机器学习:详解半朴素贝叶斯分类TAN原理(附Python实现)

时间:2022-10-13 14:52:27

0 写在前面

机器学习强基计划聚焦深度和广度,加深对机器学习模型的理解与应用。“深”在详细推导算法模型背后的数学原理;“广”在分析多个机器学习模型:决策树、支持向量机、贝叶斯与马尔科夫决策、强化学习等。

????详情:机器学习强基计划(附几十种经典模型源码合集)


机器学习强基计划4-4:详解半朴素贝叶斯分类AODE原理(附Python实现)中,我们向朴素贝叶斯模型里引入了独依赖假设,AODE的基本思路是考虑每个属性与其他属性间的依赖性,做加权平均,而本文介绍的TAN则是采用另一种思路。

1 条件互信息

机器学习强基计划2-1:一文总结熵——交叉熵、相对熵、互信息(附例题分析)中,我们介绍了互信息的概念,互信息(Mutual Information)用来描述给定随机变量 X X X后随机变量 Y Y Y不确定性的减少量(或反之),度量了 X X X Y Y Y间的相互依赖性。特别地,相互独立的随机变量间互信息为0,因为知道其中一个变量的信息不会对了解另一变量产生贡献

I ( X ; Y ) = ∑ x ∈ X , y ∈ Y p ( x , y ) log ⁡ p ( x , y ) p ( x ) p ( y ) I\left( X;Y \right) =\sum_{x\in X,y\in Y}{p\left( x,y \right) \log \frac{p\left( x,y \right)}{p\left( x \right) p\left( y \right)}} I(X;Y)=xX,yYp(x,y)logp(x)p(y)p(x,y)

TAN算法使用的条件互信息是互信息的条件概率形式

I ( X i , X j ∣ Y ) = ∑ x i ∈ X i , x j ∈ X j ; C ∈ Y P ( x i , x j ∣ C ) log ⁡ P ( x i , x j ∣ C ) P ( x i ∣ C ) P ( x j ∣ C ) I\left( X_i,X_j|\mathcal{Y} \right) =\sum_{x_i\in X_i,x_j\in X_j;C\in \mathcal{Y}}{P\left( x_i,x_j|C \right) \log \frac{P\left( x_i,x_j|C \right)}{P\left( x_i|C \right) P\left( x_j|C \right)}} I(Xi,XjY)=xiXi,xjXj;CYP(xi,xjC)logP(xiC)P(xjC)P(xi,xjC)

因此可以衡量在给定标签的情况下,两个属性间的依赖关系,正好符合半朴素贝叶斯的思想

机器学习:详解半朴素贝叶斯分类TAN原理(附Python实现)

2 最大带权生成树

这里介绍图论中的经典算法Kruskal算法,Kruskal算法以贪心策略为核心,依次选择当前图中权重最小且不构成圈的边,直至构造出树,最大生成树则依次选择最大权重的边即可,算法流程如下

机器学习:详解半朴素贝叶斯分类TAN原理(附Python实现)

3 TAN算法原理

TAN的核心原理很简单:以 I ( X i , X j ∣ Y ) I\left( X_i,X_j|\mathcal{Y} \right) I(Xi,XjY)为权、属性 X i X_i Xi X j X_j Xj为节点构造完全图;最后通过最大生成树算法构造该完全图的最大带权生成树,并挑选根节点。

TAN算法通过分析属性内部结构,保留了强依赖属性间的联系

机器学习:详解半朴素贝叶斯分类TAN原理(附Python实现)

4 Python实现

4.1 计算条件互信息

分为两个步骤:

  • 计算属性间的互信息;

    def __I(self, xi, xj):
        assert len(xi)==len(xj), "the length of two attributes must be equal!"
        # 样本数
        n = len(xi)
        # 两个属性的分布
        Xi, Xj = np.unique(xi), np.unique(xj)
        # 互信息值
        mi = 0
        for i in Xi:
            pxi = sum(xi == i) / n
            for j in Xj:
                pxj = sum(xj == j) / n
                pxixj = sum(np.equal(xi, i) & np.equal(xj, j)) / n
                if not pxixj:
                    continue
                mi = mi + pxixj * np.log(pxixj / (pxi * pxj))
        return mi
    
  • 增加条件标签,转化为条件互信息

    def __cI(self, xi, xj):
        assert len(xi)==len(xj), "the length of two attributes must be equal!"
        # 可选的类别数
        label = np.unique(self.y)
        # 条件互信息
        cmi = 0
        for _label in label:
            # 获取标签取值_label的样本序号
            labelIndex = np.squeeze(np.argwhere(np.squeeze(self.y)==_label))
            # 计算在标签取为_label的条件下属性互信息
            cmi = cmi + self.__I(xi[labelIndex], xj[labelIndex])
        return cmi
    

4.2 构造属性最大生成树

def generate(self, points, edges, mode='max'):
        treeFlag = False if mode=='min' else True
        # 图的节点数量
        ptsNum = len(points)
        # 将邻接矩阵转化为带权图字典
        edgesDict = {}
        for i in range(ptsNum):
            for j in range(i, ptsNum):
                # 存在连边
                if edges[i][j] < np.inf:
                    edgesDict[(points[i], points[j])] = edges[i][j]
        
        # 点下标掩码向量,用于判断是否形成环路
        ptsMask = np.zeros(ptsNum)
        # 符合条件需要保留的边
        edgesLeft = []
        # 按权重大小排列
        edgesList = sorted(edgesDict.items(), key=lambda x: x[1], reverse=treeFlag)
        # 执行贪心策略
        for edge in edgesList:
            (p1, p2), _ = edge
            index1 = np.squeeze(np.argwhere(points==p1))
            index2 = np.squeeze(np.argwhere(points==p2))
            # 判断是否形成环路:当边的两个点下标都为1时,则形成了环路
            if (ptsMask[index1] == 1) and (ptsMask[index2] == 1):
                continue
            else:
                ptsMask[index1] = 1
                ptsMask[index2] = 1
                edgesLeft.append((p1, p2))

        # 根据留存的无向图计算生成树
        # 计算根节点
        ptsCnt = dict()
        for p0, p1 in edgesLeft:
            ptsCnt[p0] = 1 if p0 not in ptsCnt.keys() else ptsCnt[p0] + 1
            ptsCnt[p1] = 1 if p1 not in ptsCnt.keys() else ptsCnt[p1] + 1
        root = sorted(ptsCnt.items(), key=lambda x: x[1], reverse=True)[0][0]
        # 遍历特征,保存依赖关系
        tree = dict()
        tree[root] = self.__spanTree(root, edgesLeft, [root], [])

        return tree

4.3 属性依赖关系可视化

直接打印最大生成树即可

def visualization(self, name=None):
    # 根节点位置
    rootPos = (0, self.graphSize - 3)
    self.__plotTree(self.tree, rootPos, rootPos, name)
    plt.show()

机器学习:详解半朴素贝叶斯分类TAN原理(附Python实现)

4.4 预测

model = TAN(X, y)
# 训练模型
model.train()
# 模型预测
predictY = model.predict(X)
print("错误:", np.sum(predictY!=y.T), "个\n准确率为:", np.sum(predictY==y.T)/y.size)

>>> 错误: 2>>> 准确率为: 0.8823529411764706

完整代码请联系下方博主名片获取


???? 更多精彩专栏


????源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系????