机器学习算法(2) 决策树

时间:2022-12-20 12:01:15

基于决策树的基本思想(ID3算法),学习信息增益的计算,决策树的构建、使用、存储。


例子来自《Machine Learning in Action》 Peter Harrington

熵值计算

计算数据集合中分类的数量与概率,根据公式求得熵。

from math import log

"""计算熵值"""
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}                            # 用于储存分类标签的种类和个数
    for featVec in dataSet:
        currentLabel = featVec[-1]              # 当前数据点的分类标签
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2)        #以2为底求对数
    return shannonEnt

测试数据

提供一个如下的数据集合,用过两个特征对生物是否属于鱼类进行确认。

不浮出水面是否可以生存 是否有脚蹼 属于鱼类
1
2
3
4
5
def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no'],
    labels = ['no surfacing','flippers']

    return dataSet, labels

测试

def testShannonEnt():
    myDat,labels = createDataSet()
    print (calcShannonEnt(myDat))

结果

0.9709505944546686

当在数据集中再加入一种分类

"""创建测试数据集合"""
dataSet = [[1, 1, 'yes'],
           [1, 1, 'yes'],
           [1, 0, 'no'],
           [0, 1, 'no'],
           [0, 1, 'no'],
           [1, 1, 'maybe'],]

重新计算熵,可得结果:

1.4591479170272448

熵增大,即混乱度(不确定性)增大。值的变化符合熵的定义。


划分数据集

以下代码包含三个输入变量,具体含义见注释。其中dataSet中所包含的数据点,每一个数据点都有多个特征。axis表示接下来按照第几个特征进行划分数据,value表示返回的数据集第axis特征的特征值等于多少。

"""划分数据集"""
''' dataSet:带划分数据集 axis:划分数据集的特征(第axis个,从零开始计数) value:需要返回的特征的值 '''
def splitDataSet(dataSet, axis, value):
    retDataSet = []             
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]    
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)          # 当前数据点 除去当前特征后保存
    return retDataSet

测试数据划分

"""测试划分数据集"""
def testSplitData():
    myDat,labels = trees.createDataSet()
    print(trees.splitDataSet(myDat,0,1))
    print(trees.splitDataSet(myDat,0,0))

结果

[[1, 'yes'], [1, 'yes'], [0, 'no']]
[[1, 'no'], [1, 'no']]

第一行的分类结果表示,按照第1个特征对数据集进行划分,返回的结果是第一个特征值为1的数据点。第二行返回的是第一个特征值等于0的数据点。划分结果和预想的一致。


寻找最好的划分方式

依照算法,寻找信息增益最大的分类方式作为最好的分类方式

"""寻找最好的划分方式"""

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1         # 获得特征个数 
    baseEntropy = calcShannonEnt(dataSet)     # 原始的熵值
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):             # 对于每一个特征都进行迭代
        featList = [example[i] for example in dataSet]   # 提取当前特征在每个数据点中的值
        uniqueVals = set(featList)           #转换为一个set集合(没有重复元素)
        newEntropy = 0.0
        for value in uniqueVals:
            # 针对数据集合,对第i个特征进行分类,返回值是特征值为value的
            subDataSet = splitDataSet(dataSet, i, value)  
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)     
        infoGain = baseEntropy - newEntropy     
        if (infoGain > bestInfoGain):       # 比较每次分类信息增益
            bestInfoGain = infoGain         # 如果大,就替换当前的值
            bestFeature = i
    return bestFeature  

测试

"""测试最好的划分方式"""
def testChooseBestFeatureToSplit():
    myDat,labels = trees.createDataSet()
    print(trees.chooseBestFeatureToSplit(myDat))

结果

当前数据利用第0个特征分类信息增益最大。

0

递归构建决策树

由递归构成树停止的条件有两个:
1. 所有的标签的类都相同
2. 所有的特征都用完了

具体实现见代码

"""创建树"""
def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]    # 分类标签的值
    if classList.count(classList[0]) == len(classList):
        return classList[0]      # 所有的标签的类都相同 返回这个类标签
    if len(dataSet[0]) == 1:     # 如果所有的特征都用完了,则停止
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)  # 获得信息增益最大的分类特征
    bestFeatLabel = labels[bestFeat]              # 获得当前特征的具体含义
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])   # 删除已分类的特征
    featValues = [example[bestFeat] for example in dataSet]   # 当前分类特征下的数据点特征值
    uniqueVals = set(featValues)     # 转换为list类型
    for value in uniqueVals:
        subLabels = labels[:]       # 拷贝标签
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree

"""返回出现次数最多的分类名称"""
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

测试

"""测试创建树"""
def testCreateTree():
    myDat,labels = trees.createDataSet()
    myTree = trees.createTree(myDat,labels)
    print(myTree)

结果

以字典的形式返回决策树

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

利用决策树判断新数据点

依照数据点的每一个特征根据决策树进行划分,知道得出类型。

"""利用决策树判断新数据点"""
def classify(inputTree,featLabels,testVec):
    firstSides = list(inputTree.keys())   # 第一个分类特征
    firstStr = firstSides[0]        #找到输入的第一个元素
    secondDict = inputTree[firstStr]      # 二级字典
    featIndex = featLabels.index(firstStr)   # 当前特征值在数据集的位置,返回时索引
    key = testVec[featIndex]             # 拿到新数据点的当前特征的特征值
    valueOfFeat = secondDict[key]        # 根据特征值 划分数据点
    if isinstance(valueOfFeat, dict):  # 如果不是叶节点,迭代;
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: classLabel = valueOfFeat     # 如果是叶节点,返回标签类
    return classLabel

测试

"""测试决策树判断新数据点"""
def testClassify():
    myDat,labels = trees.createDataSet()
    myTree = trees.createTree(myDat,labels)
    myDat,labels = trees.createDataSet()
    print(trees.classify(myTree,labels,[1,0]))
    print(trees.classify(myTree,labels,[1,1]))

结果

返回判断结果

no
yes

存储决策树

决策树的创建比较耗时,为了方便一次创建多次使用。可以把创建的决策树序列化,保存到磁盘上,需要的时候再读取使用。

"""序列化并写入磁盘"""
def storeTree(inputTree,filename):
    fw = open(filename,'wb+')   # 要以二进制格式打开文件
    pickle.dump(inputTree,fw)
    fw.close()

"""读取磁盘并反序列化"""   
def grabTree(filename):
    fr = open(filename,'rb')    # 要以二进制格式打开文件
    return pickle.load(fr)

测试

"""测试决策树保存"""   
def testStoreAndGrabTree():
    myDat,labels = trees.createDataSet()
    myTree = trees.createTree(myDat,labels)
    trees.storeTree(myTree,'trees.txt')
    reloadMyTree = trees.grabTree('trees.txt')
    print(reloadMyTree)  

结果

可以从磁盘得到之前的决策树

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

以上完整代码见GitHub