首先推荐李航的《统计机器学习》这本书,这个实现就是按照书上的算法来的。Python 用的是最新的3.3版的,和2.x不兼容,运行的时候需要注意。
'''
Created on 2012-12-18
@author: weisu.yxd
'''
class Node:
'''Represents a decision tree node.
'''
def __init__(self, parent = None, dataset = None):
self.dataset = dataset # 落在该结点的训练实例集
self.result = None # 结果类标签
self.attr = None # 该结点的分裂属性ID
self.childs = {} # 该结点的子树列表,key-value pair: (属性attr的值, 对应的子树)
self.parent = parent # 该结点的父亲结点
def entropy(props):
if (not isinstance(props, (tuple, list))):
return None
from math import log
log2 = lambda x:log(x)/log(2) # an anonymous function
e = 0.0
for p in props:
e -= p * log2(p)
return e
def info_gain(D, A, T = -1, return_ratio = False):
'''特征A对训练数据集D的信息增益 g(D,A)
g(D,A)=entropy(D) - entropy(D|A)
假设数据集D的每个元组的最后一个特征为类标签
T为目标属性的ID,-1表示元组的最后一个元素为目标'''
if (not isinstance(D, (set, list))):
return None
if (not type(A) is int):
return None
C = {} # 类别计数字典
DA = {} # 特征A的取值计数字典
CDA = {} # 类别和特征A的不同组合的取值计数字典
for t in D:
C[t[T]] = C.get(t[T], 0) + 1
DA[t[A]] = DA.get(t[A], 0) + 1
CDA[(t[T], t[A])] = CDA.get((t[T], t[A]), 0) + 1
PC = map(lambda x : x / len(D), C.values()) # 类别的概率列表
entropy_D = entropy(tuple(PC)) # map返回的对象类型为map,需要强制类型转换为元组
PCDA = {} # 特征A的每个取值给定的条件下各个类别的概率(条件概率)
for key, value in CDA.items():
a = key[1] # 特征A
pca = value / DA[a]
PCDA.setdefault(a, []).append(pca)
condition_entropy = 0.0
for a, v in DA.items():
p = v / len(D)
e = entropy(PCDA[a])
condition_entropy += e * p
if (return_ratio):
return (entropy_D - condition_entropy) / entropy_D
else:
return entropy_D - condition_entropy
def get_result(D, T = -1):
'''获取数据集D中实例数最大的目标特征T的值'''
if (not isinstance(D, (set, list))):
return None
if (not type(T) is int):
return None
count = {}
for t in D:
count[t[T]] = count.get(t[T], 0) + 1
max_count = 0
for key, value in count.items():
if (value > max_count):
max_count = value
result = key
return result
def devide_set(D, A):
'''根据特征A的值把数据集D分裂为多个子集'''
if (not isinstance(D, (set, list))):
return None
if (not type(A) is int):
return None
subset = {}
for t in D:
subset.setdefault(t[A], []).append(t)
return subset
def build_tree(D, A, threshold = 0.0001, T = -1, Tree = None, algo = "ID3"):
'''根据数据集D和特征集A构建决策树.
T为目标属性在元组中的索引 . 目前支持ID3和C4.5两种算法'''
if (Tree != None and not isinstance(Tree, Node)):
return None
if (not isinstance(D, (set, list))):
return None
if (not type(A) is set):
return None
if (None == Tree):
Tree = Node(None, D)
subset = devide_set(D, T)
if (len(subset) <= 1):
for key in subset.keys():
Tree.result = key
del(subset)
return Tree
if (len(A) <= 0):
Tree.result = get_result(D)
return Tree
use_gain_ratio = False if algo == "ID3" else True
max_gain = 0.0
for a in A:
gain = info_gain(D, a, return_ratio = use_gain_ratio)
if (gain > max_gain):
max_gain = gain
attr_id = a # 获取信息增益最大的特征
if (max_gain < threshold):
Tree.result = get_result(D)
return Tree
Tree.attr = attr_id
subD = devide_set(D, attr_id)
del(D[:]) # 删除中间数据,释放内存
Tree.dataset = None
A.discard(attr_id) # 从特征集中排查已经使用过的特征
for key in subD.keys():
tree = Node(Tree, subD.get(key))
Tree.childs[key] = tree
build_tree(subD.get(key), A, threshold, T, tree)
return Tree
def print_brance(brance, target):
odd = 0
for e in brance:
print(e, end = ('=' if odd == 0 else '∧'))
odd = 1 - odd
print("target =", target)
def print_tree(Tree, stack = []):
if (None == Tree):
return
if (None != Tree.result):
print_brance(stack, Tree.result)
return
stack.append(Tree.attr)
for key, value in Tree.childs.items():
stack.append(key)
print_tree(value, stack)
stack.pop()
stack.pop()
def classify(Tree, instance):
if (None == Tree):
return None
if (None != Tree.result):
return Tree.result
return classify(Tree.childs[instance[Tree.attr]], instance)
dataset = [
("青年", "否", "否", "一般", "否")
,("青年", "否", "否", "好", "否")
,("青年", "是", "否", "好", "是")
,("青年", "是", "是", "一般", "是")
,("青年", "否", "否", "一般", "否")
,("中年", "否", "否", "一般", "否")
,("中年", "否", "否", "好", "否")
,("中年", "是", "是", "好", "是")
,("中年", "否", "是", "非常好", "是")
,("中年", "否", "是", "非常好", "是")
,("老年", "否", "是", "非常好", "是")
,("老年", "否", "是", "好", "是")
,("老年", "是", "否", "好", "是")
,("老年", "是", "否", "非常好", "是")
,("老年", "否", "否", "一般", "否")
]
T = build_tree(dataset, set(range(0, len(dataset[0]) - 1)))
print_tree(T)
print(classify(T, ("老年", "否", "否", "一般")))
运行结果如下:
2=否∧1=否∧target = 否
2=否∧1=是∧target = 是
2=是∧target = 是
否