kaldi中CD-DNN-HMM网络参数更新公式手写推导

时间:2021-04-05 12:24:45

在基于DNN-HMM的语音识别中,DNN的作用跟GMM是一样的,即它是取代GMM的,具体作用是算特征值对每个三音素状态的概率,算出来哪个最大这个特征值就对应哪个状态。只不过以前是用GMM算的,现在用DNN算了。这是典型的多分类问题,所以输出层用的激活函数是softmax,损失函数用的是cross entropy(交叉熵)。不用均方差做损失函数的原因是在分类问题上它是非凸函数,不能保证全局最优解(只有凸函数才能保证全局最优解)。Kaldi中也支持DNN-HMM,它还依赖于上下文(context dependent, CD),所以叫CD-DNN-HMM。在kaldi的nnet1中,特征提取用filterbank,每帧40维数据,默认取当前帧前后5帧加上当前帧共11帧作为输入,所以输入层维数是440(440 = 40*11)。同时默认有4个隐藏层,每层1024个网元,激活函数是sigmoid。今天我们看看网络的各种参数是怎么得到的(手写推导)。由于真正的网络比较复杂,为了推导方便这里对其进行了简化,只有一个隐藏层,每层的网元均为3,同时只有weight没有bias。这样网络如下图:

kaldi中CD-DNN-HMM网络参数更新公式手写推导

上图中输入层3个网元为i1/i2/i3(i表示input),隐藏层3个网元为h1/h2/h3(h表示hidden),输出层3个网元为o1/o2/o3(o表示output)。隐藏层h1的输入为kaldi中CD-DNN-HMM网络参数更新公式手写推导 (q11等表示输入层和隐藏层之间的权值),输出为kaldi中CD-DNN-HMM网络参数更新公式手写推导。输出层o1的输入为kaldi中CD-DNN-HMM网络参数更新公式手写推导(w11等表示隐藏层和输出层之间的权值),输出为kaldi中CD-DNN-HMM网络参数更新公式手写推导。其他可类似推出。损失函数用交叉熵。今天我们看看网络参数(以隐藏层和输出层之间的w11以及输入层和隐藏层之间的q11为例)在每次迭代训练后是怎么更新的。先看隐藏层和输出层之间的w11。

1,隐藏层和输出层之间的w11的更新

kaldi中CD-DNN-HMM网络参数更新公式手写推导

先分别求三个导数的值:

kaldi中CD-DNN-HMM网络参数更新公式手写推导

kaldi中CD-DNN-HMM网络参数更新公式手写推导

所以最终的w11更新公式如下图:

kaldi中CD-DNN-HMM网络参数更新公式手写推导

2,输入层和隐藏层之间的q11的更新

kaldi中CD-DNN-HMM网络参数更新公式手写推导

先分别求三个导数的值:

kaldi中CD-DNN-HMM网络参数更新公式手写推导

kaldi中CD-DNN-HMM网络参数更新公式手写推导

所以最终的q11更新公式如下图:

kaldi中CD-DNN-HMM网络参数更新公式手写推导

以上的公式推导中如有错误,烦请指出,非常感谢!