from pylab import *
from matplotlib import pyplot as plt
from math import log
"""
ID3决策树算法
"""
def createDateSet():
"""
函数说明:读取数据集并做预处理
:return: 测试数据列表,属性名列表
"""
fo = open("", "r+", encoding="utf-8") # 打开文件
readlist = ().splitlines() # 去掉换行符
dataSet = []
for i in range(len(readlist)):
line = readlist[i].split(',') # 进一步分割
(line)
()
featuerName = dataSet[0][1:7] # 分类属性
#featuerName=['OutLook', 'Temperature', 'Humidity', 'Wind']
dataSet = dataSet[1:] # 去除属性名行和编号列
dataSet1=[]
for line in dataSet:
(line[1:])
return dataSet1, featuerName
# 香农熵
def calcShannonEnt(dataSet):
"""
计算给定数据集的期望信息量,即香农熵
:param dataSet: 要处理的数据集列表
:return: 期望信息量
"""
numEntries = len(dataSet) # 样本的个数
labelCounts = {}
for featVec in dataSet: # 遍历每个实例,统计标签的频次
currentLabel = featVec[-1]
# 当前标签不在labelCounts dic中,就让labelCounts加入该标签
labelCounts[currentLabel] = (currentLabel, 0) + 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
# 划分子集
def splitDataSet(dataSet, axis, value):
"""
根据给定属性,划分子集
:param dataSet: 数据集列表
:param axis: 给定属性的索引
:param value: 给定属性的取值
:return: 返回属性值位value的数据集列表子集(已经不含给定的属性)
"""
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reduceFeatVec = featVec[:axis] # 删除这一属性,不能用pop()等方法操作,根据列表作为函数参数的特性会出错
(featVec[axis + 1:])
(reduceFeatVec)
return retDataSet
# 信息熵
def calcConditionalEntropy(dataSet, i, uniqueVals):
"""
索引为i的属性的信息熵
:param dataSet: 传入数据集
:param i: 属性的索引
:param uniqueVals: dataSet列表中给定属性所包含的属性值列表
:return: 信息熵
"""
Ea = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
Ea += prob * calcShannonEnt(subDataSet)
return Ea
# 信息增益
def calcInformationGain(dataSet, shannonEnt, i):
"""
计算索引为i的属性的信息增益
:param dataSet: 给定数据集列表
:param shannonEnt: 香农熵
:param i: 属性索引
:return: 返回该属性的信息增益
"""
featList = [example[i] for example in dataSet]
uniqueVals = set(featList) # featList中所含属性值
newEntropy = calcConditionalEntropy(dataSet, i, uniqueVals)
infoGain = shannonEnt - newEntropy # 信息增益 = 信息熵 - 条件熵
return infoGain
# 寻找最高信息增益
def chooseBestFeatureToSplit(dataSet):
"""
寻找最高信息增益
:param dataSet: 数据集列表
:return: 返回最高信息增益的属性索引
"""
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet) # 返回整个数据集的香农熵
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures): # 遍历所有属性
infoGain = calcInformationGain(dataSet, baseEntropy, i) # 返回具体特征的信息增益
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
# 只剩一种属性下仍存在分类,选取类数最多的类
def majorityCnt(classList):
"""
只剩一种属性下仍存在分类,选取类数最多的类
:param classList:数据集列表
:return: 返回类数最多的类
"""
classCount = {}
for vote in classList:
classCount[vote] = (vote, 0) + 1
sortedClassCount = sorted((), key=(1), reverse=True) # 根据字典的值降序排序
return sortedClassCount[0][0] # 返回classList中出现次数最多的元素
# 生成树
def createTree(dataSet, featureName, featurevalue):
"""
生成树
:param dataSet: 数据集列表
:param featureName: 属性列表
:param featurevalue: 属性顺序列表
:return: 生成树嵌套字典
"""
classList = [example[-1] for example in dataSet]
if (classList[0]) == len(classList):
return classList[0] # 当类别完全相同则停止继续划分
if len(dataSet[0]) == 1: # 当只有一个特征的时候,遍历所有实例返回出现次数最多的类别
return majorityCnt(classList) # 返回类别标签
bestFeat = chooseBestFeatureToSplit(dataSet) # 最佳特征对应的索引
bestFeatLabel = featureName[bestFeat] # 最佳特征
(bestFeatLabel)
myTree = {bestFeatLabel: {}} # 最优生成树,嵌套字典结构
del (featureName[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals: #集合无顺序存储
subLabels = featureName[:] # 复制操作
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels, featurevalue)
return myTree
"""
绘制生成树
注意此处的X轴向右,Y轴向上
"""
# 全局变量节点及箭头类型
decisionNode = dict(boxstyle="round", color='#CD853F') # 定义判断结点形态
leafNode = dict(boxstyle="circle", color='#008B45') # 定义叶结点形态
arrow_args = dict(arrowstyle="<-", color='#CD661D') # 定义箭头
# 计算叶结点数
def get_num_leafs(myTree):
"""
计算叶节点数
:param myTree: 计算节点的树
:return: 叶节点数目
"""
numLeafs = 0
firstStr = list(())[0]
secondDict = myTree[firstStr]
for key in ():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += get_num_leafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
# 计算树的层数
def get_tree_depth(myTree):
"""
计算树的最大层数
:param myTree: 给定树
:return: 最大层数
"""
maxDepth = 0
firstStr = list(())[0]
secondDict = myTree[firstStr]
for key in ():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + get_tree_depth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
# 填充父子结点间文本信息
def plotMidText(cntrPt, parentPt, txtString):
"""
填充父子结点间文本信息
:param cntrPt: 子节点坐标
:param parentPt: 父节点坐标
:param txtString: 文本信息
:return:
"""
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.(xMid, yMid, txtString, va="center", ha="center")
# 绘制节点及指向箭头
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
"""
绘制节点及指向箭头
:param nodeTxt: 节点中文本
:param centerPt: 子节点
:param parentPt: 父节点
:param nodeType: 节点形状
:return:
"""
createPlot.(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# 绘制所有树节点
def plotTree(myTree, parentPt, nodeTxt):
"""
绘制树
:param myTree: 要绘制的树,子树等
:param parentPt: 父节点
:param nodeTxt: 属性值文本
:return:
"""
numLeafs = get_num_leafs(myTree)
depth = get_tree_depth(myTree)
firstStr = list(())[0]
cntrPt = ( + (1.0 + float(numLeafs)) / 2.0 / , )
# cntrPt = ( + float(numLeafs) / 2.0 / , ) # 子节点x坐标
plotMidText(cntrPt, parentPt, nodeTxt) # 填充父子结点间文本信息
plotNode(firstStr, cntrPt, parentPt, decisionNode) # 绘制节点及指向箭头
secondDict = myTree[firstStr]
= - 1.0 /
for key in ():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
= + 1.0 /
plotNode(secondDict[key], (, ), cntrPt, leafNode)
plotMidText((, ), cntrPt, str(key))
= + 1.0 /
# 绘图显示
def createPlot(inTree):
"""
绘图显示
:param inTree: 输入嵌套字典树
:return:
"""
fig = (1, facecolor='white') # xy轴坐标为1
()
axprops = dict(xticks=[], yticks=[]) # 去除坐标线
createPlot.ax1 = (111, frameon=False, **axprops)
= float(get_num_leafs(inTree))
= float(get_tree_depth(inTree))
= -0.5 / ;
= 1.0;
plotTree(inTree, (0.5, 1.0), '')
()
if __name__ == '__main__':
['-serif'] = ['SimHei'] # 指定默认字体
['axes.unicode_minus'] = False # 解决保存图像时负号'-'显示为方块的问题
# 决策树的构建
treelist = []
myDat, labels = createDateSet()
print(myDat)
print(labels)
myTree = createTree(myDat, labels, treelist)
print(myTree)
# 绘制决策树
createPlot(myTree)