决策树方法最早产生于上世纪60年代,到70年代末。由J Ross Quinlan提出了ID3算法,此算法的目的在于减少树的深度。但是忽略了叶子数目的研究。C4.5算法在ID3算法的基础上进行了改进,对于预测变量的缺值处理、剪枝技术、派生规则等方面作了较大改进,既适合于分类问题,又适合于回归问题。分类与回归树CART 模型最早由Breiman 等人提出,也已经在统计领域和数据挖掘技术中普遍使用。本章将对这三种常见的决策树算法进行简单介绍。
八、信息增益选择属性-ID3
S是一个训练样本的集合,该样本中每个集合的类编号已知。每个样本为一个元组,有个属性用来判定某个训练样本的类编号。
假设S中有
一个有
因为此处熵是用来描述集合数据的混乱程度:数据越混乱,熵越大;数据越统一,熵越小。在等式Gain(A)中,被减数
九、ID3算法举例
以AllElectronics顾客数据库标记类的训练元组为例。该数据集D拥有4个特征:age、income、student和credit_rating,计算基于熵的度量——信息增益,作为样本划分的根据:
Gain(age)=0.246,Gain(income)=0.029,Gain(student)=0.151,Gain(credit_rating)=0.048.
然后,对测试属性每个已知的值,创建一个分支,并以此划分样本,得到第一次划分。
因数据子集D2均属于同一类别,无需再迭代。接下来对数据子集D1和D3执行ID3算法,得到最终的决策树。
十、信息增益率-C4.5
以信息增益作为划分准则训练数据集的特征,存在偏向于选择取值较多的特征的问题,从而导致训练的决策树分支较多。使用信息增益比可以对这一问题进行校正。这便是C4.5算法对ID3算法的改进优化。
特征
其中
与ID3算法相比,C4.5算法选择信息增益率最大的属性进行分支,整体上看,分支更明确,获得有用信息更多。因C4.5的套路跟ID3相差不大,只是评判标准稍有改变,因此不再举例说明。
十一、 Gini指标-CART
CART模型是Breiman等人在1984年提出,其假设决策树是二分树,内部节点特征的取值为“是”或“否”,左边分支是取值为“是”的分支,右边分支是取值为“否”的分支。在生成决策树的过程中会递归地二分每个特征。此时选用基尼指数选择最优特征,并决定该特征的最优二值切分点。具体的实现过程是:
设属性
(1)选一个属性
(2)递归处理,将上面得到的两部分按步骤(1)重新选取一个属性继续划分,直到把整个
基尼指数:对于给定的样本集S,假设有
如果样本集合S根据特征
基尼指数
十二、CART算法举例(分类)
下面举个简单的例子,样本集如下表所示:
ID | 有房者 | 婚姻状况 | 年收入 | 拖欠贷款 |
---|---|---|---|---|
1 | 是 | 单身 | 125K | 否 |
2 | 否 | 已婚 | 100K | 否 |
3 | 否 | 单身 | 70K | 否 |
4 | 是 | 已婚 | 120K | 否 |
5 | 否 | 离异 | 95K | 是 |
6 | 否 | 已婚 | 60K | 否 |
7 | 是 | 离异 | 220K | 否 |
8 | 否 | 单身 | 85K | 是 |
9 | 否 | 已婚 | 75K | 否 |
10 | 否 | 单身 | 90K | 是 |
在上述图中,属性有3个,分别是有房情况,婚姻状况和年收入,其中有房情况和婚姻 状况是离散的取值,而年收入是连续的取值。拖欠贷款者属于分类的结果。
假设现在来看有房情况这个属性,那么按照它划分后的Gini指数计算如下:
而对于婚姻状况属性来说,它的取值有3种,按照每种属性值分裂后Gini指标计算如下:
最后还有一个取值连续的属性:年收入,它的取值是连续的。对于连续值处理引进“分裂点”的思想,假设样本集中某个属性
在有房者、婚姻状况、年收入几个特征中,Gini(年收入 = 97)与Gini(婚姻状况 = 单身或离异)最小,所以可以选择特征“年收入”为最优特征,年收入为97K为其最优切分 点,于是根节点生成两个子节点,一个是叶节点。对另一个节点继续使用以上方法在有房 者、婚姻状况中选择最优特征及其最优切分点:
因为上表左边的数据均已属于一类,因此递归终止。基于右边数据,现在继续来看有房情况这个属性,那么按照它划分后的Gini指数计算如下:
而对于婚姻状况属性来说,它的取值有3种,按照每种属性值分裂后Gini指标计算如下:
在有房者、婚姻状况这两个特征中,Gini(婚姻状况 = 单身或离异)最小,所以可以选择 特征“婚姻状况”为最优特征,单身或离异为其最优切分点,于是生成两个子节点,一个是 叶节点。对另一个节点继续使用以上方法在有房者、婚姻状况中选择最优特征及其最优切分 点。根据这样的分裂规则CART算法就能完成建树过程。
十三、 决策树算法实战
因网上已有很多基于scikit-learn的Python代码(可参考: http://www.cnblogs.com/pinard/p/6056319.html),本期给大家换个口味,用R实验。
set.seed(1234)
index <-sample(1:nrow(iris),100)
iris.train <-iris[index,]
iris.test <-iris[-index,]
#第二步:加载包含CART算法的R包
library(rpart)
library(rpart.plot);
#第三步:构建CART模型
model.CART <-rpart(Species~.,data=iris.train)
#第四步:模型应用到测试集
results.CART <-predict(model.CART,newdata=iris.test, type="class")
#第五步:生成混淆矩阵
table(results.CART, iris.test$Species)
rpart.plot(model.CART, branch=1, branch.type=2, type=1, extra=102,
shadow.col="gray", box.col="green",
border.col="blue", split.col="red",
split.cex=1.2, main="CART-IRIS");