基于决策树的基本思想(ID3算法),学习信息增益的计算,决策树的构建、使用、存储。
例子来自《Machine Learning in Action》 Peter Harrington
熵值计算
计算数据集合中分类的数量与概率,根据公式求得熵。
from math import log
"""计算熵值"""
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {} # 用于储存分类标签的种类和个数
for featVec in dataSet:
currentLabel = featVec[-1] # 当前数据点的分类标签
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob,2) #以2为底求对数
return shannonEnt
测试数据
提供一个如下的数据集合,用过两个特征对生物是否属于鱼类进行确认。
不浮出水面是否可以生存 | 是否有脚蹼 | 属于鱼类 |
---|---|---|
1 | 是 | 是 |
2 | 是 | 是 |
3 | 是 | 否 |
4 | 否 | 是 |
5 | 否 | 是 |
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no'],
labels = ['no surfacing','flippers']
return dataSet, labels
测试
def testShannonEnt():
myDat,labels = createDataSet()
print (calcShannonEnt(myDat))
结果
0.9709505944546686
当在数据集中再加入一种分类
"""创建测试数据集合"""
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no'],
[1, 1, 'maybe'],]
重新计算熵,可得结果:
1.4591479170272448
熵增大,即混乱度(不确定性)增大。值的变化符合熵的定义。
划分数据集
以下代码包含三个输入变量,具体含义见注释。其中dataSet中所包含的数据点,每一个数据点都有多个特征。axis表示接下来按照第几个特征进行划分数据,value表示返回的数据集第axis特征的特征值等于多少。
"""划分数据集"""
''' dataSet:带划分数据集 axis:划分数据集的特征(第axis个,从零开始计数) value:需要返回的特征的值 '''
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec) # 当前数据点 除去当前特征后保存
return retDataSet
测试数据划分
"""测试划分数据集"""
def testSplitData():
myDat,labels = trees.createDataSet()
print(trees.splitDataSet(myDat,0,1))
print(trees.splitDataSet(myDat,0,0))
结果
[[1, 'yes'], [1, 'yes'], [0, 'no']]
[[1, 'no'], [1, 'no']]
第一行的分类结果表示,按照第1个特征对数据集进行划分,返回的结果是第一个特征值为1的数据点。第二行返回的是第一个特征值等于0的数据点。划分结果和预想的一致。
寻找最好的划分方式
依照算法,寻找信息增益最大的分类方式作为最好的分类方式
"""寻找最好的划分方式"""
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 # 获得特征个数
baseEntropy = calcShannonEnt(dataSet) # 原始的熵值
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures): # 对于每一个特征都进行迭代
featList = [example[i] for example in dataSet] # 提取当前特征在每个数据点中的值
uniqueVals = set(featList) #转换为一个set集合(没有重复元素)
newEntropy = 0.0
for value in uniqueVals:
# 针对数据集合,对第i个特征进行分类,返回值是特征值为value的
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain): # 比较每次分类信息增益
bestInfoGain = infoGain # 如果大,就替换当前的值
bestFeature = i
return bestFeature
测试
"""测试最好的划分方式"""
def testChooseBestFeatureToSplit():
myDat,labels = trees.createDataSet()
print(trees.chooseBestFeatureToSplit(myDat))
结果
当前数据利用第0个特征分类信息增益最大。
0
递归构建决策树
由递归构成树停止的条件有两个:
1. 所有的标签的类都相同
2. 所有的特征都用完了
具体实现见代码
"""创建树"""
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet] # 分类标签的值
if classList.count(classList[0]) == len(classList):
return classList[0] # 所有的标签的类都相同 返回这个类标签
if len(dataSet[0]) == 1: # 如果所有的特征都用完了,则停止
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) # 获得信息增益最大的分类特征
bestFeatLabel = labels[bestFeat] # 获得当前特征的具体含义
myTree = {bestFeatLabel:{}}
del(labels[bestFeat]) # 删除已分类的特征
featValues = [example[bestFeat] for example in dataSet] # 当前分类特征下的数据点特征值
uniqueVals = set(featValues) # 转换为list类型
for value in uniqueVals:
subLabels = labels[:] # 拷贝标签
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
"""返回出现次数最多的分类名称"""
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
测试
"""测试创建树"""
def testCreateTree():
myDat,labels = trees.createDataSet()
myTree = trees.createTree(myDat,labels)
print(myTree)
结果
以字典的形式返回决策树
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
利用决策树判断新数据点
依照数据点的每一个特征根据决策树进行划分,知道得出类型。
"""利用决策树判断新数据点"""
def classify(inputTree,featLabels,testVec):
firstSides = list(inputTree.keys()) # 第一个分类特征
firstStr = firstSides[0] #找到输入的第一个元素
secondDict = inputTree[firstStr] # 二级字典
featIndex = featLabels.index(firstStr) # 当前特征值在数据集的位置,返回时索引
key = testVec[featIndex] # 拿到新数据点的当前特征的特征值
valueOfFeat = secondDict[key] # 根据特征值 划分数据点
if isinstance(valueOfFeat, dict): # 如果不是叶节点,迭代;
classLabel = classify(valueOfFeat, featLabels, testVec)
else: classLabel = valueOfFeat # 如果是叶节点,返回标签类
return classLabel
测试
"""测试决策树判断新数据点"""
def testClassify():
myDat,labels = trees.createDataSet()
myTree = trees.createTree(myDat,labels)
myDat,labels = trees.createDataSet()
print(trees.classify(myTree,labels,[1,0]))
print(trees.classify(myTree,labels,[1,1]))
结果
返回判断结果
no
yes
存储决策树
决策树的创建比较耗时,为了方便一次创建多次使用。可以把创建的决策树序列化,保存到磁盘上,需要的时候再读取使用。
"""序列化并写入磁盘"""
def storeTree(inputTree,filename):
fw = open(filename,'wb+') # 要以二进制格式打开文件
pickle.dump(inputTree,fw)
fw.close()
"""读取磁盘并反序列化"""
def grabTree(filename):
fr = open(filename,'rb') # 要以二进制格式打开文件
return pickle.load(fr)
测试
"""测试决策树保存"""
def testStoreAndGrabTree():
myDat,labels = trees.createDataSet()
myTree = trees.createTree(myDat,labels)
trees.storeTree(myTree,'trees.txt')
reloadMyTree = trees.grabTree('trees.txt')
print(reloadMyTree)
结果
可以从磁盘得到之前的决策树
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
以上完整代码见GitHub。