EM算法说起来很简单,给定一个要估计的参数的初值,计算隐含变量分布,再根据隐含变量的分布更新要估计的参数值,之后在这两个步骤之间进行迭代。但是其中的数学原理,GMM的推导等等其实并不简单,难想更难算。这篇博客主要基于翻译我看过的好材料,对其中做出些许的解释。以下便从最简单的例子说起
投硬币的例子
出自http://www.cmi.ac.in/~madhavan/courses/datamining12/reading/em-tutorial.pdf
EM算法实现的是在数据不完全的情况下的参数预测。我们用一个投硬币的例子来解释EM算法的流程。假设我们有A,B两枚硬币,其正面朝上的概率分别为θA,θB,这两个参数即为需要估计的参数。我们设计5组实验,每次实验投掷10次硬币(但不知道用哪一枚硬币进行这次实验),投掷结束后会得到一个数组x=(x1,x2,...,x5),来表示每组实验有几次硬币是正面朝上的,因此0≤xi≤10。
如果我们知道每一组实验中的xi是A硬币投掷的结果还是B硬币的结果,我们就很容易估计出θA,θB,只需要统计在所有的试验中两个硬币分别有几次是正面朝上的,除以他们各自投掷的总次数。数据不完全的意思在于,我们并不知道每一个数据是哪一个硬币产生的。EM算法就是适用于这种问题。
虽然我们不知道每组实验用的是哪一枚硬币,但如果我们用某种方法猜测每组实验是哪个硬币投掷的,我们就可以将数据缺失的估计问题转化成一个最大似然问题+完整参数估计问题。
我们将逐步讲解投硬币的例子。假设5次试验的结果如下(H是正面,T是反面):
试验序号 |
结果 |
1 |
H T T T H H T H T H |
2 |
H H H H T H H H H H |
3 |
H T H H H H H T H H |
4 |
H T H T T T H H T T |
5 |
T H H H T H H H T H |
首先,随机选取初值θA,θB,比如θA=0.6,θB=0.5。EM算法的E步骤,是计算在当前的预估参数下,隐含变量(是A硬币还是B硬币)的每个值出现的概率。也就是给定θA,θB和观测数据,计算这组数据出自A硬币的概率和这组数据出自B硬币的概率。对于第一组实验,5正面5背面。
A硬币得到这个结果的概率为0.65×0.45=0.000796
B硬币得到这个结果的概率为0.55×0.55=0.000977
因此,第一组实验是A硬币得到的概率为0.000796/(0.000796+0.000977)=0.45,第一组实验是B硬币得到的概率为0.000977/(0.000796+0.000977)=0.55。整个5组实验的A,B投掷概率如下:
试验序号 |
是A硬币概率 |
是B硬币概率 |
1 |
0.45 |
0.55 |
2 |
0.80 |
0.20 |
3 |
0.73 |
0.27 |
4 |
0.35 |
0.65 |
5 |
0.65 |
0.35 |
根据隐含变量的概率,可以计算出两组训练值的期望。依然以第一组实验来举例子:5正5反中,A硬币投掷出了0.45×5=2.2个正面和0.45×5=2.2个反面;B硬币投掷出了0.55×5=2.8个正面和0.55×5=2.8个反面。整个5组实验的期望如下表:
试验序号 |
A硬币 |
B硬币 |
1 |
2.2H, 2.2T |
2.8H, 2.8T |
2 |
7.2H, 0.8T |
1.8H, 0.2T |
3 |
5.9H, 1.5T |
2.1H, 0.5T |
4 |
1.4H, 2.1T |
2.6H, 3.9T |
5 |
4.5H, 1.9T |
2.5H, 1.1T |
SUM |
21.3H, 8.6T |
11.7H, 8.4T |
通过计算期望,我们把一个有隐含变量的问题变化成了一个没有隐含变量的问题,由上表的数据,估计θA,θB变得非常简单。
θA=21.3/(21.3+8.6)=0.71
θB=11.7/(11.7+8.4)=0.58
下图是原文中以上描述的示意图
当我们有了新的估计,便可以基于这个估计进行下一次迭代了。综上所述,EM算法的步骤是:
1. E步骤:根据观测值计算隐含变量的分布情况
2. M步骤:根据隐含变量的分布来估计新的模型参数
GMM的参数推导
总体思想来自PRML chapter 9.2
高斯混合模型是什么这里不再赘述。书上的公式相当简洁,当然多元高斯函数对于均值和方差求导你可以不会,然而这是一个练习矩阵求导的好机会,毕竟好久没有推过这么复杂的公式了;再者,关于这部分的求导细节网络上的资料很少。以下就分享一下我的推导过程。
根据极大似然的思想,在已知GMM模型产生的一系列数据点x1,x2,...xn (假定它们是列向量)时,我们需要知道一组最佳的参数μ1,μ2,...μk,Σ1,Σ2,...Σk,和π1,π2,...πk,在这种参数下生成这组数据点的可能性最大。求解GMM模型的参数,就是求以下的极大似然函数的极值点。
lnp(X|π,μ,Σ)=∑n=1Nln∑k=1KπkN(xn|μk,Σk)(1.1)
其中,多元高斯函数的公式为
N(xn|μk,Σk)=12πD/2|Σk|1/2exp(−12(xn−μk)TΣ−1k(xn−μk))(1.2)
我们的最终目的是对公式(1.1)进行对μk,Σk,πk求导,并求导数为零时它们分别对应的值。在对这个终极公式求导之前,为了描述的更清楚,我们先计算公式(1.2)对μk,Σk的导数。
ddμkN(xn|μk,Σk)=12πD/2|Σk|1/2exp(−12(xn−μk)TΣ−1k(xn−μk))ddμk(−12(xn−μk)TΣ−1k(xn−μk))=N(xn|μk,Σk)ddμk(−12(xn−μk)TΣ−1k(xn−μk))=N(xn|μk,Σk)dd(xn−μk)(−12(xn−μk)TΣ−1k(xn−μk))ddμk(xn−μk)=N(xn|μk,Σk)(−Σ−1k(xn−μk))(−1)=N(xn|μk,Σk)Σ−1k(xn−μk)
这里,−12(xn−μk)TΣ−1k(xn−μk)对于xn−μk的求导原理如下(包括一个简单的变量代换):
ddxxTAx=2Ax,当A为对称矩阵
公式来源是https://en.wikipedia.org/wiki/Matrix_calculus
再计算N(xn|μk,Σk)对协方差的求导
ddΣkN(xn|μk,Σk)=12πD/2{d|Σk|−1/2dΣkexp(−12(xn−μk)TΣ−1k(xn−μk))+dexp(−12(xn−μk)TΣ−1(xn−μk))dΣk|Σ|−1/2}=12πD/2{−12|Σk|−32|Σk|(Σ−1k)Texp(−12(xn−μk)TΣ−1k(xn−μk))+12Σ−Tk(xn−μk)(xn−μk)TΣ−Tk|Σk|−1/2}=12πD/2|Σk|−1/2exp(−12(xn−μk)TΣ−1k(xn−μk)){−12(Σ−1k)T+12Σ−Tk(xn−μk)(xn−μk)TΣ−Tk}=N(xn|μk,Σk){−12(Σ−1k)T+12Σ−Tk(xn−μk)(xn−μk)TΣ−Tk}
这里求导的重点有两个,对行列式的求导公式和对逆矩阵trace的求导公式
首先,对行列式的求导公式为
d|X|dX=|X|(X−1)T
这个公式同样出自https://en.wikipedia.org/wiki/Matrix_calculus
因此,d|Σk|−1/2dΣk=−12|Σk|−32|Σk|(Σ−1k)T
接下来,对矩阵的trace的求导公式ddXTr(AX−1B)=−X−TATBTX−T
这个公式出自http://www2.imm.dtu.dk/pubdb/views/edoc_download.php/3274/pdf/imm3274.pdf
又因为12(xn−μk)TΣ−1(xn−μk)其实是一个实数,因此它等于它的trace,因此
d(−12(xn−μk)TΣ−1(xn−μk))dΣk=dtr(−12(xn−μk)TΣ−1k(xn−μk))dΣk=12Σ−Tk(xn−μk)(xn−μk)TΣ−Tk
推完了一个高斯函数对其均值和方差的求导,我们开始进入主题:对极大似然函数对均值和方差求导
首先,对均值求导:
ddμklnp(X|π,μ,Σ)=∑n=1N1∑Kj=1πjN(xn|μj,Σj)ddμkπkN(xn|μk,Σk)=∑n=1N1∑Kj=1πjN(xn|μj,Σj)πkN(xn|μk,Σk)Σ−1k(xn−μk)=∑n=1NπkN(xn|μk,Σk)∑Kj=1πjN(xn|μj,Σj)Σ−1k(xn−μk)
为了表达的方便,我们令
γ(znk)=πkN(xn|μk,Σk)∑Kj=1πjN(xn|μj,Σj),
Nk=∑Nn=1γ(znk)则有:
ddμklnp(X|π,μ,Σ)=∑n=1Nγ(znk)Σ−1k(xn−μk)
我们让这个式子等于0,即
∑n=1Nγ(znk)Σ−1k(xn−μk)=0
可以得到
μk=1Nk∑n=1Nγ(znk)xn
终于我们看到书上的结果了!观察一下,这个结果其实很容易想象。
γ(znk)的实际含义是第n个观测数据分别属于第1,2,…,k个高斯函数的概率。每一个高斯函数的均值,将会是观测数据在用各个高斯函数上的概率加权后的计算。
现在我们再对方差求导。
ddΣklnp(X|π,μ,Σ)=∑n=1N1∑Kj=1πjN(xn|μj,Σj)ddΣkπkN(xn|μk,Σk)=∑n=1N1∑Kj=1πjN(xn|μj,Σj)πkN(xn|μk,Σk){−12(Σ−1k)T+12Σ−Tk(xn−μk)(xn−μk)TΣ−Tk}=∑n=1NπkN(xn|μk,Σk)∑Kj=1πjN(xn|μj,Σj){−12(Σ−1k)T+12Σ−Tk(xn−μk)(xn−μk)TΣ−Tk}
令此公式为零
∑n=1NπkN(xn|μk,Σk)∑Kj=1πjN(xn|μj,Σj){−12(Σ−1k)T+12Σ−Tk(xn−μk)(xn−μk)TΣ−Tk}=0
对此公式稍作化简,两边同乘2,由于协方差矩阵都是实对称的,可以去掉转置符号,再左右各乘Σk可以化简为
∑n=1NπkN(xn|μk,Σk)∑Kj=1πjN(xn|μj,Σj){−Σk+(xn−μk)(xn−μk)T}=0
我们终于得到了方差的迭代公式:
Σk=1Nk∑n=1Nγ(znk)(xn−μk)(xn−μk)T
最后,对于πk的求导就显得简单多了,其重要注意的一点是,由于∑Kk=1πk=1,求导的时候需要用拉格朗日乘子法。即为,对下面的函数求导:
lnp(X|π,μ,Σ)+λ(∑k=1Kπk−1)
这个求导过程很简单没什么好说的,上面这个式子对
πk的求导结果如下
∑n=1NN(xn|μk,Σk)∑Kj=1πjN(xn|μj,Σj)+λ=0
两边同时乘以
πk可以得到
Nk+λπk=0
对所有的k求和得到
N+λ=0
再带入前面的式子得到
πk=NkN
到这里,GMM的迭代公式就推导结束啦!我在学习的时候,只要有对方差求导的作业就没做出来过。想想其实这些公式看上去难推导,其实是因为对矩阵求导的不熟悉和畏惧。经验就是,多查查wiki就好。或者matrix cookbook会有更加详细的公式。