写在开头的话:在学习《机器学习实战》的过程中发现书中很多代码并没有注释,这对新入门的同学是一个挑战,特此贴出我对代码做出的注释,仅供参考,欢迎指正。
1、trees.py
#coding:gbk
from math import log
import operator
#作用:建立数据集
#输出:数据集,标签名称
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
#作用:计算香农熵
#输入:数据集列表,最后一列为类
#输出:数据集香农熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet)#数据集中实例总数
labelCounts = {}#创建字典,表示类标签出现次数
for featVec in dataSet:
currentLabel = featVec[-1]#最后一列,即类标签
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0#香农熵
for key in labelCounts:#对每个类标签来说
prob = float(labelCounts[key]) / numEntries#标签出现概率
shannonEnt -= prob * log(prob, 2)
return shannonEnt
#作用:按照给定特征划分数据集,去除axis对应列特征值等于value的值
#输入:待划分的数据集,划分数据集的特征即列数,需要返回的特征的值
#输入:划分后的数据集
def splitDataSet(dataSet, axis, value):
retDataSet = []#返回列表
for featVec in dataSet:#对数据集中每一行
if featVec[axis] == value:#如果相等
reducedFeatVec = featVec[:axis]#该行和下一行的作用是得到去除featVec[axis]的列表
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)#将去除featVec[axis]的列表添加到返回列表中
return retDataSet
#作用:得到最好的数据集划分方式
#输入:数据集列表,最后一列为类
#输出:最好的数据集划分方式对应的特征值
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1#dataSet特征数,-1表示最后一列为类别标签
baseEntropy = calcShannonEnt(dataSet)#dataSet的香农熵
bestInfoGain = 0.0;#最大信息熵
bestFeature = -1;#最佳特征值
for i in range(numFeatures):
featList = [example[i] for example in dataSet]#列表推导式,找到第i个特征对应的属性值,注意是列表,会有多个相同的属性值
uniqueVals = set(featList)#转换为集合,消除相同的属性值,集合里只能存在不相同的属性值
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))#注意float,不能两个int值相除,只能得int值
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):#如果新特征值拥有更大的信息熵
bestInfoGain = infoGain
bestFeature = i
return bestFeature
#作用:返回出现最多的分类名称
#输入:分类名称的列表
#输出:出现最多的分类名称
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1#出现频率加1
sortedClassCount = sorted(classCount.iteritems,#iteritems()表示将classCount以一个迭代器对象返回
key = operator.itemgetter(1), reverse = true)#operator.itemgetter(1)表示第2维数据即值,reverse = True表示从大大小排列
return sortedClassCount
#作用:创建树
#输入:数据集,标签名称
#输出:树的字典形式
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]#列表推导式,得类别标签列表
#类别完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):#classList.count(classList[0])表示将计算第一个类别出现的次数
return classList[0]
#遍历完所有特征时返回出现次数最多的类别
#该程序用到了递归,此为递归退出条件
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)#最好的数据集划分方式对应的特征值
bestFeatLabel = labels[bestFeat]#最好的数据集划分方式对应的特征值对应的类别名称
myTree = {bestFeatLabel:{}}#建立字典,键为根节点类别名称
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]#列表推导式,得特征值对应值的列表
uniqueVals = set(featValues)#建立集合
for value in uniqueVals:
subLabels = labels[:]#使用新变量代替原始列表
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)#创建子节点
return myTree
#作用:使用决策树的分类函数
#输入:树的字典形式,分类标签,待分类矢量
#输出:分类标签
def classify(inputTree, featLabels, testVec):
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)#得firstStr在分类标签中的索引
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':#如果该子节点为字典类型,如是则递归调用
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
#作用:储存决策树
#输入:树的字典形式,文件名字
#输出:无
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(inputTree, fw)
fw.close()
#作用:提取决策树
#输入:文件名字
#输出:树的字典形式
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
2、treePlotter.py
#coding:gbk
import matplotlib.pyplot as plt
#定义文本框和箭头格式
decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")
#作用:绘制带箭头的注解
#输入:注解,注解坐标,起点坐标,注解类型
#输出:无
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction', xytext = centerPt, textcoords = 'axes fraction',
va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)
#作用:绘制图像
#输入:
#输出:无
def createPlot(inTree):
fig = plt.figure(1, facecolor = 'white')
fig.clf()
axprops = dict(xticks = [], yticks = [])
createPlot.ax1 = plt.subplot(111, frameon = False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))#树的宽度
plotTree.totalD = float(getTreeDepth(inTree))#数的高度
plotTree.x0ff = -0.5/plotTree.totalW;#根节点x值?
plotTree.y0ff = 1.0;#根节点y值,为1.0表示放在最高点
plotTree(inTree, (0.5, 1.0), '')#绘制根节点,0.5表示在x方向的中间,1.0表示在y方向的最上面,''表示为根节点,不用标记子节点属性值
#plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
#作用:获取叶节点的数目
#输入:树的字典形式
#输出:叶节点的数目
def getNumLeafs(myTree):
numLeafs = 0#叶节点数目
firstStr = myTree.keys()[0]#根节点键
secondDict = myTree[firstStr]#根节点值
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':#如果该子节点为字典类型,如是则递归调用
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
#作用:获取树的层数
#输入:树的字典形式
#输出:树的层数
def getTreeDepth(myTree):
maxDepth = 0#数的层数
firstStr = myTree.keys()[0]#根节点键
secondDict = myTree[firstStr]#根节点值
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':#如果该子节点为字典类型,如是则递归调用
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
#作用:输出树的字典形式
#输入:需要的数
#输出:树的字典形式
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]
#作用:在父子节点间填充文本信息
#输入:子节点位置,父节点位置,文本信息
#输出:无
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)#叶节点数目
depth = getTreeDepth(myTree)#树的层数
firstStr = myTree.keys()[0]#根节点键
cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs)) / 2.0 /plotTree.totalW, plotTree.y0ff)
plotMidText(cntrPt, parentPt, nodeTxt)#绘制文字
plotNode(firstStr, cntrPt, parentPt, decisionNode)#绘制根节点
secondDict = myTree[firstStr]
plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':#如果该子节点为字典类型,如是则递归调用
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)
plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))
plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD