Python机器学习之决策树

时间:2022-05-09 23:51:46

   Python机器学习之决策树

Python机器学习系列博客以记录整个学习过程为主,一是为了加强监督,杜绝懒惰心理,二是方便日后查阅读!

决策树构造

  Python机器学习之决策树

决策树的绘制

1. 图标框架的确定

  1. 计算树的层次depth

  2. 计算树的总叶子数sumleafs

  3. 设定坐标轴的范围x(minx,maxx), y(miny,maxy)

  4. 单元格的宽度W = maxx/sumleafs,高度H = maxy/depth

  5. 单个结点位置,分两种情况,1是没有子结点,2 是有子结点

    (1)有子结点的结点位置

      offsetx + nodeleaf/sumleafs*W/2

      offsety = maxy

      offsety -= h

      p(offsetx, offsety)

    (2)没有子结点的结点位置

      offsetx = -W / 2

      offsetx += nodeleaf * W

      offsety = maxy

      offsety -= h

      p(offsetx, offsety)

2. amatplotlib.pyplot库函数的使用

  1. pyplot.axis()

  2. pyplot.text()

  3. pyplot.annotate()

  4. pyplot.show

  amatplotlib.pyplot库函数具体使用方法可以查看:Python之matplotlib库

决策树分类和绘制代码

Python机器学习之决策树Python机器学习之决策树
#coding=utf-8
from math import log

#计算香农熵
def calcShannonEnt(dataset):
    numentries = len(dataset)
    labelCounts = {}

    for featvec in dataset:
        currentlabel=featvec[-1]
        if currentlabel not in labelCounts.keys():
            labelCounts[currentlabel] = 0
        labelCounts[currentlabel] +=1    
    shannoent = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]/numentries)
        shannoent -= prob*log(prob,2)
    return shannoent

#根据axis和value进行数据分类
def SplitDataSet(dataset, axis, value):
    retdataset=[]
    for featvect in dataset:
        if featvect[axis] == value: 
            reducefeatvec=featvect[:axis]
            reducefeatvec.extend(featvect[axis+1:])  #特别注意extend和append的区别
            retdataset.append(reducefeatvec)
    return retdataset  

#选择最好的数据集划分方式
def ChooseBestFeatureToSplit(dataset):
    numfeatures = len(dataset[0])-1   #分类特征数
    baseentroy = calcShannonEnt(dataset)
    bestinfogain = 0.0
    bestfeature=-1
    for i in range(numfeatures):
        featlist = [example[i] for example in dataset]  #取特征信息
        uniquevalues = set(featlist)
        newentrop = 0.0
        for value in uniquevalues:
            subdataset = SplitDataSet(dataset, i, value)
            prob =len(subdataset)/float(len(dataset))
            newentrop += prob * calcShannonEnt(subdataset)
        infogain = baseentroy-newentrop
        
        if (infogain > bestinfogain):   #计算最好的信息增量
            bestinfogain = infogain
            bestfeature = i
            
    return bestfeature            

#多数表决决定该叶子节点的分类
import operator
def majorityCnt(classlist):
    classcount = {}
    for vote in classlist:
        if vote not in classcount.keys():
            classcount[vote] = 0
        else:
            classcount[vote] += 1
    sortedclasscount = sorted(classcount.items(), key=operator.itemgetter(1),reserve=True)
    return sortedclasscount[0][0]

def CreateTree(dataset, labels):
    classlist = [example[-1] for example in dataset]
    if classlist.count(classlist[0]) == len(dataset): #类别完全相同停止继续划分
        return classlist[0]
    if len(dataset[0]) == 1:
        return majorityCnt(classlist)
    bestfeat = ChooseBestFeatureToSplit(dataset)  #最适合特征索引
    bestfeatlabel = labels[bestfeat]#最适合标签
    mytree = {bestfeatlabel:{}}
    del(labels[bestfeat]) #删除最适合特征标签
    featvalues = [example[bestfeat] for example in dataset] #得到列表包含的所有属性值
    uniquevals = set(featvalues)
    for value in uniquevals:
        sublabels = labels[:]
        mytree[bestfeatlabel][value] = CreateTree(SplitDataSet\
                (dataset, bestfeat, value), sublabels)
    return mytree    

#获取叶子节点的数目
def getnumleafs(tree):
 
    numleafs = 0
    keys = tree.keys()
    for key in keys:
        print (key)
        if type(tree[key]).__name__ == 'dict':
            numleafs+=getnumleafs(tree[key])
        else:
            numleafs += 1      
    return numleafs        
            
def gettreedepth(tree):
    treedepth = 0
    maxdepth = 0
    keys = tree.keys()
    for key in keys:
        if type(tree[key]).__name__ == 'dict':            
            treedepth = 1 + gettreedepth(tree[key])
        else:
            treedepth = 0
        
    if treedepth > maxdepth:
        maxdepth = treedepth
    return maxdepth  
    
def createdataset():
    dataset = [[1,1,'yes'],
                [1,1,'yes'],
                [1,0,'no'],
                [0,1,'no'],
                [0,1,'no']]
    labels = ['no surfacing', 'flippers']
    return dataset,labels    

import matplotlib.pyplot as plt

def getnumleafs(tree):
 
    numleafs = 0
    keys = tree.keys()
    for key in keys:
        #print (key)
        if type(tree[key]).__name__ == 'dict':
            numleafs+=getnumleafs(tree[key])
        else:
            numleafs += 1      
    return numleafs        
            
def gettreedepth(tree):
    treedepth = 0
    maxdepth = 0
    keys = tree.keys()
    for key in keys:
        if type(tree[key]).__name__ == 'dict':            
            treedepth = 1 + gettreedepth(tree[key])
        else:
            treedepth = 0
        
    if treedepth > maxdepth:
        maxdepth = treedepth
    return maxdepth  

#tree= {'no surfacing':{0:'no', 88:"99",1:{'flippers':{0:'no',1:'yes', 2:"maybe",3:"sure"}}}}

def plotnode(nodetxt, parentpt, nextnodept):
    plt.annotate(nodetxt,parentpt,xytext=nextnodept,\
        xycoords='data',arrowprops={"arrowstyle":'<-'}, bbox={'facecolor':'yellow'})
    
def plotmidtext(parentpt, nextnodept, text):
    xmid = ((parentpt[0]-nextnodept[0]) / 2.0 + nextnodept[0])
    ymid = ((parentpt[1]-nextnodept[1])/2.0 + nextnodept[1])
    plt.text(xmid, ymid, text)
    
def plottree(tree, parentpt, nodetext, b):
    numleafs = getnumleafs(tree)
    depth = gettreedepth(tree)
    firststr = list(tree.keys())[0]
    #指向下一个结点位置
    #cntrpt = (plottree.xoff +(1.0 + float(numleafs)) / 2.0  / plottree.W,plottree.yoff)
    nextnodept = ((plottree.xoff + float(numleafs)/ plottree.W * plottree.max_x / 2), plottree.yoff)
    print (parentpt)
    print (nextnodept)
    plotmidtext(parentpt, nextnodept,nodetext)
    if not b:
        plotnode(firststr, parentpt,nextnodept)
    else:
        plt.text(nextnodept[0], nextnodept[1], firststr,bbox={'facecolor':'yellow'})
    seconddict=tree[firststr]
    plottree.yoff = plottree.yoff - plottree.max_y/plottree.H
    for key in seconddict.keys():
        if type(seconddict[key]).__name__ == "dict":
           plottree(seconddict[key], nextnodept, str(key), False) 
        else:
           plottree.xoff = plottree.xoff + plottree.cellwidth
           plotnode(seconddict[key], nextnodept,(plottree.xoff,plottree.yoff))
           plotmidtext(nextnodept, (plottree.xoff,plottree.yoff), str(key))
    #plottree.yoff = plottree.yoff + 1.0/plottree.H       
    
def createplot(tree):
    plottree.max_x = 10
    plottree.max_y = 10
    axprops = dict(xticks=[],yticks=[])
    plt.axis([0,plottree.max_x,0,plottree.max_y], visible=False, **axprops )
    plottree.W = float(getnumleafs(tree)) #树叶子数
    plottree.H = float(gettreedepth(tree)) #树深度
    plottree.cellwidth = plottree.max_x / plottree.W
    plottree.xoff = -  plottree.cellwidth/ 2
    plottree.yoff = plottree.max_y-0.1*plottree.max_y
    
    plottree(tree, (plottree.max_x/2,plottree.max_y), '', True)   
    plt.show()           
        
if __name__ == "__main__":
    dataset,labels =createdataset()
    mytree = CreateTree(dataset, labels)
    createplot(mytree)
View Code

Python机器学习之决策树