机器学习算法之决策树

时间:2022-12-20 12:00:39

大家都知道二叉树,决策树算法就是利用二叉树的结构,利用数据特征对数据集进行分类,直到所有具有相同类型的数据在一个子数据集内。本文的决策树算法参照《机器学习实战第三章,使用ID3算法划分数据集。如何确定用于划分数据的数据特征呢,使用信息论中的信息熵和信息增益作为划分的度量方法。信息熵的概念源自物理热力学,在热力学中用熵表示分子状态的混乱程度,香农在信息论中用熵的来描述信息源的不确定度,可以通过以下公式定义:

机器学习算法之决策树

其中p(xi)为每个特征值的概率,I(xi)表示随机变量的信息,信息的定义是:对于一个事件i,它发生的概率是pi,那么它的信息就是对这个概率取对数的相反数:I(xi)=logbP(xi),其中b为底数,可以取2,10,e.

要明白信息增益,我们还要直到条件熵,我们都知道条件概率是给定条件下某个事件发生的概率,条件熵就是给定条件下的条件干率分布的熵对X的数学期望,在机器学习中可以理解为选定某个特征后的熵:

机器学习算法之决策树

在知道熵、条件熵的概念后,我们就可以得到信息增益:所有分类的熵 - 某个特征值对应的条件熵:

机器学习算法之决策树

信息增益越大,就代表信息不确定性减少的程度最大,就是说那一个特征的条件熵对熵的影响很大,那么这个特征值就是最好的特征值。

 

以下是具体的代码实现:

# 决策树算法的代码

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle='sawtooth',fc="0.8")
leafNode = dict(boxstyle='round4',fc="0.8")
arrow_args = dict(arrowstyle="<-")


# 在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
    # 分别计算填充文文本位置的x,y坐标
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    # createPlot方法的ax1属性为一个plot视图,此处为视图添加文本
    createPlot.ax1.text(xMid,yMid,txtString)

# 计算树的宽和高
def plotTree(myTree, parentPt, nodeTxt):
    # 获取叶节点数
    numleafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    # 获取树的第一个key(根节点)
    firstStr = list(myTree.keys())[0]
    # 子节点的坐标计算
    # 子节点 X坐标=节点的x偏移量 + (叶节点数 )
    cntrPt = (plotTree.xOff + (1.0 + float(numleafs))/2.0/plotTree.totalW,plotTree.yOff)
    # 填充父子节点键的文本
    plotMidText(cntrPt, parentPt, nodeTxt)
    # 绘制树节点
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    # 通过第一个key取获取value
    secondDict = myTree[firstStr]
    # 树的Y坐标偏移量
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    # 对比value(所有节点名称,通过节点名称获取到对应的dict)
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            # 如果遍历到字典,将调用本身绘制子节点
            plotTree(secondDict[key],cntrPt,str(key))
        else: # 已经遍历不到字典,此处已经是最后一个,将其画上
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            # 绘制子节点
            plotNode(secondDict[key],(plotTree.xOff, plotTree.yOff),cntrPt, leafNode)
            # 添加节点间的文本信息
            plotMidText((plotTree.xOff, plotTree.yOff),cntrPt, str(key))
    # 确定y的偏移量
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD


# 创建视图
def createPlot(inTree):

    fig = plt.figure(1, facecolor='White')
    fig.clf()

    # 不需要设置x,y的刻度文本
    axprops = dict(xticks= [], yticks=[])
    # 添加子图
    createPlot.ax1 = plt.subplot(111,frameon=False, **axprops)
    # 设置plotTree方法中的变量
    # 总的宽度 = 叶子节点的数量
    plotTree.totalW = float(getNumLeafs(inTree))
    # 总的高度 = 树的层数
    plotTree.totalD = float(getTreeDepth(inTree))
    # 定义plotTree的xOff, yOff属性的初始值
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0
    # 调用plotTree方法
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()



# 绘制树节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


# def create_plot():
#     fig = plt.figure(1,facecolor='white')
#     fig.clf()
#     create_plot.ax1 = plt.subplot(111, frameon=False)
#     plotNode('决策节点',(0.5, 0.1), (0.1, 0.5), decisionNode)
#     plotNode('叶节点',(0.8, 0.1),(0.3, 0.8), leafNode)
#     plt.show()


# 获取叶节点数
def getNumLeafs(myTree):
    # 初始化叶节点的计数
    numLeafs = 0
    # 从myTree的所有节点获取第一个节点(根节点)
    firstStr = list(myTree.keys())[0]
    # 通过跟节点的key取出根key对应的value
    secondDict = myTree[firstStr]
    # 遍历根key的value(value包含根key包含的余下所有的子节点)
    # 上一级的value包含下一级的key,因此通过递归,可以不断取到下一层的value
    for key in secondDict.keys():
        # 只要获取到的value的是字典的类型,就进行递归,接着往下取叶节点
        if type(secondDict[key]).__name__ == 'dict':
            # 每次递归调用该函数都会获取到该节点下的所有叶节点,并进行计数
            numLeafs += getNumLeafs(secondDict[key])
        # 如果获取的vlaue不再是字典,说明已经是最后一个子节点,进行一次加1操作
        else: numLeafs += 1
    return numLeafs


# 获取树的层数

def getTreeDepth(myTree):
    # 树的层数与获取叶节点的步骤相似,区别在于
    # 叶节点数每遍历一次,如果遍历到叶子节点,那么将计数加一,累计叶子节点的个数;
    # 树层数的计数在递归的过程中,如果遍历到叶子节点,就会将计数值置为1,只保留max的计数。
    # 将这一层的深度记为1


    # 初始化一个记录最大深度的变量
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
        # 每次递归都进行依次+1的计数操作
            thisDepth = 1 + getTreeDepth(secondDict[key])
        # 如果没有遍历到dict,只有只有一层
        else: thisDepth = 1
        # 每一个key对用的子节点串(每一条路径)都会有一个最大值,记录其中最大的那个
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth


# 输出预先存储的树信息
def retriveTree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1:{'flippers':{0:'no',1:'yes'}}}},
                   {'no surfacing': {0: 'no', 1:{'flippers':{0: {'head':{0: 'no', 1:'yes'}}, 1: 'no'}}}}]

    return listOfTrees[i]

 

# 绘制决策树的代码

import matplotlib.pyplot as plt
from cha03_trees import trees

decisionNode = dict(boxstyle='sawtooth',fc="0.8")
leafNode = dict(boxstyle='round4',fc="0.8")
arrow_args = dict(arrowstyle="<-")


# 在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
    # 分别计算填充文文本位置的x,y坐标
    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
    # createPlot方法的ax1属性为一个plot视图,此处为视图添加文本
    createPlot.ax1.text(xMid,yMid,txtString)

# 计算树的宽和高
def plotTree(myTree, parentPt, nodeTxt):
    # 获取叶节点数
    numleafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    # 获取树的第一个key(根节点)
    firstStr = list(myTree.keys())[0]
    # 子节点的坐标计算
    # 子节点 X坐标=节点的x偏移量 + (叶节点数 )
    cntrPt = (plotTree.xOff + (1.0 + float(numleafs))/2.0/plotTree.totalW,plotTree.yOff)
    # 填充父子节点键的文本
    plotMidText(cntrPt, parentPt, nodeTxt)
    # 绘制树节点
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    # 通过第一个key取获取value
    secondDict = myTree[firstStr]
    # 树的Y坐标偏移量
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    # 对比value(所有节点名称,通过节点名称获取到对应的dict)
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            # 如果遍历到字典,将调用本身绘制子节点
            plotTree(secondDict[key],cntrPt,str(key))
        else: # 已经遍历不到字典,此处已经是最后一个,将其画上
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            # 绘制子节点
            plotNode(secondDict[key],(plotTree.xOff, plotTree.yOff),cntrPt, leafNode)
            # 添加节点间的文本信息
            plotMidText((plotTree.xOff, plotTree.yOff),cntrPt, str(key))
    # 确定y的偏移量
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD


# 创建视图
def createPlot(inTree):

    fig = plt.figure(1, facecolor='White')
    fig.clf()

    # 不需要设置x,y的刻度文本
    axprops = dict(xticks= [], yticks=[])
    # 添加子图
    createPlot.ax1 = plt.subplot(111,frameon=False, **axprops)
    # 设置plotTree方法中的变量
    # 总的宽度 = 叶子节点的数量
    plotTree.totalW = float(getNumLeafs(inTree))
    # 总的高度 = 树的层数
    plotTree.totalD = float(getTreeDepth(inTree))
    # 定义plotTree的xOff, yOff属性的初始值
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0
    # 调用plotTree方法
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()



# 绘制树节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


# def create_plot():
#     fig = plt.figure(1,facecolor='white')
#     fig.clf()
#     create_plot.ax1 = plt.subplot(111, frameon=False)
#     plotNode('决策节点',(0.5, 0.1), (0.1, 0.5), decisionNode)
#     plotNode('叶节点',(0.8, 0.1),(0.3, 0.8), leafNode)
#     plt.show()


# 获取叶节点数
def getNumLeafs(myTree):
    # 初始化叶节点的计数
    numLeafs = 0
    # 从myTree的所有节点获取第一个节点(根节点)
    firstStr = list(myTree.keys())[0]
    # 通过跟节点的key取出根key对应的value
    secondDict = myTree[firstStr]
    # 遍历根key的value(value包含根key包含的余下所有的子节点)
    # 上一级的value包含下一级的key,因此通过递归,可以不断取到下一层的value
    for key in secondDict.keys():
        # 只要获取到的value的是字典的类型,就进行递归,接着往下取叶节点
        if type(secondDict[key]).__name__ == 'dict':
            # 每次递归调用该函数都会获取到该节点下的所有叶节点,并进行计数
            numLeafs += getNumLeafs(secondDict[key])
        # 如果获取的vlaue不再是字典,说明已经是最后一个子节点,进行一次加1操作
        else: numLeafs += 1
    return numLeafs


# 获取树的层数

def getTreeDepth(myTree):
    # 树的层数与获取叶节点的步骤相似,区别在于
    # 叶节点数每遍历一次,如果遍历到叶子节点,那么将计数加一,累计叶子节点的个数;
    # 树层数的计数在递归的过程中,如果遍历到叶子节点,就会将计数值置为1,只保留max的计数。
    # 将这一层的深度记为1


    # 初始化一个记录最大深度的变量
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
        # 每次递归都进行依次+1的计数操作
            thisDepth = 1 + getTreeDepth(secondDict[key])
        # 如果没有遍历到dict,只有只有一层
        else: thisDepth = 1
        # 每一个key对用的子节点串(每一条路径)都会有一个最大值,记录其中最大的那个
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth


# 输出预先存储的树信息
def retriveTree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1:{'flippers':{0:'no',1:'yes'}}}},
                   {'no surfacing': {0: 'no', 1:{'flippers':{0: {'head':{0: 'no', 1:'yes'}}, 1: 'no'}}}}]

    return listOfTrees[i]


if __name__ == '__main__':
    # myTree = retriveTree(0)
    # createPlot(myTree)
    fr = open("../cha03_trees/lenses.txt")
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    lensLabels = ['age', 'prescipt','astigmatic','tearRate']
    lensesTree = trees.createTree(lenses, lensLabels)
    createPlot(lensesTree)

 

运行结果:

机器学习算法之决策树

代码地址:https://github.com/ZhaoJiangJie/MLInAction/tree/master/cha03_trees

参考:1.《机器学习实战》peter Harrington 著

         2.https://www.cnblogs.com/fantasy01/p/4581803.html

         3.https://www.zhihu.com/question/22104055

         4.http://blog.csdn.net/aws3217150/article/details/49906389