决策树(ID3算法)--分别基于两种语言实现

时间:2025-03-27 10:58:45
  • from pylab import *
  • from matplotlib import pyplot as plt
  • from math import log
  • """
  • ID3决策树算法
  • """
  • def createDateSet():
  • """
  • 函数说明:读取数据集并做预处理
  • :return: 测试数据列表,属性名列表
  • """
  • fo = open("", "r+", encoding="utf-8") # 打开文件
  • readlist = ().splitlines() # 去掉换行符
  • dataSet = []
  • for i in range(len(readlist)):
  • line = readlist[i].split(',') # 进一步分割
  • (line)
  • ()
  • featuerName = dataSet[0][1:7] # 分类属性
  • #featuerName=['OutLook', 'Temperature', 'Humidity', 'Wind']
  • dataSet = dataSet[1:] # 去除属性名行和编号列
  • dataSet1=[]
  • for line in dataSet:
  • (line[1:])
  • return dataSet1, featuerName
  • # 香农熵
  • def calcShannonEnt(dataSet):
  • """
  • 计算给定数据集的期望信息量,即香农熵
  • :param dataSet: 要处理的数据集列表
  • :return: 期望信息量
  • """
  • numEntries = len(dataSet) # 样本的个数
  • labelCounts = {}
  • for featVec in dataSet: # 遍历每个实例,统计标签的频次
  • currentLabel = featVec[-1]
  • # 当前标签不在labelCounts dic中,就让labelCounts加入该标签
  • labelCounts[currentLabel] = (currentLabel, 0) + 1
  • shannonEnt = 0.0
  • for key in labelCounts:
  • prob = float(labelCounts[key]) / numEntries
  • shannonEnt -= prob * log(prob, 2)
  • return shannonEnt
  • # 划分子集
  • def splitDataSet(dataSet, axis, value):
  • """
  • 根据给定属性,划分子集
  • :param dataSet: 数据集列表
  • :param axis: 给定属性的索引
  • :param value: 给定属性的取值
  • :return: 返回属性值位value的数据集列表子集(已经不含给定的属性)
  • """
  • retDataSet = []
  • for featVec in dataSet:
  • if featVec[axis] == value:
  • reduceFeatVec = featVec[:axis] # 删除这一属性,不能用pop()等方法操作,根据列表作为函数参数的特性会出错
  • (featVec[axis + 1:])
  • (reduceFeatVec)
  • return retDataSet
  • # 信息熵
  • def calcConditionalEntropy(dataSet, i, uniqueVals):
  • """
  • 索引为i的属性的信息熵
  • :param dataSet: 传入数据集
  • :param i: 属性的索引
  • :param uniqueVals: dataSet列表中给定属性所包含的属性值列表
  • :return: 信息熵
  • """
  • Ea = 0.0
  • for value in uniqueVals:
  • subDataSet = splitDataSet(dataSet, i, value)
  • prob = len(subDataSet) / float(len(dataSet))
  • Ea += prob * calcShannonEnt(subDataSet)
  • return Ea
  • # 信息增益
  • def calcInformationGain(dataSet, shannonEnt, i):
  • """
  • 计算索引为i的属性的信息增益
  • :param dataSet: 给定数据集列表
  • :param shannonEnt: 香农熵
  • :param i: 属性索引
  • :return: 返回该属性的信息增益
  • """
  • featList = [example[i] for example in dataSet]
  • uniqueVals = set(featList) # featList中所含属性值
  • newEntropy = calcConditionalEntropy(dataSet, i, uniqueVals)
  • infoGain = shannonEnt - newEntropy # 信息增益 = 信息熵 - 条件熵
  • return infoGain
  • # 寻找最高信息增益
  • def chooseBestFeatureToSplit(dataSet):
  • """
  • 寻找最高信息增益
  • :param dataSet: 数据集列表
  • :return: 返回最高信息增益的属性索引
  • """
  • numFeatures = len(dataSet[0]) - 1
  • baseEntropy = calcShannonEnt(dataSet) # 返回整个数据集的香农熵
  • bestInfoGain = 0.0
  • bestFeature = -1
  • for i in range(numFeatures): # 遍历所有属性
  • infoGain = calcInformationGain(dataSet, baseEntropy, i) # 返回具体特征的信息增益
  • if infoGain > bestInfoGain:
  • bestInfoGain = infoGain
  • bestFeature = i
  • return bestFeature
  • # 只剩一种属性下仍存在分类,选取类数最多的类
  • def majorityCnt(classList):
  • """
  • 只剩一种属性下仍存在分类,选取类数最多的类
  • :param classList:数据集列表
  • :return: 返回类数最多的类
  • """
  • classCount = {}
  • for vote in classList:
  • classCount[vote] = (vote, 0) + 1
  • sortedClassCount = sorted((), key=(1), reverse=True) # 根据字典的值降序排序
  • return sortedClassCount[0][0] # 返回classList中出现次数最多的元素
  • # 生成树
  • def createTree(dataSet, featureName, featurevalue):
  • """
  • 生成树
  • :param dataSet: 数据集列表
  • :param featureName: 属性列表
  • :param featurevalue: 属性顺序列表
  • :return: 生成树嵌套字典
  • """
  • classList = [example[-1] for example in dataSet]
  • if (classList[0]) == len(classList):
  • return classList[0] # 当类别完全相同则停止继续划分
  • if len(dataSet[0]) == 1: # 当只有一个特征的时候,遍历所有实例返回出现次数最多的类别
  • return majorityCnt(classList) # 返回类别标签
  • bestFeat = chooseBestFeatureToSplit(dataSet) # 最佳特征对应的索引
  • bestFeatLabel = featureName[bestFeat] # 最佳特征
  • (bestFeatLabel)
  • myTree = {bestFeatLabel: {}} # 最优生成树,嵌套字典结构
  • del (featureName[bestFeat])
  • featValues = [example[bestFeat] for example in dataSet]
  • uniqueVals = set(featValues)
  • for value in uniqueVals: #集合无顺序存储
  • subLabels = featureName[:] # 复制操作
  • myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels, featurevalue)
  • return myTree
  • """
  • 绘制生成树
  • 注意此处的X轴向右,Y轴向上
  • """
  • # 全局变量节点及箭头类型
  • decisionNode = dict(boxstyle="round", color='#CD853F') # 定义判断结点形态
  • leafNode = dict(boxstyle="circle", color='#008B45') # 定义叶结点形态
  • arrow_args = dict(arrowstyle="<-", color='#CD661D') # 定义箭头
  • # 计算叶结点数
  • def get_num_leafs(myTree):
  • """
  • 计算叶节点数
  • :param myTree: 计算节点的树
  • :return: 叶节点数目
  • """
  • numLeafs = 0
  • firstStr = list(())[0]
  • secondDict = myTree[firstStr]
  • for key in ():
  • if type(secondDict[key]).__name__ == 'dict':
  • numLeafs += get_num_leafs(secondDict[key])
  • else:
  • numLeafs += 1
  • return numLeafs
  • # 计算树的层数
  • def get_tree_depth(myTree):
  • """
  • 计算树的最大层数
  • :param myTree: 给定树
  • :return: 最大层数
  • """
  • maxDepth = 0
  • firstStr = list(())[0]
  • secondDict = myTree[firstStr]
  • for key in ():
  • if type(secondDict[key]).__name__ == 'dict':
  • thisDepth = 1 + get_tree_depth(secondDict[key])
  • else:
  • thisDepth = 1
  • if thisDepth > maxDepth:
  • maxDepth = thisDepth
  • return maxDepth
  • # 填充父子结点间文本信息
  • def plotMidText(cntrPt, parentPt, txtString):
  • """
  • 填充父子结点间文本信息
  • :param cntrPt: 子节点坐标
  • :param parentPt: 父节点坐标
  • :param txtString: 文本信息
  • :return:
  • """
  • xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
  • yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
  • createPlot.(xMid, yMid, txtString, va="center", ha="center")
  • # 绘制节点及指向箭头
  • def plotNode(nodeTxt, centerPt, parentPt, nodeType):
  • """
  • 绘制节点及指向箭头
  • :param nodeTxt: 节点中文本
  • :param centerPt: 子节点
  • :param parentPt: 父节点
  • :param nodeType: 节点形状
  • :return:
  • """
  • createPlot.(nodeTxt, xy=parentPt, xycoords='axes fraction',
  • xytext=centerPt, textcoords='axes fraction',
  • va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
  • # 绘制所有树节点
  • def plotTree(myTree, parentPt, nodeTxt):
  • """
  • 绘制树
  • :param myTree: 要绘制的树,子树等
  • :param parentPt: 父节点
  • :param nodeTxt: 属性值文本
  • :return:
  • """
  • numLeafs = get_num_leafs(myTree)
  • depth = get_tree_depth(myTree)
  • firstStr = list(())[0]
  • cntrPt = ( + (1.0 + float(numLeafs)) / 2.0 / , )
  • # cntrPt = ( + float(numLeafs) / 2.0 / , ) # 子节点x坐标
  • plotMidText(cntrPt, parentPt, nodeTxt) # 填充父子结点间文本信息
  • plotNode(firstStr, cntrPt, parentPt, decisionNode) # 绘制节点及指向箭头
  • secondDict = myTree[firstStr]
  • = - 1.0 /
  • for key in ():
  • if type(secondDict[key]).__name__ == 'dict':
  • plotTree(secondDict[key], cntrPt, str(key))
  • else:
  • = + 1.0 /
  • plotNode(secondDict[key], (, ), cntrPt, leafNode)
  • plotMidText((, ), cntrPt, str(key))
  • = + 1.0 /
  • # 绘图显示
  • def createPlot(inTree):
  • """
  • 绘图显示
  • :param inTree: 输入嵌套字典树
  • :return:
  • """
  • fig = (1, facecolor='white') # xy轴坐标为1
  • ()
  • axprops = dict(xticks=[], yticks=[]) # 去除坐标线
  • createPlot.ax1 = (111, frameon=False, **axprops)
  • = float(get_num_leafs(inTree))
  • = float(get_tree_depth(inTree))
  • = -0.5 / ;
  • = 1.0;
  • plotTree(inTree, (0.5, 1.0), '')
  • ()
  • if __name__ == '__main__':
  • ['-serif'] = ['SimHei'] # 指定默认字体
  • ['axes.unicode_minus'] = False # 解决保存图像时负号'-'显示为方块的问题
  • # 决策树的构建
  • treelist = []
  • myDat, labels = createDateSet()
  • print(myDat)
  • print(labels)
  • myTree = createTree(myDat, labels, treelist)
  • print(myTree)
  • # 绘制决策树
  • createPlot(myTree)