Python机器学习系列博客以记录整个学习过程为主,一是为了加强监督,杜绝懒惰心理,二是方便日后查阅读!
决策树构造
决策树的绘制
1. 图标框架的确定
1. 计算树的层次depth
2. 计算树的总叶子数sumleafs
3. 设定坐标轴的范围x(minx,maxx), y(miny,maxy)
4. 单元格的宽度W = maxx/sumleafs,高度H = maxy/depth
5. 单个结点位置,分两种情况,1是没有子结点,2 是有子结点
(1)有子结点的结点位置
offsetx + nodeleaf/sumleafs*W/2
offsety = maxy
offsety -= h
p(offsetx, offsety)
(2)没有子结点的结点位置
offsetx = -W / 2
offsetx += nodeleaf * W
offsety = maxy
offsety -= h
p(offsetx, offsety)
2. amatplotlib.pyplot库函数的使用
1. pyplot.axis()
2. pyplot.text()
3. pyplot.annotate()
4. pyplot.show
amatplotlib.pyplot库函数具体使用方法可以查看:Python之matplotlib库
决策树分类和绘制代码
#coding=utf-8 from math import log #计算香农熵 def calcShannonEnt(dataset): numentries = len(dataset) labelCounts = {} for featvec in dataset: currentlabel=featvec[-1] if currentlabel not in labelCounts.keys(): labelCounts[currentlabel] = 0 labelCounts[currentlabel] +=1 shannoent = 0.0 for key in labelCounts: prob = float(labelCounts[key]/numentries) shannoent -= prob*log(prob,2) return shannoent #根据axis和value进行数据分类 def SplitDataSet(dataset, axis, value): retdataset=[] for featvect in dataset: if featvect[axis] == value: reducefeatvec=featvect[:axis] reducefeatvec.extend(featvect[axis+1:]) #特别注意extend和append的区别 retdataset.append(reducefeatvec) return retdataset #选择最好的数据集划分方式 def ChooseBestFeatureToSplit(dataset): numfeatures = len(dataset[0])-1 #分类特征数 baseentroy = calcShannonEnt(dataset) bestinfogain = 0.0 bestfeature=-1 for i in range(numfeatures): featlist = [example[i] for example in dataset] #取特征信息 uniquevalues = set(featlist) newentrop = 0.0 for value in uniquevalues: subdataset = SplitDataSet(dataset, i, value) prob =len(subdataset)/float(len(dataset)) newentrop += prob * calcShannonEnt(subdataset) infogain = baseentroy-newentrop if (infogain > bestinfogain): #计算最好的信息增量 bestinfogain = infogain bestfeature = i return bestfeature #多数表决决定该叶子节点的分类 import operator def majorityCnt(classlist): classcount = {} for vote in classlist: if vote not in classcount.keys(): classcount[vote] = 0 else: classcount[vote] += 1 sortedclasscount = sorted(classcount.items(), key=operator.itemgetter(1),reserve=True) return sortedclasscount[0][0] def CreateTree(dataset, labels): classlist = [example[-1] for example in dataset] if classlist.count(classlist[0]) == len(dataset): #类别完全相同停止继续划分 return classlist[0] if len(dataset[0]) == 1: return majorityCnt(classlist) bestfeat = ChooseBestFeatureToSplit(dataset) #最适合特征索引 bestfeatlabel = labels[bestfeat]#最适合标签 mytree = {bestfeatlabel:{}} del(labels[bestfeat]) #删除最适合特征标签 featvalues = [example[bestfeat] for example in dataset] #得到列表包含的所有属性值 uniquevals = set(featvalues) for value in uniquevals: sublabels = labels[:] mytree[bestfeatlabel][value] = CreateTree(SplitDataSet\ (dataset, bestfeat, value), sublabels) return mytree #获取叶子节点的数目 def getnumleafs(tree): numleafs = 0 keys = tree.keys() for key in keys: print (key) if type(tree[key]).__name__ == 'dict': numleafs+=getnumleafs(tree[key]) else: numleafs += 1 return numleafs def gettreedepth(tree): treedepth = 0 maxdepth = 0 keys = tree.keys() for key in keys: if type(tree[key]).__name__ == 'dict': treedepth = 1 + gettreedepth(tree[key]) else: treedepth = 0 if treedepth > maxdepth: maxdepth = treedepth return maxdepth def createdataset(): dataset = [[1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no']] labels = ['no surfacing', 'flippers'] return dataset,labels import matplotlib.pyplot as plt def getnumleafs(tree): numleafs = 0 keys = tree.keys() for key in keys: #print (key) if type(tree[key]).__name__ == 'dict': numleafs+=getnumleafs(tree[key]) else: numleafs += 1 return numleafs def gettreedepth(tree): treedepth = 0 maxdepth = 0 keys = tree.keys() for key in keys: if type(tree[key]).__name__ == 'dict': treedepth = 1 + gettreedepth(tree[key]) else: treedepth = 0 if treedepth > maxdepth: maxdepth = treedepth return maxdepth #tree= {'no surfacing':{0:'no', 88:"99",1:{'flippers':{0:'no',1:'yes', 2:"maybe",3:"sure"}}}} def plotnode(nodetxt, parentpt, nextnodept): plt.annotate(nodetxt,parentpt,xytext=nextnodept,\ xycoords='data',arrowprops={"arrowstyle":'<-'}, bbox={'facecolor':'yellow'}) def plotmidtext(parentpt, nextnodept, text): xmid = ((parentpt[0]-nextnodept[0]) / 2.0 + nextnodept[0]) ymid = ((parentpt[1]-nextnodept[1])/2.0 + nextnodept[1]) plt.text(xmid, ymid, text) def plottree(tree, parentpt, nodetext, b): numleafs = getnumleafs(tree) depth = gettreedepth(tree) firststr = list(tree.keys())[0] #指向下一个结点位置 #cntrpt = (plottree.xoff +(1.0 + float(numleafs)) / 2.0 / plottree.W,plottree.yoff) nextnodept = ((plottree.xoff + float(numleafs)/ plottree.W * plottree.max_x / 2), plottree.yoff) print (parentpt) print (nextnodept) plotmidtext(parentpt, nextnodept,nodetext) if not b: plotnode(firststr, parentpt,nextnodept) else: plt.text(nextnodept[0], nextnodept[1], firststr,bbox={'facecolor':'yellow'}) seconddict=tree[firststr] plottree.yoff = plottree.yoff - plottree.max_y/plottree.H for key in seconddict.keys(): if type(seconddict[key]).__name__ == "dict": plottree(seconddict[key], nextnodept, str(key), False) else: plottree.xoff = plottree.xoff + plottree.cellwidth plotnode(seconddict[key], nextnodept,(plottree.xoff,plottree.yoff)) plotmidtext(nextnodept, (plottree.xoff,plottree.yoff), str(key)) #plottree.yoff = plottree.yoff + 1.0/plottree.H def createplot(tree): plottree.max_x = 10 plottree.max_y = 10 axprops = dict(xticks=[],yticks=[]) plt.axis([0,plottree.max_x,0,plottree.max_y], visible=False, **axprops ) plottree.W = float(getnumleafs(tree)) #树叶子数 plottree.H = float(gettreedepth(tree)) #树深度 plottree.cellwidth = plottree.max_x / plottree.W plottree.xoff = - plottree.cellwidth/ 2 plottree.yoff = plottree.max_y-0.1*plottree.max_y plottree(tree, (plottree.max_x/2,plottree.max_y), '', True) plt.show() if __name__ == "__main__": dataset,labels =createdataset() mytree = CreateTree(dataset, labels) createplot(mytree)