Python实现ID3算法

时间:2023-01-01 06:30:51

  自己用Python写的数据挖掘中的ID3算法,现在觉得Python是实现算法的最好工具:

  先贴出ID3算法的介绍地址http://wenku.baidu.com/view/cddddaed0975f46527d3e14f.html

  自己写的ID3算法

 from __future__ import division
import math table = {'age': {'young', 'middle', 'old'}, 'income': {'high', 'middle', 'low'},
'student': {'yes', 'no'}, 'credit': {'good', 'superior'}, 'buy computer': {'yes', 'no'}}
attrIndex = {'age': 0, 'income': 1, 'student': 2, 'credit': 3, 'buy computer': 4}
attrList = ['age', 'income', 'student', 'credit']
allDataSet = [
['young', 'high', 'no', 'good', 'no'], ['young', 'high', 'no', 'superior', 'no'],
['middle', 'high', 'no', 'superior', 'yes'], ['old', 'middle', 'no', 'good', 'yes'],
['young', 'middle', 'no', 'good', 'no'], ['young', 'low', 'yes', 'good', 'yes'],
['middle', 'high', 'yes', 'good', 'yes'], ['old', 'middle', 'no', 'superior', 'no'],
['young', 'high', 'yes', 'good', 'yes'], ['middle', 'middle', 'no', 'good', 'no']
] #求熵
def entropy(attr, dataSet):
valueCount = {v: {'yes': 0, 'no': 0, 'count': 0} for v in table[attr]}
for row in dataSet:
vName = row[attrIndex[attr]]
decAttrVal = row[attrIndex['buy computer']] # 'yes' or 'no'
valueCount[vName]['count'] = valueCount[vName]['count'] + 1
valueCount[vName][decAttrVal] = valueCount[vName][decAttrVal] + 1
infoMap = {v: 0 for v in table[attr]}
for v in valueCount:
if valueCount[v]['count'] == 0:
infoMap[v] = 0
else:
p1 = valueCount[v]['yes'] / valueCount[v]['count']
p2 = valueCount[v]['no'] / valueCount[v]['count']
infoMap[v] = - ((0 if p1 == 0 else p1 * math.log(p1, 2)) + (0 if p2 == 0 else p2 * math.log(p2, 2)))
s = 0
for v in valueCount:
s = s + valueCount[v]['count']
propMap = {v: (valueCount[v]['count'] / s) for v in valueCount}
i = 0
for v in valueCount:
i = i + infoMap[v] * propMap[v]
return i #定义节点的数据结构
class Node(object):
def __init__(self, attrName):
if attrName != '':
self.attr = attrName
self.childNodes = {v:Node('') for v in table[attrName]} #数据筛选
def filtrate(dataSet, condition):
result = []
for row in dataSet:
if row[attrIndex[condition['attr']]] == condition['val']:
result.append(row)
return result
#求最大信息熵
def maxEntropy(dataSet, attrList):
if len(attrList) == 1:
return attrList[0]
else:
attr = attrList[0]
maxE = entropy(attr, dataSet)
for a in attrList:
if maxE < entropy(a, dataSet):
attr = a
return attr
#判断构建是否结束,当所有的决策属性都相等的时候,就不用在构建决策树了
def endBuild(dataSet):
if len(dataSet) == 1:
return True
buy = dataSet[0][attrIndex['buy computer']]
for row in dataSet:
if buy != row[attrIndex['buy computer']]:
return False
#构建决策树
def buildDecisionTree(dataSet, root, attrList):
if len(attrList) == 0 or endBuild(dataSet):
root.attr = 'buy computer'
root.result = dataSet[0][attrIndex['buy computer']]
root.childNodes = {}
return
attr = root.attr
for v in root.childNodes:
childDataSet = filtrate(dataSet, {"attr":attr, "val":v})
if len(childDataSet) == 0:
root.childNodes[v] = Node('buy computer')
root.childNodes[v].result = 'no'
root.childNodes[v].childNodes = {}
continue
else:
childAttrList = [a for a in attrList]
childAttrList.remove(attr)
if len(childAttrList) == 0:
root.childNodes[v] = Node('buy computer')
root.childNodes[v].result = childDataSet[0][attrIndex['buy computer']]
root.childNodes[v].childNodes = {}
else:
childAttr = maxEntropy(childDataSet, childAttrList)
root.childNodes[v] = Node(childAttr)
buildDecisionTree(childDataSet, root.childNodes[v], childAttrList)
#预测结果
def predict(root, row):
if root.attr == 'buy computer':
return root.result
root = root.childNodes[row[attrIndex[root.attr]]]
return predict(root, row) rootAttr = maxEntropy(allDataSet, attrList)
rootNode = Node(rootAttr)
print rootNode.attr
buildDecisionTree(allDataSet, rootNode, attrList)
print predict(rootNode, ['old', 'low', 'yes', 'good'])

欢迎大家提出建议