1、决策树的工作原理
(1)找到划分数据的特征,作为决策点
(2)利用找到的特征对数据进行划分成n个数据子集。
(3)如果同一个子集中的数据属于同一类型就不再划分,如果不属于同一类型,继续利用特征进行划分。
(4)指导每一个子集的数据属于同一类型停止划分。
2、决策树的优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关的特征数据
缺点:可能产生过度匹配的问题
适用数据类型:数值型(需要进行离散化)和标称型
ps:产生过度匹配的原因:
过度匹配的现象:一个假设在训练数据上能够获得比其他假设更好的拟合,但是在训练数据外的数据集上却不能很好的拟合数据。此时我们就叫这个假设出现了overfitting的现象。
原因:在决策树模型搭建中,我们使用的算法对于决策树的生长没有合理的限制和修剪的话,决策树的*生长有可能每片叶子里只包含单纯的事件数据或非事件数据,可以想象,这种决策树当然可以完美匹配(拟合)训练数据,但是一旦应用到新的业务真实数据时,效果是一塌糊涂。
ps:数值型和标称型数据
标称型:标称型目标变量的结果只在有限目标集中取值,如真与假(标称型目标变量主要用于分类)
数值型:数值型目标变量则可以从无限的数值集合中取值,如0.100,42.001等 (数值型目标变量主要用于回归分析)
3、决策树创建分支createBranch()的伪代码实现
检测数据集中的每个子项是否属于同一分类
if so return 类标签
else
寻找划分数据集的最好特征
划分数据集
创建分支节点
for 每个划分的子集
调用createBranch并增加返回结果到分支节点中(递归调用)
return 分支节点
4、决策树的目标就是将散乱的数据进行划分成有序的数据,那么这个划分前后信息的变化就是信息增益,也就是信息熵
那么对于每个类别分类前后都有相应的信息增益,所以就要计算所有类别的信息期望值
(n表示分类的数目)
下面用具体的代码实现平均信息熵的计算过程:
from math import log
import operator
#计算所有已经分类好的子集的信息商
def calcShannonEnt(dataSet):
#计算给定的数据集的长度,也就是类别的数目
numEntries = len(dataSet)
#创建一个空字典
labelCounts = {}
#遍历所有分类的子集
for featVec in dataSet: #the the number of unique elements and their occurance
#取出每个子集的键值,也就是对应的类标签,-1的索引值表示最后一个
currentLabel = featVec[-1]
#如果当前类标签不在标签库中,就将当前子集的标签加入到标签库中
if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
#如果已经存在于标签库中,就将标签库对应的加1
labelCounts[currentLabel] += 1
shannonEnt = 0.0
#遍历标签库中所有的标签
for key in labelCounts:
#根据键值对取出每个类别的次数/总的类别数,也就是每个类别的概率p(i)
prob = float(labelCounts[key])/numEntries
#计算所有类别的期望值,也就是平均信息熵
shannonEnt -= prob * log(prob,2) #log base 2
return shannonEnt
下面创建一个已经分类好的数据集
#创建一个数据集
def createDataSet():
dataSet = [[1, 1, \'yes\'],
[1, 1, \'yes\'],
[1, 0, \'no\'],
[0, 1, \'no\'],
[0, 1, \'no\']]
labels = [\'no surfacing\',\'flippers\']
#change to discrete values
return dataSet, labels
再命令行里面进行测试,计算这些数据集的平均信息熵
>>> import tree
>>> myDat,labels=tree.createDataSet()
>>> myDat
[[1, 1, \'yes\'], [1, 1, \'yes\'], [1, 0, \'no\'], [0, 1, \'no\'], [0, 1, \'no\']]
>>> labels
[\'no surfacing\', \'flippers\']
>>> tree.calcShannonEnt()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: calcShannonEnt() takes exactly 1 argument (0 given)
>>> tree.calcShannonEnt(myDat)
0.9709505944546686
>>>
熵越高,则混合的数据也越多,我们可以在数据集中添加更多的分类,观察熵是如何变化的。这里我们就不尝试了,到这里我们已经学会了如何计算数据集的无序程度
d=[]
a = [[1, 1, \'yes\'],
[1, 1, \'yes\'],
[1, 0, \'no\'],
[0, 1, \'no\'],
[0, 1, \'no\']]
for i in a:
# print i[:2]
# print "---"
# print i[:]#所有的元素
# print i[0:]#所有的元素
# print i[1:]#第二列之后所有的 取值为1 ,。。。。
# print i[2:]#第三列之后所有的 取值为2,。。。。
# print i[:1]#第一列 取值为 0
# print i[:2]#前两列 取值为0,1
print i[:3]#前三列 取值为0,1,2
要知道矩阵里面的[:]是左闭右开区间
a = [[1, 1, \'yes\'],
[1, 1, \'yes\'],
[1, 0, \'no\'],
[0, 1, \'no\'],
[0, 1, \'no\']]
for i in a:
if i[1] == 0:
b = i[:1]
print b
print "=="
b.extend(i[2:])
print b
print "==="
d.append(b)
print d
[1]
==
[1, \'no\']
===
[[1, \'no\']]
相当于判断每个元素的第二个字符是否等于0,如果等于,则将这个元素的剩下的字符 组成新的列表矩阵
下面给出按照给定的标准划分数据集的代码
看了上面的简单的例子相信下面的函数应该很容易懂了
#按照给定的标准对数据进行划分
#dataSet:待划分的数据集
#axis:划分的标准
#value:返回值
def splitDataSet(dataSet, axis, value):
# python吾言在函数中传递的是列表的引用,在函数内部对列表对象的修改。
# 将会影响该列表对象的整个生存周期。为了消除这个不良
# 影响,我们需要在函数的开始声明一个新列表对象
# 因为该函数代码在同一数据集上被调用多次,
# 为了不修改原始数据集,创建一个新的列表对象0
retDataSet = []
#遍历数据集中的每个元素,这里的每个元素也是一个列表
for featVec in dataSet:
#如果满足分类标准
#axis=0 value=1 如果每个元素的第一个字符为1
#axis=0 value=0 如果每个元素的第一个字符为0
if featVec[axis] == value:
#取出前axis列的数据
reducedFeatVec = featVec[:axis] #chop out axis used for splitting
# list.append(object)向列表中添加一个对象object
# list.extend(sequence)把一个序列seq的内容添加到列表中
#把featVec元素的axis+1列后面的数据取出来添加到reducedFeatVec
reducedFeatVec.extend(featVec[axis+1:])
#将reducedFeatVec作为一个对象添加到retDataSet
retDataSet.append(reducedFeatVec)
return retDataSet
下面我们在命令行里面进行测试
>>> import tree
>>> myDat,labels = tree.createDataSet()
>>> myDat
[[1, 1, \'yes\'], [1, 1, \'yes\'], [1, 0, \'no\'], [0, 1, \'no\'], [0, 1, \'no\']]
>>> tree.splitDataSet(myDat,0,1)
[[1, \'yes\'], [1, \'yes\'], [0, \'no\']]
>>> tree.splitDataSet(myDat,0,0)
[[1, \'no\'], [1, \'no\']]
>>> tree.splitDataSet(myDat,1,0)
[[1, \'no\']]
>>> tree.splitDataSet(myDat,1,0)
可以看出我们的分类标准不一样,最终的结果也就不一样,这一步是根据我们的标准选出符合我们确定的标准的数据
现在我们可以循环计算分类后的香农熵以及splitDataSet()函数来寻找最好的分类标准
#计算所有已经分类好的子集的信息商
def calcShannonEnt(dataSet):
#计算给定的数据集的长度,也就是类别的数目
numEntries = len(dataSet)
#创建一个空字典
labelCounts = {}
#遍历所有分类的子集
for featVec in dataSet: #the the number of unique elements and their occurance
#取出每个子集的键值,也就是对应的类标签,-1的索引值表示最后一个
currentLabel = featVec[-1]
#如果当前类标签不在标签库中,就将当前子集的标签加入到标签库中
if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
#如果已经存在于标签库中,就将标签库对应的加1
labelCounts[currentLabel] += 1
shannonEnt = 0.0
#遍历标签库中所有的标签
for key in labelCounts:
#根据键值对取出每个类别的次数/总的类别数,也就是每个类别的概率p(i)
prob = float(labelCounts[key])/numEntries
#计算所有类别的期望值,也就是平均信息熵
shannonEnt -= prob * log(prob,2) #log base 2
return shannonEnt
#寻找最好的分类标准
def chooseBestFeatureToSplit(dataSet):
#计算原始数据的特征属性的个数=len-标签列
numFeatures = len(dataSet[0]) - 1 #the last column is used for the labels
#计算原始数据的原始熵
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
#遍历每个特征属性,遍历每一列
for i in range(numFeatures): #iterate over all the features
#遍历数据集中每行除去最后一列标签的每个元素
featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
#取出每行的特征属性存入set集合
uniqueVals = set(featList) #get a set of unique values
newEntropy = 0.0
#遍历每个特征属性
for value in uniqueVals:
#dataset:带分类的数据集 i:分类标准 value:返回值
#取出符合分类标准的数据
subDataSet = splitDataSet(dataSet, i, value)
#符合分类标准的数据长度/数据集的总长度=p(i)
prob = len(subDataSet)/float(len(dataSet))
#计算分类之后平均信息熵---》香农熵
newEntropy += prob * calcShannonEnt(subDataSet)
#分类前后的香农熵之差
infoGain = baseEntropy - newEntropy #calculate the info gain; ie reduction in entropy
#如果差大于最好的时候的差
if (infoGain > bestInfoGain): #compare this to the best gain so far
#则继续分类,注意的是香农熵越大,越无序,所以我们要找的是差值最大的时候,也就是分类最有序的时候
#将前一次的赋值给bestInfoGain,下一次和前一次比较,如果下一次的差小于前者则停止
bestInfoGain = infoGain #if better than current best, set to best
#将当前的分类标准i赋值给bestFeature
bestFeature = i
return bestFeature
我们在命令行里面测试
>>> reload(tree)
<module \'tree\' from \'E:\python2.7\tree.py\'>
>>> myDat,labels=trees.createDataSet()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
NameError: name \'trees\' is not defined
>>> myDat,labels=tree.createDataSet()
>>> myDat
[[1, 1, \'yes\'], [1, 1, \'yes\'], [1, 0, \'no\'], [0, 1, \'no\'], [0, 1, \'no\']]
>>> tree.chooseBestFeatureToSplit(myDat)
0
>>>
我们通过chooseBestFeatureToSplit()得到第0个特征是最好的分类标准
subDataSet = splitDataSet(dataSet, i, value) 也就是这里的i就是选取的分类标准 value就是第i个特征值的值
#选出次数最多的类别名称
def majorityCnt(classList):
#创建一个空字典,用于统计每个类别出现的次数
classCount={}
#遍历所有分类的子集中类别的出现的次数
for vote in classList:
#如果子集中没有该类别标签,则将该类别添加到字典classCount中
if vote not in classCount.keys():
classCount[vote] = 0
#否则将该字典里面的标签的次数加1
classCount[vote] += 1
#iteritems()返回一个迭代器,工作效率高,不需要额外的内存
#items()返回字典的所有项,以列表的形式返回
#这里通过迭代返回每个类别出现的次数
#key=operator.itemgetter(1)获取每个迭代器的第二个域的值,也就是次数,按照 类别出现的次数降序排列
#reverse是一个bool变量,表示升序还是降序排列,默认为false(升序排列),定义为True时表示降序排列
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
#取出类别最高的类别以及对应的次数
return sortedClassCount[0][0]
下面开始创建树#创建决策树
#dataset:数据集 labels:包含了数据集中的所有标签
def createTree(dataSet, labels):
#创建一个classList的列表,取出每个元素的最后一列:标签类[\'yes\', \'yes\', \'no\', \'no\', \'no\']
classList = [example[-1] for example in dataSet]
\'\'\'
a = [[1, 1, \'yes\'],
[1, 1, \'yes\'],
[1, 0, \'no\'],
[0, 1, \'no\'],
[0, 1, \'no\']]
classList = [example[-1] for example in a]
print classList
print classList[0]
print classList.count(classList[0])
print len(classList)
# [\'yes\', \'yes\', \'no\', \'no\', \'no\']
# yes
# 2
# 5
\'\'\'
#统计标签列表中的第一个元素的个数是否等于标签列表的长度
#相等就意味着所有的元素属于同一个类别,那么就可以不再划分,这是最简单的情况
if classList.count(classList[0]) == len(classList):
#如果相等就返回标签列表的第一个元素
return classList[0] # stop splitting when all of the classes are equal
#或者数据集的第一个元素的长度等于1,表示该元素只有一个特征值,同样停止划分
if len(dataSet[0]) == 1: # stop splitting when there are no more features in dataSet
#返回次数最多的类别的名称
return majorityCnt(classList)
#在数据集中寻找最好的分类标准:最鲜明的特征属性
#chooseBestFeatureToSplit()函数返回的是一个整数,表示第几个特征
bestFeat = chooseBestFeatureToSplit(dataSet)
#从标签库中将该特征选出来
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
#从标签列表中删除该特征属性
del (labels[bestFeat])
#选出数据集中每行的第bestFeat个元素组成一个列表
#取出第bestFeati列元素
featValues = [example[bestFeat] for example in dataSet]
#从列表中创建一个不重复的集合
uniqueVals = set(featValues)
#遍历这些不重复的特征值
for value in uniqueVals:
#复制所有的标签,这里创建一个新的标签是为了防止函数调用createtree()时改变原始列表的内容
subLabels = labels[:] # copy all of labels, so trees don\'t mess up existing labels
#递归调用创建决策树的函数
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
命令行:
>>> import tree
>>> myDat,labels =tree.createDataSet()
>>> myTree
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
NameError: name \'myTree\' is not defined
>>> myTree=tree.createTree(myDat,labels)
>>> myTree
{\'no surfacing\': {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}}
>>>
下面利用matplolib绘制树形图
import matplotlib.pyplot as plt
#决策点
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
#子节点
leafNode = dict(boxstyle="round4", fc="0.8")
#箭头
arrow_args = dict(arrowstyle="<-")
#使用文本注解绘制树节点
# 节点注释 当前子节点 父节点 节点类型
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
#nodeTxt:注释的内容,xy:设置箭头尖的坐标 ,被注释的地方(x,y)
# xytext:xytext设置注释内容显示的起始位置,文字注释的地方,
#arrowprops用来设置箭头
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords=\'axes fraction\',
xytext=centerPt, textcoords=\'axes fraction\',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
\'\'\'
annotate
# 添加注释
# 第一个参数是注释的内容
# xy设置箭头尖的坐标
# xytext设置注释内容显示的起始位置
# arrowprops 用来设置箭头
# facecolor 设置箭头的颜色
# headlength 箭头的头的长度
# headwidth 箭头的宽度
# width 箭身的宽度
\'\'\'
def createPlot():
#创建一个白色画布
fig = plt.figure(1, facecolor=\'white\')
#清除画布
fig.clf()
#在画布上创建1行1列的图形
createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotNode(\'a decision node\', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode(\'a leaf node\', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
在命令行里面测试
>>> from imp import reload
>>> reload(treeplot)
<module \'treeplot\' from \'E:\\Python36\\treeplot.py\'>
>>> treeplot.createPlot()
>>>
效果图如下
#获取叶节点数目
def getNumLeafs(myTree):
#初始化叶节点的值为0
numLeafs = 0
#取出树的第一个关键字{\'no surfacing\': {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}}
#myTree.keys()[0]=no surfacing
#myTree[firstStr]={0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}=secondDict
#secondDict.keys()=[0,1]
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
#secondDict[1]是一个字典 如果子节点为字典类型
if type(secondDict[key]).__name__==\'dict\':#test to see if the nodes are dictonaires, if not they are leaf nodes
#递归调用getNumLeafs(myTree)
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs +=1
return numLeafs
#获取树的层数
def getTreeDepth(myTree):
#初始化树的最大深度为0
maxDepth = 0
# #取出树的第一个关键字{\'no surfacing\': {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}}
#myTree.keys()[0]=no surfacing
#myTree[firstStr]={0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}=secondDict
#secondDict.keys()=[0,1]
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
# secondDict[1]是一个字典 如果子节点为字典
if type(secondDict[key]).__name__==\'dict\':#test to see if the nodes are dictonaires, if not they are leaf nodes
#将深度加1之后继续递归调用函数getTreeDepth()
thisDepth = 1 + getTreeDepth(secondDict[key])
#如果是叶子节点,深度为1
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
命令行:>>> reload(treeplot)
<module \'treeplot\' from \'E:\\Python36\\treeplot.py\'>
>>> treeplot.retrieveTree(0)
{\'no surfacing\': {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}}
>>> treeplot.getNumLeafs(myTree)
3(树的子节点数,相当于树的宽度)
>>> treeplot.getTreeDepth(myTree)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "E:\Python36\treeplot.py", line 77, in getTreeDepth
firstStr = myTree.keys()[0]
TypeError: \'dict_keys\' object does not support indexing(要注意的是python3不能直接解析字典的keys列表,需要手动将keys转为list列表)
>>> reload(myTree)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "E:\Python36\lib\imp.py", line 315, in reload
return importlib.reload(module)
File "E:\Python36\lib\importlib\__init__.py", line 139, in reload
raise TypeError("reload() argument must be a module")
TypeError: reload() argument must be a module
>>> reload(treeplot)
<module \'treeplot\' from \'E:\\Python36\\treeplot.py\'>
>>> treeplot.getTreeDepth
<function getTreeDepth at 0x0000020A95A799D8>
>>> treeplot.getTreeDepth(myTree)
2(树的深度)
>>>
利用递归画出整个树
#在父子节点中添加文本信息
def plotMidText(cntrPt, parentPt, txtString):
#父节点和子节点的中点坐标
#parentPt[0]:父节点的x坐标 cntrPt[0]:左孩子节点的x坐标
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
#画树
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
#获取所有的叶子节点的个数,决定了x轴的宽度
numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
#获取树的深度,决定了y轴的高度
depth = getTreeDepth(myTree)
# #取出树的第一个关键字{\'no surfacing\': {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}}
# myTree.keys()[0]=no surfacing
#取出第一个关键字,作为第一个节点的文本注释
firstStr = myTree.keys()[0] #the text label for this node should be this
#==============参考博客地址:https://www.cnblogs.com/fantasy01/p/4595902.html========================#
#plotTree.xOff即为最近绘制的一个叶子节点的x坐标
#plotTree.yOff 最近绘制的一个叶子节点的y的y坐标
#在确定当前节点位置时每次只需确定当前节点有几个叶子节点,
# 因此其叶子节点所占的总距离就确定了即为float(numLeafs)/plotTree.totalW*1(因为总长度为1),
#因此当前节点的位置即为其所有叶子节点所占距离的中间即一半为float(numLeafs)/2.0/plotTree.totalW*1,
# 但是由于开始plotTree.xOff赋值并非从0开始,而是左移了半个表格,因此还需加上半个表格距离即为1/2/plotTree.totalW*1,
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
#在当前节点和父节点之间添加文本信息
plotMidText(cntrPt, parentPt, nodeTxt)
#使用文本注解绘制节点
# 节点注释 当前子节点 父节点 节点类型
# def plotNode(nodeTxt, centerPt, parentPt, nodeType):
plotNode(firstStr, cntrPt, parentPt, decisionNode)
#myTree[firstStr]={0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}=secondDict
secondDict = myTree[firstStr]
#当前节点y坐标的偏移,绘制一层就减少树的1/深度
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
#遍历第二层的节点
for key in secondDict.keys():
#如果该节点是一个字典
if type(secondDict[key]).__name__==\'dict\':#test to see if the nodes are dictonaires, if not they are leaf nodes
#以该节点继续画子节点,
#def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
plotTree(secondDict[key],cntrPt,str(key)) #recursion
else: #it\'s a leaf node print the leaf node
#如果该节点不是一个字典,那就是一个子节点
#计算当前节点的x坐标的偏移
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
#def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
#plotMidText(cntrPt, parentPt, txtString):
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it\'s a tree, and the first element will be another dict
#创建绘图区
def createPlot(inTree):
fig = plt.figure(1, facecolor=\'white\')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
#计算树形图的全局尺寸
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW;
plotTree.yOff = 1.0;
#递归调用plotTree函数
plotTree(inTree, (0.5,1.0), \'\')
plt.show()
命令行测试:
>>> reload(treeplot)
<module \'treeplot\' from \'E:\\Python36\\treeplot.py\'>
>>> myTree=treeplot.retrieveTree(0)
>>> treeplot.createPlot(myTree)
#{\'no surfacing\': {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}}
#使用决策树的分类函数
\'\'\'#>>> tree.classify(myTree,labels,[1,0])
\'no\'
>>> tree.classify(myTree,labels,[1,1])
\'yes\'\'\'
def classify(inputTree,featLabels,testVec):
#找到第一个特征值\'no surfacing\'
firstStr = list(inputTree.keys())[0]
#取出第一个特征值的值value作为第二个字典树
secondDict = inputTree[firstStr]
#在所有的标签列表中找到第一个特征值\'no surfacing\'对应特征的名称的下标
#featLabels[\'no surfacing\', \'flippers\']
#featLabels.index(\'no surfacing\')=0
#featIndex=0
featIndex = featLabels.index(firstStr)
#在测试集中找到该下标对应的特征属性
#testVec[0]=1
key = testVec[featIndex]
# {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}中找到对应的value
#在字典中根据这个属性扎到对应的值
#secondDict[1]={\'flippers\': {0: \'no\', 1: \'yes\'}}
valueOfFeat = secondDict[key]
#如果该值为一个字典
if isinstance(valueOfFeat, dict):
#继续执行分类函数(递归分类)
classLabel = classify(valueOfFeat, featLabels, testVec)
#否则将该值复制给分类标签标签,返回当前的分类标签
else: classLabel = valueOfFeat
return classLabel
>>> reload(tree)
<module \'tree\' from \'E:\\Python36\\tree.py\'>
>>> dataSet,labels = tree.createDataSet()
>>> labels (分类的标签列表)
[\'no surfacing\', \'flippers\']
>>> myTree = treeplot.retrieveTree(0)
>>> myTree(构建的决策树)
{\'no surfacing\': {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}}
>>> tree.classify(myTree,labels,[1,0])
\'no\'
>>> tree.classify(myTree,labels,[1,1])
\'yes\'
>>>
构建好一颗可以进行分类的决策树之后需要进行序列化,存储在硬盘上,可以根据需要随便调用任一个对象‘
#存储决策树到硬盘
def storeTree(inputTree, filename):
import pickle
fw = open(filename, \'w\')
pickle.dump(inputTree, fw)
fw.close()
#从硬盘加载决策树
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
命令行运行:>>> reload(tree)
<module \'tree\' from \'E:\\Python36\\tree.py\'>
>>> fr = open(\'E:\Python36\lenses.txt\')
>>> lenses = [inst.strip().split(\'\t\') for inst in fr.readlines()](读取每行,并且根据tab符号进行分隔)
>>> lensesLabels =[\'age\',\'prescript\',\'astigmatic\',\'tearRate\']
>>> lensesTree = tree.createTree(lenses,lensesLabels)
>>> lensesTree
{\'tearRate\': {\'reduced\': \'no lenses\', \'normal\': {\'astigmatic\': {\'yes\': {\'prescript\': {\'myope\': \'hard\', \'hyper\': {\'age\': {\'pre\': \'no lenses\', \'young\': \'hard\', \'presbyopic\': \'no lenses\'}}}}, \'no\': {\'age\': {\'pre\': \'soft\', \'young\': \'soft\', \'presbyopic\': {\'prescript\': {\'myope\': \'no lenses\', \'hyper\': \'soft\'}}}}}}}}
>>> treeplot.createPlot(lensesTree)
最后画出决策树