下面是代码实现的部分,写了一个基于CART的分类树,使用的样本就是上面提到的贷款数据,数据如下图:
是一个.txt文档,运行后得到了分类的结果,最终分类的几个集合都只有一个类别,也就是根据这些分类规则,可以完全将数据分开。
完整代码
# 基于CART的决策分类树复现(离散)
import collections
import queue
import numpy as np
from matplotlib import pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
def load_data():
with open('static/data.txt', mode='r', encoding='utf-8') as f:
data=f.read().split('\n')
title=data[0].split(' ')
x=[]
y=[]
for i in range(1,len(data)):
xy = data[i].split(' ')
x.append(xy[:-1])
y.append(xy[-1]) # 最后一个是标签
x = np.array(x)
y = np.array(y)
return title,x,y
class Node():
def __init__(self,node_id=0,deep=0,id_list=None,nxt_list=None,split=True):
self.node_id=node_id
self.deep = deep
if id_list is not None:
self.id_list = np.array(id_list,dtype=int) # 当前节点的索引集合
else:
self.id_list=[]
if nxt_list is not None:
self.nxt_list = np.array(nxt_list) # 当前节点的索引集合
else:
self.nxt_list=np.array([])
self.split=split # 是否需要继续分裂
class CART():
def fit(self,x,y,gini_thresh=0.1):
samples = x.shape[0]
features = x.shape[1]
root = Node(node_id=0,deep=1,id_list=np.arange(samples),nxt_list=[],split=True)
# 先统计y的相关信息
y_cag = collections.Counter(y)
# print('标签统计信息:',y_cag)
# label_list = list(y_cag.keys()) # y的所有类别
tree = [root] # 存储最终的树
q = queue.Queue() # 产生一个队列
q.put(root)
split_cnt = 0 # 记录分裂次数
while not q.empty(): # 取出一个节点
node = q.get() # 移除并返回数据
id_list = node.id_list # 得到当前集合的所有id
label_num = collections.Counter(y[id_list]) # 当前集合样本的所有对应标签的样本数
num_all = id_list.size # 单管集合的所有数据
min_gini = [0,None,0x3f3f3f3f,[],None] # 记录当前集合feature索引和特征名称(分裂信息),以及对应的gini指数,还有集合id
for i in range(features): # 对于每个feature选择
# 求出所有特征类别和对应的id
feat_dict = {}
# 这个地方有问题(不应该统计所有样本,而是当前对应的,应该可以在上面的id循环里面统计掉)
for idx in id_list:
if x[idx,i] not in feat_dict.keys():
feat_dict[x[idx,i]]=[] # 当前id(索引)
feat_dict[x[idx, i]].append(idx)
# 下面枚举将当前特征特征的每个取值作为分割点
for type in feat_dict.keys(): # type作为分割点(统计分割点内的个样本匹配数量)
res = {}
for idx in feat_dict[type]:
if y[idx] not in res:
res[y[idx]]=0
res[y[idx]]+=1
# 根据统计出的数量已经可以计算基尼系数 gini=1-∑p^2
num = len(feat_dict[type]) # 得到数量(为是的)
gini_D1 = 1
for key in res.keys():
gini_D1-=(res[key]/num)**2
gini_D2 = 1
if num_all!=num:
for key in label_num.keys(): # 利用集合总数来推算为否的集合gini
sub = 0 # 要减去的样本(在集合D1的)
if key in res.keys():
sub = res[key]
gini_D2-=((label_num[key]-sub)/(num_all-num))**2
gini = (num/num_all)*gini_D1+((num_all-num)/num_all)*gini_D2
if gini<min_gini[2]:
min_gini[0]=i
min_gini[1] = type # 第i个特征的类别type
min_gini[2]=gini
min_gini[3] = feat_dict[type] # 记录id
min_gini[4]= (gini_D1,gini_D2) # 记录两个集合的gini决定是否继续分裂
# 找到最小的gini进行分裂
split_cnt+=1
# print('总样本集:',id_list)
print('第 %d 次分裂,根据第 %d 个特征的 %s 类别'%(split_cnt,min_gini[0],min_gini[1]))
id_D1 = min_gini[3]
id_D2 = []
# print(id_list)
for id in id_list:
if id not in id_D1:
id_D2.append(id)
# 生成两个节点
id1 = len(tree)
id2 = len(tree)+1 # 即将插入的两个节点的id(也就是在tree中的索引)
tree[node.node_id].nxt_list=[id1,id2]
node1 = Node(node_id=id1,deep=node.deep+1,id_list=id_D1,nxt_list=[])
node2 = Node(node_id=id2,deep=node.deep+1,id_list=id_D2,nxt_list=[])
# 判断是否需要继续分裂(纯度,纯度也就是如果都是一个类别为0就不分裂,还有个用阈值计算,懒得算了)
if min_gini[4][0]<gini_thresh:
node1.split=False # 无需分裂
else:
q.put(node1)
if min_gini[4][1]<gini_thresh:
node2.split=False # 无需分裂
else:
q.put(node2)
tree.append(node1)
tree.append(node2)
# print(tree)
self.tree = tree
def printTree(self):
tree = self.tree
print('----------- CART -----------')
print('()中表示深度,根节点为1')
for subtree in tree:
if subtree.split:
print('(%d)'%(subtree.deep),subtree.id_list, end='')
print(' -> ',end='')
node1 = tree[subtree.nxt_list[0]]
print('(%d)'%(node1.deep),node1.id_list,end=' + ')
node2 = tree[subtree.nxt_list[1]]
print('(%d)'%(node2.deep),node2.id_list)
print('----------------------------')
if __name__ == '__main__':
title,x,y = load_data()
print('********* 特征 *********')
for i in range(len(title)):
print(i+1,title[i])
print('***********************')
dct_cart = CART()
dct_cart.fit(x,y)
dct_cart.printTree()