决策树系列(三)——ID3

时间:2022-06-15 19:49:09

预备知识:决策树

初识ID3

回顾决策树的基本知识,其构建过程主要有下述三个重要的问题:

(1)数据是怎么分裂的

(2)如何选择分类的属性

(3)什么时候停止分裂

从上述三个问题出发,以实际的例子对ID3算法进行阐述。

例:通过当天的天气、温度、湿度和季节预测明天的天气

                                  表1 原始数据

当天天气

温度

湿度

季节

明天天气

25

50

春天

21

48

春天

18

70

春天

28

41

夏天

8

65

冬天

18

43

夏天

24

56

秋天

18

76

秋天

31

61

夏天

6

43

冬天

15

55

秋天

4

58

冬天

 1.数据分割

对于离散型数据,直接按照离散数据的取值进行分裂,每一个取值对应一个子节点,以“当前天气”为例对数据进行分割,如图1所示。

决策树系列(三)——ID3

对于连续型数据,ID3原本是没有处理能力的,只有通过离散化将连续性数据转化成离散型数据再进行处理。

连续数据离散化是另外一个课题,本文不深入阐述,这里直接采用等距离数据划分的李算话方法。该方法先对数据进行排序,然后将连续型数据划分为多个区间,并使每一个区间的数据量基本相同,以温度为例对数据进行分割,如图2所示。

决策树系列(三)——ID3

 2. 选择最优分裂属性

ID3采用信息增益作为选择最优的分裂属性的方法,选择熵作为衡量节点纯度的标准,信息增益的计算公式如下:

决策树系列(三)——ID3

其中, 决策树系列(三)——ID3表示父节点的熵; 决策树系列(三)——ID3表示节点i的熵,熵越大,节点的信息量越多,越不纯; 决策树系列(三)——ID3表示子节点i的数据量与父节点数据量之比。 决策树系列(三)——ID3越大,表示分裂后的熵越小,子节点变得越纯,分类的效果越好,因此选择 决策树系列(三)——ID3最大的属性作为分裂属性。

对上述的例子的跟节点进行分裂,分别计算每一个属性的信息增益,选择信息增益最大的属性进行分裂。

天气属性:(数据分割如上图1所示)

  决策树系列(三)——ID3

温度:(数据分割如上图2所示)

决策树系列(三)——ID3

湿度:

决策树系列(三)——ID3

决策树系列(三)——ID3

季节:

决策树系列(三)——ID3

决策树系列(三)——ID3

由于决策树系列(三)——ID3最大,所以选择属性“季节”作为根节点的分裂属性。

3.停止分裂的条件

停止分裂的条件已经在决策树中阐述,这里不再进行阐述。

(1)最小节点数

  当节点的数据量小于一个指定的数量时,不继续分裂。两个原因:一是数据量较少时,再做分裂容易强化噪声数据的作用;二是降低树生长的复杂性。提前结束分裂一定程度上有利于降低过拟合的影响。

  (2)熵或者基尼值小于阀值。

由上述可知,熵和基尼值的大小表示数据的复杂程度,当熵或者基尼值过小时,表示数据的纯度比较大,如果熵或者基尼值小于一定程度时,节点停止分裂。

  (3)决策树的深度达到指定的条件

   节点的深度可以理解为节点与决策树跟节点的距离,如根节点的子节点的深度为1,因为这些节点与跟节点的距离为1,子节点的深度要比父节点的深度大1。决策树的深度是所有叶子节点的最大深度,当深度到达指定的上限大小时,停止分裂。

  (4)所有特征已经使用完毕,不能继续进行分裂。

被动式停止分裂的条件,当已经没有可分的属性时,直接将当前节点设置为叶子节点。

程序设计及源代码(C#版本)

(1)数据处理

用二维数组存储原始的数据,每一行表示一条记录,前n-1列表示数据的属性,第n列表示分类的标签。

   static double[][] allData;

   为了方便后面的处理,对离散属性进行数字化处理,将离散值表示成数字,并用一个链表数组进行存储,数组的第一个元素表示属性1的离散值。

   static List<String>[] featureValues;

那么经过处理后的表1数据可以转化为如表2所示的数据:

      表2 初始化后的数据

当天天气

温度

湿度

季节

明天天气

1

25

50

1

1

2

21

48

1

2

2

18

70

1

3

1

28

41

2

1

3

8

65

3

2

1

18

43

2

1

2

24

56

4

1

3

18

76

4

2

3

31

61

2

1

2

6

43

3

3

1

15

55

4

2

3

4

58

3

3

其中,对于当天天气属性,数字{1,2,3}分别表示{晴,阴,雨};对于季节属性{1,2,3,4}分别表示{春天、夏天、冬天、秋天};对于明天天气{1,2,3}分别表示{晴、阴、雨}。

(2)两个类:节点类和分裂信息

  a)节点类Node

该类存储了节点的信息,包括节点的数据量、节点选择的分裂属性、节点输出类、子节点的个数、子节点的分类误差等。

     class Node
{
/// <summary>
/// 各个子节点的取值
/// </summary>
public List<String> features { get; set; }
/// <summary>
/// 分裂属性的类型
/// </summary>
public String feature_Type { get; set; }
/// <summary>
/// 分裂的属性
/// </summary>
public String SplitFeature { get; set; }
/// <summary>
/// 节点对应各个分类的数目
/// </summary>
public double[] ClassCount { get; set; }
/// <summary>
/// 各个孩子节点
/// </summary>
public List<Node> childNodes { get; set; }
/// <summary>
/// 父亲节点(未用到)
/// </summary>
public Node Parent { get; set; }
/// <summary>
/// 占比最大的类别
/// </summary>
public String finalResult { get; set; }
/// <summary>
/// 数的深度
/// </summary>
public int deep { get; set; }
/// <summary>
/// 该节点占比最大的类标号
/// </summary>
public int result { get; set; }
/// <summary>
/// 节点的数量
/// </summary>
public int rowCount{ get; set; } public void setClassCount(double[] count)
{
this.ClassCount = count;
double max = ClassCount[];
int result = ;
for (int i = ; i < ClassCount.Length; i++)
{
if (max < ClassCount[i])
{
max = ClassCount[i];
result = i;
}
}
//wrong = Convert.ToInt32(nums.Count - ClassCount[result]);
this.result = result;
}
}

  b)分裂信息类SplitInfo

该类存储节点进行分裂的信息,包括各个子节点的行坐标、子节点各个类的数目、该节点分裂的属性、属性的类型等。

     class SplitInfo
{
/// <summary>
/// 分裂的列下标
/// </summary>
public int splitIndex { get; set; }
/// <summary>
/// 数据类型
/// </summary>
public int type { get; set; }
/// <summary>
/// 分裂属性的取值
/// </summary>
public List<String> features { get; set; }
/// <summary>
/// 各个节点的行坐标链表
/// </summary>
public List<int>[] temp { get; set; }
/// <summary>
/// 每个节点各类的数目
/// </summary>
public double[][] class_Count { get; set; }
}

(3)节点分裂方法findBestSplit(Node node,List<int> nums,int[] isUsed),该方法对节点进行分裂,返回值Node

其中:

node表示即将进行分裂的节点;

nums表示节点数据对应的行坐标列表;

isUsed表示到该节点位置所有属性的使用情况(1:表示该属性不能再次使用,0:表示该属性可以使用);

findBestSplit主要有以下几个组成部分:

1)节点分裂停止的判定

判断节点是否需要继续分裂,分裂判断条件如上文所述。源代码如下:

         public static Object[] ifEnd(Node node, double entropy,int[] isUsed)
{
try
{
double[] count = node.ClassCount;
int rowCount = node.rowCount;
int maxResult = ;
double maxRate = ;
#region 数达到某一深度
int deep = node.deep;
if (deep >= maxDeep)
{
maxResult = node.result + ;
node.feature_Type=("result");
node.features=(new List<String>() { maxResult + "" });
return new Object[] { true, node };
}
#endregion
#region 纯度(其实跟后面的有点重了,记得要修改)
//maxResult = 1;
//for (int i = 1; i < count.Length; i++)
//{
// if (count[i] / rowCount >= 0.95)
// {
// node.setFeatureType("result");
// node.setFeatures(new List<String> { "" + (i + 1) });
// return new Object[] { true, node };
// }
//}
//node.setLeafWrong(rowCount - Convert.ToInt32(count[maxResult - 1]));
#endregion
#region 熵为0
if (entropy == )
{
maxRate = count[] / rowCount;
maxResult = ;
for (int i = ; i < count.Length; i++)
{
if (count[i] / rowCount >= maxRate)
{
maxRate = count[i] / rowCount;
maxResult = i + ;
}
}
node.feature_Type=("result");
node.features=(new List<String> { maxResult + "" });
return new Object[] { true, node };
}
#endregion
#region 属性已经分完
//int[] isUsed = node.;
bool flag = true;
for (int i = ; i < isUsed.Length - ; i++)
{
if (isUsed[i] == )
{
flag = false;
break;
}
}
if (flag)
{
maxRate = count[] / rowCount;
maxResult = ;
for (int i = ; i < count.Length; i++)
{
if (count[i] / rowCount >= maxRate)
{
maxRate = count[i] / rowCount;
maxResult = i + ;
}
}
node.feature_Type=("result");
node.features=(new List<String> { "" + (maxResult) });
return new Object[] { true, node };
}
#endregion
#region 数据量少于100
if (rowCount < Limit_Node)
{
maxRate = count[] / rowCount;
maxResult = ;
for (int i = ; i < count.Length; i++)
{
if (count[i] / rowCount >= maxRate)
{
maxRate = count[i] / rowCount;
maxResult = i + ;
}
}
node.feature_Type=("result");
node.features=(new List<String> { "" + (maxResult) });
return new Object[] { true, node };
}
#endregion
return new Object[] { false, node };
}
catch (Exception e)
{
return new Object[] { false, node };
}
}

2)寻找最优的分裂属性

寻找最优的分裂属性需要计算每一个分裂属性分裂后的信息增益,计算公式上文已给出,其中熵的计算代码如下:

         public static double CalEntropy(double[] counts, int countAll)
{
try
{
double allShang = ;
for (int i = ; i < counts.Length; i++)
{
if (counts[i] == )
{
continue;
}
double rate = counts[i] / countAll;
allShang = allShang + rate * Math.Log(rate, );
}
return -allShang;
}
catch (Exception e)
{
return ;
}
}

3)进行分裂,同时子节点也执行相同的分类步骤

其实就是递归的过程,对每一个子节点执行findBestSplit方法进行分裂。

全部源代码:

         #region ID3核心算法
/// <summary>
/// 测试
/// </summary>
/// <param name="node"></param>
/// <param name="data"></param>
public static String findResult(Node node, String[] data)
{
List<String> featrues = node.features;
String type = node.feature_Type;
if (type == "result")
{
return featrues[];
}
int split = Convert.ToInt32(node.SplitFeature);
List<Node> childNodes = node.childNodes;
double[] resultCount = node.ClassCount;
if (type == "连续")
{ for (int i = ; i < featrues.Count; i++)
{
double value = Convert.ToDouble(featrues[i]);
if (Convert.ToDouble(data[split]) <= value)
{
return findResult(childNodes[i], data);
}
}
return findResult(childNodes[featrues.Count], data);
}
else
{
for (int i = ; i < featrues.Count; i++)
{
if (data[split] == featrues[i])
{
return findResult(childNodes[i], data);
}
if (i == featrues.Count - )
{
double count = resultCount[];
int maxInt = ;
for (int j = ; j < resultCount.Length; j++)
{
if (count < resultCount[j])
{
count = resultCount[j];
maxInt = j;
}
}
return findResult(childNodes[], data);
}
}
}
return null;
}
/// <summary>
/// 判断是否还需要分裂
/// </summary>
/// <param name="node"></param>
/// <returns></returns>
public static Object[] ifEnd(Node node, double entropy,int[] isUsed)
{
try
{
double[] count = node.ClassCount;
int rowCount = node.rowCount;
int maxResult = ;
double maxRate = ;
#region 数达到某一深度
int deep = node.deep;
if (deep >= maxDeep)
{
maxResult = node.result + ;
node.feature_Type=("result");
node.features=(new List<String>() { maxResult + "" });
return new Object[] { true, node };
}
#endregion
#region 纯度(其实跟后面的有点重了,记得要修改)
//maxResult = 1;
//for (int i = 1; i < count.Length; i++)
//{
// if (count[i] / rowCount >= 0.95)
// {
// node.setFeatureType("result");
// node.setFeatures(new List<String> { "" + (i + 1) });
// return new Object[] { true, node };
// }
//}
//node.setLeafWrong(rowCount - Convert.ToInt32(count[maxResult - 1]));
#endregion
#region 熵为0
if (entropy == )
{
maxRate = count[] / rowCount;
maxResult = ;
for (int i = ; i < count.Length; i++)
{
if (count[i] / rowCount >= maxRate)
{
maxRate = count[i] / rowCount;
maxResult = i + ;
}
}
node.feature_Type=("result");
node.features=(new List<String> { maxResult + "" });
return new Object[] { true, node };
}
#endregion
#region 属性已经分完
//int[] isUsed = node.;
bool flag = true;
for (int i = ; i < isUsed.Length - ; i++)
{
if (isUsed[i] == )
{
flag = false;
break;
}
}
if (flag)
{
maxRate = count[] / rowCount;
maxResult = ;
for (int i = ; i < count.Length; i++)
{
if (count[i] / rowCount >= maxRate)
{
maxRate = count[i] / rowCount;
maxResult = i + ;
}
}
node.feature_Type=("result");
node.features=(new List<String> { "" + (maxResult) });
return new Object[] { true, node };
}
#endregion
#region 数据量少于100
if (rowCount < Limit_Node)
{
maxRate = count[] / rowCount;
maxResult = ;
for (int i = ; i < count.Length; i++)
{
if (count[i] / rowCount >= maxRate)
{
maxRate = count[i] / rowCount;
maxResult = i + ;
}
}
node.feature_Type=("result");
node.features=(new List<String> { "" + (maxResult) });
return new Object[] { true, node };
}
#endregion
return new Object[] { false, node };
}
catch (Exception e)
{
return new Object[] { false, node };
}
}
#region 排序算法
public static void InsertSort(double[] values, List<int> arr, int StartIndex, int endIndex)
{
for (int i = StartIndex + ; i <= endIndex; i++)
{
int key = arr[i];
double init = values[i];
int j = i - ;
while (j >= StartIndex && values[j] > init)
{
arr[j + ] = arr[j];
values[j + ] = values[j];
j--;
}
arr[j + ] = key;
values[j + ] = init;
}
}
static int SelectPivotMedianOfThree(double[] values, List<int> arr, int low, int high)
{
int mid = low + ((high - low) >> );//计算数组中间的元素的下标 //使用三数取中法选择枢轴
if (values[mid] > values[high])//目标: arr[mid] <= arr[high]
{
swap(values, arr, mid, high);
}
if (values[low] > values[high])//目标: arr[low] <= arr[high]
{
swap(values, arr, low, high);
}
if (values[mid] > values[low]) //目标: arr[low] >= arr[mid]
{
swap(values, arr, mid, low);
}
//此时,arr[mid] <= arr[low] <= arr[high]
return low;
//low的位置上保存这三个位置中间的值
//分割时可以直接使用low位置的元素作为枢轴,而不用改变分割函数了
}
static void swap(double[] values, List<int> arr, int t1, int t2)
{
double temp = values[t1];
values[t1] = values[t2];
values[t2] = temp;
int key = arr[t1];
arr[t1] = arr[t2];
arr[t2] = key;
}
static void QSort(double[] values, List<int> arr, int low, int high)
{
int first = low;
int last = high; int left = low;
int right = high; int leftLen = ;
int rightLen = ; if (high - low + < )
{
InsertSort(values, arr, low, high);
return;
} //一次分割
int key = SelectPivotMedianOfThree(values, arr, low, high);//使用三数取中法选择枢轴
double inti = values[key];
int currentKey = arr[key]; while (low < high)
{
while (high > low && values[high] >= inti)
{
if (values[high] == inti)//处理相等元素
{
swap(values, arr, right, high);
right--;
rightLen++;
}
high--;
}
arr[low] = arr[high];
values[low] = values[high];
while (high > low && values[low] <= inti)
{
if (values[low] == inti)
{
swap(values, arr, left, low);
left++;
leftLen++;
}
low++;
}
arr[high] = arr[low];
values[high] = values[low];
}
arr[low] = currentKey;
values[low] = values[key];
//一次快排结束
//把与枢轴key相同的元素移到枢轴最终位置周围
int i = low - ;
int j = first;
while (j < left && values[i] != inti)
{
swap(values, arr, i, j);
i--;
j++;
}
i = low + ;
j = last;
while (j > right && values[i] != inti)
{
swap(values, arr, i, j);
i++;
j--;
}
QSort(values, arr, first, low - - leftLen);
QSort(values, arr, low + + rightLen, last);
}
#endregion
/// <summary>
/// 寻找最佳的分裂点
/// </summary>
/// <param name="num"></param>
/// <param name="node"></param>
public static Node findBestSplit(Node node, int lastCol,List<int> nums,int[] isUsed)
{
try
{
//判断是否继续分裂
double totalShang = CalEntropy(node.ClassCount, nums.Count);
Object[] check = ifEnd(node, totalShang, isUsed);
if ((bool)check[])
{
node = (Node)check[];
return node;
}
#region 变量声明
SplitInfo info = new SplitInfo();
//int[] isUsed = node.getUsed(); //连续变量or离散变量
//List<int> nums = node.getNum(); //样本的标号
int RowCount = nums.Count; //样本总数
double jubuMax = ; //局部最大熵
#endregion
for (int i = ; i < isUsed.Length - ; i++)
{
if (isUsed[i] == )
{
continue;
}
#region 离散变量
if (type[i] == )
{
int[] allFeatureCount = new int[]; //所有类别的数量
double[][] allCount = new double[allNum[i]][];
for (int j = ; j < allCount.Length; j++)
{
allCount[j] = new double[classCount];
}
int[] countAllFeature = new int[allNum[i]];
List<int>[] temp = new List<int>[allNum[i]];
for (int j = ; j < temp.Length; j++)
{
temp[j] = new List<int>();
}
for (int j = ; j < nums.Count; j++)
{
int index = Convert.ToInt32(allData[nums[j]][i]);
temp[index - ].Add(nums[j]);
countAllFeature[index - ]++;
allCount[index - ][Convert.ToInt32(allData[nums[j]][lieshu - ]) - ]++;
}
double allShang = ;
for (int j = ; j < allCount.Length; j++)
{
allShang = allShang + CalEntropy(allCount[j], countAllFeature[j]) * countAllFeature[j] / RowCount;
}
allShang = (totalShang - allShang);
if (allShang > jubuMax)
{
info.features=new List<String>();
info.type=;
info.temp=(temp);
info.splitIndex=(i);
info.class_Count=(allCount);
jubuMax = allShang;
allFeatureCount = countAllFeature;
}
}
#endregion
#region 连续变量
else
{
double[] leftCount = new double[classCount]; //做节点各个类别的数量
double[] rightCount = new double[classCount]; //右节点各个类别的数量
double[] values = new double[nums.Count];
List<String> List_Feature = new List<string>();
for (int j = ; j < values.Length; j++)
{
values[j] = allData[nums[j]][i];
}
QSort(values, nums, , nums.Count - );
int eachNum = nums.Count / ;
double lianxuMax = ; //连续型属性的最大熵
int index = ;
double[][] counts = new double[][];
List<int>[] temp = new List<int>[];
for (int j = ; j < ; j++)
{
counts[j] = new double[classCount];
temp[j] = new List<int>();
}
for (int j = ; j < nums.Count - ; j++)
{
if (j >= index * eachNum&&index<)
{
List_Feature.Add(allData[nums[j]][i]+"");
lianxuMax += eachNum*CalEntropy(counts[index - ], eachNum)/RowCount;
index++;
}
temp[index-].Add(nums[j]);
counts[index - ][Convert.ToInt32(allData[nums[j]][lieshu - ])-]++;
}
lianxuMax += ((eachNum + nums.Count % )*CalEntropy(counts[index - ], eachNum + nums.Count % ) / RowCount);
lianxuMax = totalShang - lianxuMax;
if (lianxuMax > jubuMax)
{
info.splitIndex=(i);
info.features=(List_Feature);
info.type=();
jubuMax = lianxuMax;
info.temp=(temp);
info.class_Count=(counts);
}
}
#endregion
}
#region 如何找不到最佳的分裂属性,则设为叶节点
if (info.splitIndex == -)
{
double[] finalCount = node.ClassCount;
double max = finalCount[];
int result = ;
for (int i = ; i < finalCount.Length; i++)
{
if (finalCount[i] > max)
{
max = finalCount[i];
result = (i + );
}
}
node.feature_Type=("result");
node.features=(new List<String> { "" + result });
return node;
}
#endregion
int deep = node.deep;
#region 分裂
node.SplitFeature=("" + info.splitIndex); List<Node> childNode = new List<Node>();
int[] used = new int[isUsed.Length];
for (int i = ; i < used.Length; i++)
{
used[i] = isUsed[i];
}
if (info.type == )
{
used[info.splitIndex] = ;
node.feature_Type=("离散");
}
else
{
used[info.splitIndex] = ;
node.feature_Type=("连续");
}
int sumLeaf = ;
int sumWrong = ;
List<int>[] rowIndex = info.temp;
List<String> features = info.features;
for (int j = ; j < rowIndex.Length; j++)
{
if (rowIndex[j].Count == )
{
continue;
}
if (info.type == )
features.Add(""+(j+));
Node node1 = new Node();
//node1.setNum(info.getTemp()[j]);
node1.setClassCount(info.class_Count[j]);
//node1.setUsed(used);
node1.deep=(deep + );
node1.rowCount = info.temp[j].Count;
node1 = findBestSplit(node1, info.splitIndex,info.temp[j], used);
childNode.Add(node1);
}
node.features=(features);
node.childNodes=(childNode); #endregion
return node;
}
catch (Exception e)
{
Console.WriteLine(e.StackTrace);
return node;
}
}
/// <summary>
/// 计算熵
/// </summary>
/// <param name="counts"></param>
/// <param name="countAll"></param>
/// <returns></returns>
public static double CalEntropy(double[] counts, int countAll)
{
try
{
double allShang = ;
for (int i = ; i < counts.Length; i++)
{
if (counts[i] == )
{
continue;
}
double rate = counts[i] / countAll;
allShang = allShang + rate * Math.Log(rate, );
}
return -allShang;
}
catch (Exception e)
{
return ;
}
}
#endregion

(注:上述代码只是ID3的核心代码,数据预处理的代码并没有给出,只要将预处理后的数据输入到主方法findBestSplit中,就可以得到最终的结果)

总结

ID3是基本的决策树构建算法,作为决策树经典的构建算法,其具有结构简单、清晰易懂的特点。虽然ID3比较灵活方便,但是有以下几个缺点:

 (1)采用信息增益进行分裂,分裂的精确度可能没有采用信息增益率进行分裂高

(2)不能处理连续型数据,只能通过离散化将连续性数据转化为离散型数据

(3)不能处理缺省值

(4)没有对决策树进行剪枝处理,很可能会出现过拟合的问题