决策树算法简介及其MATLAB实现代码

时间:2024-01-31 11:59:24

目录

决策树原理概述

决策树的经典算法:ID3算法

改进:C4.5算法

MATLAB实现决策树分类算法


决策树原理概述

  • 决策树通过把样本实例从根节点排列到某个叶子节点来对其进行分类。树上的每个非叶子节点代表对一个属性取值的测试, 其分支就代表测试的每个结果(yes no表示正类、负类);而树上的每个叶子节点均代表一个分类的类别,树的最高层节点是根节点。当所有叶子节点给出的分类结果都一样时,就结束生长,即已经可以判定样本的类别。

  • 根节点并没有什么实际的意义。

  • 简单地说,决策树就是一个类似流程图的树形结构,采用自顶向下的递归方式,从树的根节点开始,在它的内部节点上进行属性值的测试比较,然后按照给定实例的属性值确定对应的分支,最后在决策树的叶子节点得到结论。这个过程在以新的节点为根的子树上重复。直到所有新节点给出的结果一致或足以判断分类。

  • 决策树其实很好理解。举个例子,它就像我们玩的猜谜底游戏。B向A提问,每次可以问不同的问题,而A只能回答是或不是,对或不对。通过多次发问,B越来越接近正确答案。这里,每个问题实际上就是非叶子节点的属性测试,是或者不是就是给出测试结果yes or no。如果一个谜底符合你所有问题(属性),得到答案一致,那么你一定能肯定这个谜底是什么。

决策树的经典算法:ID3算法

信息增益越大代表这个属性中包含的信息量越多。因为它的定义式实际上是熵的变化。

改进:C4.5算法

针对ID3算法中可能存在的问题,学者提出了一些改进。

针对上述两种算法,具体解释和举例可以参考:《数据挖掘系列(6)决策树分类算法》,此处不再赘述。

决策树的优缺点

优点:

–  决策树易于理解和实现。 人们在通过解释后都有能力去理解决策树所表达的意义。

–  对于决策树,数据的准备往往是简单或者是不必要的。其他的技术往往要求先把数据归一化,比如去掉多余的 或者空白的属性。

–  能够同时处理数据型和常规型属性。 其他的技术往往要求数据属性的单一。

–  是一个白盒模型。如果给定一个观察的模型,那么根据所产生的决策树很容易推出相应的逻辑表达式。

缺点:

– 对于各类别样本数量不一致的数据,在决策树当中信息增益的结果偏向于那些具有更多数值的特征。

– 决策树内部节点的判别具有明确性,这种明确性可能会带来误导。

MATLAB实现决策树分类算法

%% I. 清空环境变量
clear all
clc
warning off

%% II. 导入数据
load data.mat

%%
% 1. 随机产生训练集/测试集
a = randperm(569);
Train = data(a(1:500),:);
Test = data(a(501:end),:);

%%
% 2. 训练数据
P_train = Train(:,3:end);
T_train = Train(:,2);

%%
% 3. 测试数据
P_test = Test(:,3:end);
T_test = Test(:,2);

%% III. 创建决策树分类器
ctree = ClassificationTree.fit(P_train,T_train);

%%
% 1. 查看决策树视图
view(ctree);
view(ctree,\'mode\',\'graph\');

%% IV. 仿真测试
T_sim = predict(ctree,P_test);

%% V. 结果分析
count_B = length(find(T_train == 1));
count_M = length(find(T_train == 2));
rate_B = count_B / 500;
rate_M = count_M / 500;
total_B = length(find(data(:,2) == 1));
total_M = length(find(data(:,2) == 2));
number_B = length(find(T_test == 1));
number_M = length(find(T_test == 2));
number_B_sim = length(find(T_sim == 1 & T_test == 1));
number_M_sim = length(find(T_sim == 2 & T_test == 2));
disp([\'病例总数:\' num2str(569)...
      \'  良性:\' num2str(total_B)...
      \'  恶性:\' num2str(total_M)]);
disp([\'训练集病例总数:\' num2str(500)...
      \'  良性:\' num2str(count_B)...
      \'  恶性:\' num2str(count_M)]);
disp([\'测试集病例总数:\' num2str(69)...
      \'  良性:\' num2str(number_B)...
      \'  恶性:\' num2str(number_M)]);
disp([\'良性乳腺肿瘤确诊:\' num2str(number_B_sim)...
      \'  误诊:\' num2str(number_B - number_B_sim)...
      \'  确诊率p1=\' num2str(number_B_sim/number_B*100) \'%\']);
disp([\'恶性乳腺肿瘤确诊:\' num2str(number_M_sim)...
      \'  误诊:\' num2str(number_M - number_M_sim)...
      \'  确诊率p2=\' num2str(number_M_sim/number_M*100) \'%\']);
  
%% VI. 叶子节点含有的最小样本数对决策树性能的影响
leafs = logspace(1,2,10);

N = numel(leafs);

err = zeros(N,1);
for n = 1:N
    t = ClassificationTree.fit(P_train,T_train,\'crossval\',\'on\',\'minleaf\',leafs(n));
    err(n) = kfoldLoss(t);
end
plot(leafs,err);
xlabel(\'叶子节点含有的最小样本数\');
ylabel(\'交叉验证误差\');
title(\'叶子节点含有的最小样本数对决策树性能的影响\')

%% VII. 设置minleaf为13,产生优化决策树
OptimalTree = ClassificationTree.fit(P_train,T_train,\'minleaf\',13);
view(OptimalTree,\'mode\',\'graph\')

%%
% 1. 计算优化后决策树的重采样误差和交叉验证误差
resubOpt = resubLoss(OptimalTree)
lossOpt = kfoldLoss(crossval(OptimalTree))

%%
% 2. 计算优化前决策树的重采样误差和交叉验证误差
resubDefault = resubLoss(ctree)
lossDefault = kfoldLoss(crossval(ctree))

%% VIII. 剪枝
[~,~,~,bestlevel] = cvLoss(ctree,\'subtrees\',\'all\',\'treesize\',\'min\')
cptree = prune(ctree,\'Level\',bestlevel);
view(cptree,\'mode\',\'graph\')

%%
% 1. 计算剪枝后决策树的重采样误差和交叉验证误差
resubPrune = resubLoss(cptree)
lossPrune = kfoldLoss(crossval(cptree))