机器学习算法:决策树

时间:2022-12-20 12:01:09

决策树(Decision Tree)的核心思想是:根据训练样本构建这样一棵树,使得其叶节点是分类标签,非叶节点是判断条件,这样对于一个未知样本,能在树上找到一条路径到达叶节点,就得到了它的分类。


举个简单的例子,如何识别有毒的蘑菇?如果能够得到一棵这样的决策树,那么对于一个未知的蘑菇就很容易判断出它是否有毒了。

                                                    它是什么颜色的?
                               |
                 -------鲜艳---------浅色----
                |                           |
              有毒                      有什么气味?
                                            |
                              -----刺激性--------无味-----
                             |                           |
                            有毒                        安全


构建决策树有很多算法,常用的有ID3、C4.5等。本篇以ID3为研究算法。


构建决策树的关键在于每一次分支时选择哪个特征作为分界条件。这里的原则是:选择最能把数据变得有序的特征作为分界条件。所谓有序,是指划分后,每一个分支集合的分类尽可能一致。用信息论的方式表述,就是选择信息增益最大的方式划分集合。


所谓信息增益(information gain),是指变化前后熵(entropy)的增加量。为了计算熵,需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到:

机器学习算法:决策树

其中H为熵,n为分类数目,p(xi)是选择该分类的概率。


根据公式,计算一个集合熵的方式为:

计算每个分类出现的次数foreach(每一个分类){    计算出现概率    根据概率计算熵    累加熵}return 累加结果


判断如何划分集合,方式为:

foreach(每一个特征){    计算按此特征切分时的熵    计算与切分前相比的信息增益    保留能产生最大增益的特征为切分方式}return 选定的特征


构建树节点的方法为:

if(集合没有特征可用了){    按多数原则决定此节点的分类}else if(集合中所有样本的分类都一致){    此标签就是节点分类}else{    以最佳方式切分集合    每一种可能形成当前节点的一个分支    递归}


OK,上C#版代码,DataVector和上篇文章一样,不放了,只放核心算法:

using System;using System.Collections.Generic;namespace MachineLearning{    /// <summary>    /// 决策树节点    /// </summary>    public class DecisionNode    {        /// <summary>        /// 此节点的分类标签,为空表示此节点不是叶节点        /// </summary>        public string Label { get; set; }        /// <summary>        /// 此节点的划分特征,为-1表示此节点是叶节点        /// </summary>        public int FeatureIndex { get; set; }        /// <summary>        /// 分支        /// </summary>        public Dictionary<string, DecisionNode> Child { get; set; }        public DecisionNode()        {            this.FeatureIndex = -1;            this.Child = new Dictionary<string, DecisionNode>();        }    }}


using System;using System.Collections.Generic;using System.Linq;namespace MachineLearning{    /// <summary>    /// 决策树(ID3算法)    /// </summary>    public class DecisionTree    {        private DecisionNode m_Tree;        /// <summary>        /// 训练        /// </summary>        /// <param name="trainingSet"></param>        public void Train(List<DataVector<string>> trainingSet)        {            var features = new List<int>(trainingSet[0].Dimension);            for(int i = 0;i < trainingSet[0].Dimension;++i)                features.Add(i);                            //生成决策树            m_Tree = CreateTree(trainingSet, features);        }        /// <summary>        /// 分类        /// </summary>        /// <param name="vector"></param>        /// <returns></returns>        public string Classify(DataVector<string> vector)        {            return Classify(vector, m_Tree);        }        /// <summary>        /// 分类        /// </summary>        /// <param name="vector"></param>        /// <param name="node"></param>        /// <returns></returns>        private string Classify(DataVector<string> vector, DecisionNode node)        {            var label = string.Empty;                        if(!string.IsNullOrEmpty(node.Label))            {                //是叶节点,直接返回结果                label = node.Label;            }            else            {                //取需要分类的字段,继续深入                var key = vector.Data[node.FeatureIndex];                if(node.Child.ContainsKey(key))                    label = Classify(vector, node.Child[key]);                else                    label = "[UNKNOWN]";            }            return label;        }                /// <summary>        /// 创建决策树        /// </summary>        /// <param name="dataSet"></param>        /// <param name="features"></param>        /// <returns></returns>        private DecisionNode CreateTree(List<DataVector<string>> dataSet, List<int> features)        {            var node = new DecisionNode();                        if(dataSet[0].Dimension == 0)            {                //所有字段已用完,按多数原则决定Label,结束分类                node.Label = GetMajorLabel(dataSet);            }            else if(dataSet.Count == dataSet.Count(d => string.Equals(d.Label, dataSet[0].Label)))            {                //如果数据集中的Label相同,结束分类                node.Label = dataSet[0].Label;            }            else            {                //挑选一个最佳分类,分割集合,递归                int featureIndex = ChooseBestFeature(dataSet);                node.FeatureIndex = features[featureIndex];                var uniqueValues = GetUniqueValues(dataSet, featureIndex);                features.RemoveAt(featureIndex);                foreach(var value in uniqueValues)                {                    node.Child[value.ToString()] = CreateTree(SplitDataSet(dataSet, featureIndex, value), new List<int>(features));                }            }                        return node;        }                /// <summary>        /// 计算给定集合的香农熵        /// </summary>        /// <param name="dataSet"></param>        /// <returns></returns>        private double ComputeShannon(List<DataVector<string>> dataSet)        {            double shannon = 0.0;                        var dict = new Dictionary<string, int>();            foreach(var item in dataSet)            {                if(!dict.ContainsKey(item.Label))                    dict[item.Label] = 0;                dict[item.Label] += 1;            }                        foreach(var label in dict.Keys)            {                double prob = dict[label] * 1.0 / dataSet.Count;                shannon -= prob * Math.Log(prob, 2);            }                        return shannon;        }                /// <summary>        /// 用给定的方式切分出数据子集        /// </summary>        /// <param name="dataSet"></param>        /// <param name="splitIndex"></param>        /// <param name="value"></param>        /// <returns></returns>        private List<DataVector<string>> SplitDataSet(List<DataVector<string>> dataSet, int splitIndex, string value)        {            var newDataSet = new List<DataVector<string>>();                        foreach(var item in dataSet)            {                //只保留指定维度上符合给定值的项                if(item.Data[splitIndex] == value)                {                    var newItem = new DataVector<string>(item.Dimension - 1);                    newItem.Label = item.Label;                    Array.Copy(item.Data, 0, newItem.Data, 0, splitIndex - 0);                    Array.Copy(item.Data, splitIndex + 1, newItem.Data, splitIndex, item.Dimension - splitIndex - 1);                    newDataSet.Add(newItem);                }            }                        return newDataSet;        }        /// <summary>        /// 在给定的数据集上选择一个最好的切分方式        /// </summary>        /// <param name="dataSet"></param>        /// <returns></returns>        private int ChooseBestFeature(List<DataVector<string>> dataSet)        {            int bestFeature = 0;            double bestInfoGain = 0.0;            double baseShannon = ComputeShannon(dataSet);                        //遍历每一个维度来寻找            for(int i = 0;i < dataSet[0].Dimension;++i)            {                var uniqueValues = GetUniqueValues(dataSet, i);                double newShannon = 0.0;                //遍历此维度下的每一个可能值,切分数据集并计算熵                foreach(var value in uniqueValues)                {                    var subSet = SplitDataSet(dataSet, i, value);                    double prob = subSet.Count * 1.0 / dataSet.Count;                    newShannon += prob * ComputeShannon(subSet);                }                //计算信息增益,保留最佳切分方式                double infoGain = baseShannon - newShannon;                if(infoGain > bestInfoGain)                {                    bestInfoGain = infoGain;                    bestFeature = i;                }            }                        return bestFeature;        }        /// <summary>        /// 数据去重        /// </summary>        /// <param name="dataSet"></param>        /// <param name="index"></param>        /// <returns></returns>        private List<string> GetUniqueValues(List<DataVector<string>> dataSet, int index)        {            var dict = new Dictionary<string, int>();            foreach(var item in dataSet)            {                dict[item.Data[index]] = 0;            }            return dict.Keys.ToList<string>();        }        /// <summary>        /// 取多数标签        /// </summary>        /// <param name="dataSet"></param>        /// <returns></returns>        private string GetMajorLabel(List<DataVector<string>> dataSet)        {            var dict = new Dictionary<string, int>();            foreach(var item in dataSet)            {                if(!dict.ContainsKey(item.Label))                    dict[item.Label] = 0;                dict[item.Label]++;            }            string label = string.Empty;            int count = -1;            foreach(var key in dict.Keys)            {                if(dict[key] > count)                {                    label = key;                    count = dict[key];                }            }                        return label;        }    }}



拿个例子实际检验一下,还是以毒蘑菇的识别为例,从这里找了点数据,http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data ,它整理了8000多个样本,每个样本描述了蘑菇的22个属性,比如形状、气味等等,然后给出了这个蘑菇是否可食用。


比如一行数据:p,x,s,n,t,p,f,c,n,k,e,e,s,s,w,w,p,w,o,p,k,s,u

第0个元素p表示poisonous(有毒),其它22个元素分别是蘑菇的属性,可以参见agaricus-lepiota.names的描述,但实际上根本不用关心具体含义。以此构建样本并测试错误率:

public void TestDecisionTree(){    var trainingSet = new List<DataVector<string>>();    //训练数据集    var testSet = new List<DataVector<string>>();        //测试数据集        //读取数据    var file = new StreamReader("agaricus-lepiota.data", Encoding.Default);    string line = string.Empty;    int count = 0;    while((line = file.ReadLine()) != null)    {        var parts = line.Split(',');                var p = new DataVector<string>(22);        p.Label = parts[0];        for(int i = 0;i < p.Dimension;++i)            p.Data[i] = parts[i + 1];                    //前7000作为训练样本,其余作为测试样本        if(++count <= 7000)            trainingSet.Add(p);        else            testSet.Add(p);    }    file.Close();    //检验    var dt = new DecisionTree();    dt.Train(trainingSet);    int error = 0;    foreach(var p in testSet)    {        //做猜测分类,并与实际结果比较        var label = dt.Classify(p);        if(label != p.Label)            ++error;    }    Console.WriteLine("Error = {0}/{1}, {2}%", error, testSet.Count, (error * 100.0 / testSet.Count));}


使用7000个样本做训练,1124个样本做测试,只有4个猜测出错,错误率仅有0.35%,相当不错的结果。


生成的决策树是这样的:

{    "FeatureIndex": 4,              //按第4个特征划分    "Child": {        "p": {"Label": "p"},        //如果第4个特征是p,则分类为p        "a": {"Label": "e"},        //如果第4个特征是a,则分类是e        "l": {"Label": "e"},        "n": {            "FeatureIndex": 19,            //如果第4个特征是n,要继续按第19个特征划分            "Child": {                "n": {"Label": "e"},                "k": {"Label": "e"},                "w": {                    "FeatureIndex": 21,                    "Child": {                        "w": {"Label": "e"},                        "l": {                            "FeatureIndex": 2,                            "Child": {                                "c": {"Label": "e"},                                "n": {"Label": "e"},                                "w": {"Label": "p"},                                "y": {"Label": "p"}                            }                        },                        "d": {                            "FeatureIndex": 1,                            "Child": {                                "y": {"Label": "p"},                                "f": {"Label": "p"},                                "s": {"Label": "e"}                            }                        },                        "g": {"Label": "e"},                        "p": {"Label": "e"}                    }                },                "h": {"Label": "e"},                "r": {"Label": "p"},                "o": {"Label": "e"},                "y": {"Label": "e"},                "b": {"Label": "e"}            }        },        "f": {"Label": "p"},        "c": {"Label": "p"},        "y": {"Label": "p"},        "s": {"Label": "p"},        "m": {"Label": "p"}    }}

可以看到,实际只使用了其中的5个特征,就能做出比较精确的判断了。


决策树还有一个很棒的优点就是能告诉我们多个特征中哪个对判别最有用,比如上面的树,根节点是特征4,参考agaricus-lepiota.names得知这个特征是指气味(odor),只要有气味,就可以直接得出结论,如果是无味的(n=none),下一个重要特征是19,即孢子印的颜色(spore-print-color)。




本文出自 “兔子窝” 博客,请务必保留此出处http://boytnt.blog.51cto.com/966121/1569763