决策树
决策这一节相对KNN算法来说难了点,因为本节需要先理解熵和信息增益的概念,理解后再看就比较容易了。不过我也是先看的代码,在看代码的过程中没明白它为什么要这么做,然后再去查相关的书籍,再把熵和信息增益的概念理解了,再去看代码,就明白了。
基本概念
基本概念不懂没关系,先去看源码。然后再回顾和总结。
香农熵(也叫信息熵)简称熵,其计算公式如下:
其中的 是该类别的概率。
熵越小纯度越高,熵越大越杂乱无章。比如左手一把盐,右手一瓶水,此时熵很小,但如果把盐倒到水里,那么此时熵就很大了。
决策树的根结点的熵是最大的,我们的目标就是进行分类,让节点的熵变成0,就表示节点都是同一类的了,需要关注的是在这个过程中使熵变小的最快的分类是最优分类,我们要做的就是找到这样的分类。
信息熵和信息增益
信息熵计算
例子1
以MLiA书上的数据,计算信息熵和信息增益的过程
如上图的数据集,最终的类别只有两类:是鱼类和不是鱼类,分别占2/5和3/5。
按公式 (MLiA P35)可知,这里的n=2,即共2个类别, 表示是鱼类; 表示不是鱼类。 表示某一类的概率,所以:
以上计算可通过python验证:
例子2
在例子1的基础上,把数据增加一类,比如把第1条样本的结果分类为“可能”,那么一共就有三类:“是鱼类”、“不是鱼类”,“可能是鱼类”,这三者占比分别为1/5,1/5,3/5。
此时的信息熵为:
对比例子1,可发现类别越多,数据就越不纯,信息熵就越大;反之,类别越少,数据就越纯,信息熵就越小,当只有一个类别时,信息熵就是0。
信息增益
从信息熵的概念我们可以知道,熵越小表示纯度越高,直到熵为0时就表示某一类已经完全分好类了。而在每一次分类时,我们需要找到一个类别,以这个类别分类后的熵最小,就是我们想要的,当前熵最小也就是上一级熵减当前类别的熵最大,把这个差就叫信息增益,所以我们的目标就是找信息增益最大的即可。
还是以这个数据为例子:
上一节已经计算出来当前熵H=0.97,那第一个分类到底是拿“不浮出水面是否可以生存”(后续简称第0列属性)这个属性去分类,还是拿“是否有脚蹼”(后续称第1列属性)去分类呢?这就需要遍历这两个属性并计算每个属性信息增益,找到信息增益最大的属性作为最优的分类属性。下面我们分别计算这两个属性的信息增益。
计算第0列属性信息增益
它有两个可能的取值:{是,否},使用该属性对样本进行划分,可得到2个子集,分别记为D1 (不浮出水面是否可以生存=是),D2(不浮出水面是否可以生存=否)。D1和D2分别占3/5和2/5。
子集D1包含编号为{1,2,3}的3个样例,其中正例(是鱼类)占2/3,反例占1/3;子集D2 包含编号为{4,5}的2个样例,都是反例。所以按照第0列属性划分之后获得的两个分支结点的信息熵为:
信息增益为:
:当前样本集合.
:以某一属性进行划分,这个属性中的某类别样本就是 ,比如以“不浮出水面是否可以生存”来划分,这个属性值为“是”的 ,为“否”的 ,所以 就好理解了,前者就是3/5,后者就是2/5。
:表示样本的信息熵。西瓜书上记作
如果想了解的更清楚,可参考西瓜书4.2节划分选择。
计算第1列属性信息增益
同样的。该列属性有两个可能的取值:{是,否},使用该属性对样本进行划分,可得到2个子集,分别记为 (是否有脚蹼=是), (是否有脚蹼=否)。 和 分别占4/5和1/5。
子集 包含编号为{1,2,4,5}的4个样例,其中正例(是鱼类)占1/2,反例占1/2;子集 包含编号为{3}的1个样例,都是反例。所以按照第1列属性划分之后获得的两个分支结点的信息熵为:
信息增益为:
因为 Gain(D,不浮出水面是否可以生存) >Gain(D,是否有脚蹼) ,所以以第0列来分类。
可结合着代码一起看,选择信息增益最好的代码为chooseBestFeatureToSplit()函数,另外一个例子可看西瓜书上4.2.1节的例子。
MLiA中的决策树代码
贴上带自己注释的完整代码,供参考。
#coding:utf-8
#原文地址:http://blog.csdn.net/rosetta
#python3.6
#author:sweird
#date:2018.2.5
from math import log
import operator
import matplotlib.pyplot as plt
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing','flippers']
#change to discrete values
return dataSet, labels
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet: #the the number of unique elements and their occurance
currentLabel = featVec[-1] #取最后一列值作为Lable
if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 #统计每一类的个数,比如'yes"类别2个,“no"类别3个
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries #计算该分类的概率,比如yes分类概率为2/5,no分类概率为3/5
shannonEnt -= prob * log(prob,2) #log(prob,2),以2为底的log。这里计算香农熵。使用书本P35最底下的公式。
#从这个公式里可以看出,如果只有一个类别,那么prob=1.0,熵就是0,也就是说只有一类那就不需要再分类了!
return shannonEnt
def calcShannonEnt_test():
myDat, labels = createDataSet()
print(myDat)
shannonEnt = calcShannonEnt(myDat)
print(shannonEnt)#0.97
#增加一个类别后看熵的变化
myDat[0][-1]='maybe'
print(myDat)
shannonEnt2=calcShannonEnt(myDat)
print(shannonEnt2)#1.37
#只有一个类别的熵
myDat[0][-1]='no'
myDat[1][-1]='no'
print(myDat)
shannonEnt2=calcShannonEnt(myDat)
print(shannonEnt2)#0.0
#所以从上面的实验可知,熵越小纯度越高,熵越大越杂乱无章。
#dataSet = {list}[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
#axis用来取dataSet中每一个元素中第0列一样的
#value表示选择axis这一列的值是多少。
#比如axis=0,value=1,被选择的数据就是[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no']],返回除其自身以外的数据[[1, 'yes'], [1, 'yes'], [0, 'no']]
#如果axis=0,value=0,被选择的数据就是[[0, 1, 'no'], [0, 1, 'no']],返回[[1, 'no'], [1, 'no']]
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis] #chop out axis used for splitting
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def splitDataSet_test():
myDat,labels = createDataSet()
print(myDat)
retDataSet1 = splitDataSet(myDat,0,1)
retDataSet2 = splitDataSet(myDat,0,0)
print(retDataSet1)
print(retDataSet2)
#这个可看西瓜书上4.2.1节的解释,涉及到熵的计算和信息增益的计算。
#返回最好的用于划分数据集的特征,0表示dataSet数据集的第0列(即是否可浮出水面),1表示第1列(即是否有脚蹼)
#该函数的主要作用:遍历样本的所有属性(带标签的),计算按照该属性分类后的信息增益,选择最大的信息增益所在的属性来分类。
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 #计算特征的个数,这里是2个(是否可浮出水面可以生存和是否有脚蹼)
baseEntropy = calcShannonEnt(dataSet)#计算熵,值为0.97,这个在calcShannonEnt_test()中已经学过了。
bestInfoGain = 0.0; bestFeature = -1
for i in range(numFeatures): #iterate over all the features
featList = [example[i] for example in dataSet]#取出每个样本第i个属性,放到featList中。
#dataSet = {list}[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]#
#如果i=0时
#featList = [1, 1, 1, 0, 0]
#如果i=1时
#featList = [1, 1, 0, 1, 1]
uniqueVals = set(featList) #去重后的放到uniqueVals set中,如uniqueVals = {0, 1}
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy #calculate the info gain; ie reduction in entropy
if (infoGain > bestInfoGain): #compare this to the best gain so far
bestInfoGain = infoGain #if better than current best, set to best
bestFeature = i
return bestFeature #returns an integer
def chooseBestFeatureToSplit_test():
myDat, lables = createDataSet()
bestFeature = chooseBestFeatureToSplit(myDat)
#疑问:信息熵是越大越好还是?
#信息增益和信息熵的关系?
print(myDat)
print(bestFeature)
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
#这块代码看起来有点费劲。
#主要功能是给定实验数据和标签,创建一棵决策数。注意labels会被改写。
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]#stop splitting when all of the classes are equal
if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:] #copy all of labels, so trees don't mess up existing labels
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
def createTree_test():
myDat,labels=createDataSet()#注意,这里的labels不是结果类别,而是属性类别。
myTree = createTree(myDat,labels)
print(myTree)
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1:{'feet':{0:'no', 1:'yes'}}}}}},
{'脐部':{"凹陷":'好瓜', "稍凹":{"根蒂":{"稍蜷":{"色泽":{"青绿":"好瓜", "乌黑":"好瓜", "浅白":"好瓜"}}, "蜷缩":"坏瓜", "硬挺":"好瓜"}}, "平坦":"坏瓜"}},
]#第四条参考西瓜书P83 图4.7,自己写的数据,然后可以正常显示到图中。
return listOfTrees[i]
#该函数就是决策树预测函数。
#输入已经创建好的决策树和属性类别标签,以及待预测样本的属性。
#输出该标本属于哪一类。
def classify(inputTree,featLabels,testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
key = testVec[featIndex]
valueOfFeat = secondDict[key]
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
else: classLabel = valueOfFeat
return classLabel
def classify_test():
myDat,ori_labels = createDataSet()#书上使用了种方法,实际上只要labels就可以,myDat没有用。
print("myDat", myDat)
# labels = ['no surfacing', 'flippers']#可注释上述两句,放开这句也能达到效果。
print("labels", ori_labels)
labels = ori_labels.copy()
# myTree = retrieveTree(0)
# print("myTree 1", type(myTree),myTree)
myTree=createTree(myDat,labels)#该函数会改变labels值,所以上面进行了一次拷贝,因为原始标签ori_labels要在classify()中使用。
#myTree也可以使用retrieveTree()获取手动创建的决策树,用于测试。
print("myTree 2", type(myTree),myTree)
result = classify(myTree,ori_labels,[1,0])
print(result)
def storeTree(inputTree,filename):
import pickle
fw = open(filename,'wb')#python3改成一定要用二进制存储,所以打开属性一定要有‘b'。
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename, 'rb')#同样读取时也要加'b’
return pickle.load(fr)
def store_trees_test():
filename = "myTreeStorage.txt"
myDat,ori_labels = createDataSet()
labels = ori_labels.copy()
myTree=createTree(myDat,labels)
print("myTree", myTree)
storeTree(myTree,filename)
result = grabTree(filename)
print("result", result)
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, fontproperties="SimHei")
#变量createPlot.ax1是调用plotNode函数的地方传进来的,python中的变量默认全局有效。
#annototate在图形中增加带箭头的注释。可参考“pyplot_test.py”代码“pyplot的文本显示”一节。
#详细可参考:https://matplotlib.org/api/_as_gen/matplotlib.pyplot.annotate.html?highlight=annotate#matplotlib.pyplot.annotate
#第1个参数nodeTxt是要注释的内容
#第2个参数xy=()被注释的地方
#第4个参数xytext=()是插入文本的地方
#fontproperties="SimHei",支持中文。
#使用递归遍历的方法,获取叶子个数。
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in list(secondDict.keys()):
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs +=1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in list(secondDict.keys()):
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30,fontproperties="SimHei")
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0] #the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in list(secondDict.keys()):
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key],cntrPt,str(key)) #recursion
else: #it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()
#lenses_test()用于对给定的隐形眼睛数据集(数据来源于UCI数据库,UCI数据库是加州大学欧文分校(University of CaliforniaIrvine)提出的用于机器学习的数据库,这个数据库目前共有335个数据集,其数目还在不断增加)创建决策树,并使用Matplot画出决策树,然后使用classify对给定的输入预测结果。
#有四种分类属性,分别是:age(年龄)、prescript(症状)、astigmatic(是否散光)、tearRate(眼泪数量)
#age(年龄)有三种值:young(年经的),pre(翻译成啥?), presbyopic(老花眼)
#prescript(症状):hyper(高度近视)和myope(普通近视)
#astigmatic(是否散光)
#tearRate(眼泪数量):normal(正常)和reduced(减少)
#预测的结果有三种:hard(硬材质)、soft(软材质)和no lenses(不适合佩戴隐形眼镜)
def lenses_test():
fr = open("lenses.txt")
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
ori_lensesLabels = ['age','prescript','astigmatic','tearRate']#有四种分类属性,分别是:年龄、症状、是否散光、眼泪数量。
lensesLabels = ori_lensesLabels.copy()#因为createTree()函数会修改lensesLabels,所以这里做一个拷贝。
lensesTree = createTree(lenses,lensesLabels)
print("lensesTree:",lensesTree)
print("lensesLabels:",lensesLabels)
createPlot(lensesTree)#使用matplotlib画图。
#进行分类预测的时候,可肉眼看决策树,也可使用下面的classify()函数进行。
result = classify(lensesTree,ori_lensesLabels,["presbyopic", "myope", "no", "reduced"])#输入一条实例的属性为presbyopic(年龄不知道翻译成啥。),myope(普通近视),no(不散光),reduced(眼泪减少),输出是否需要佩戴隐藏眼睛,以及隐形眼睛的材质。
#结果有三类:hard(硬材质)、soft(软材质)和no lenses(不适合佩戴隐形眼镜)
print(result)
if __name__ == '__main__':
#3.1决策树构造
#3.1.1信息增益。计算给定数据集的香农熵
# calcShannonEnt_test()
#3.1.2划分数据集
#按照给定特征划分数据集
# splitDataSet_test()
#选择最好的数据集划分方式
# chooseBestFeatureToSplit_test()
#3.1.3递归构建决策树
#其实到本节为止,整棵决策树已经画出来了,只是不太直观而已。
# createTree_test()
#3.2节使用mattplot把决策树直观的展示出来。可参见“treePlotter.py”
#3.3测试和存储分类器
#3.3.1使用决策数进行分类。
# classify_test()
#3.3.2决策树的存储。
# 因为创建一棵决策树会很慢,所以可以先把决策树保存到硬盘上,在用到的时候读取出来。
# store_trees_test()
#3.4.使用决策树预测隐形眼镜类型。
lenses_test()
决策树总结
学完以后才发现,决策树其实也是很简单的,它无非就是两个关键概念信息熵和信息增益,而实际上把信息熵的概念理解了,信息增益就好理解的。这两个概念的理解可看本大单节开始部分的内容。
MLiA第3章决策树内容,就是对给定的数据按类别遍历划分,然后计算出划分后的信息熵,再计算出信息增益,算出最好的分类类别,然后按此类别分类,最终构造一棵决策数。最后输入决策数、属性值(比如差别西瓜时有:挤部、根蒂和色泽)以及待预测的样本的属性值,返回该本的类别(比如是好瓜还是坏瓜)