这周在看循环数据网络, 发现一个博客, 里面推导极其详细, 借此记录重点.
详细推导
强烈建议手推一遍, 虽然会花一点时间, 但便于理清思路.
长短时记忆网络
回顾BPTT算法里误差项沿时间反向传播的公式:
δTk=δTt∏i=kt−1diag[f′(neti)]W(1)
根据范数的性质, 来获取
δTk
的模的上界:
‖δTk‖⩽⩽‖δTt‖∏i=kt−1‖diag[f′(neti)]‖‖W‖‖δTt‖(βfβW)t−k(2)(3)
可以看到, 误差项
δ
从t时刻传递到k时刻, 其值上界是
βfβw
的指数函数.
βfβw
分别是对角矩阵
diag[f′(neti)]
和矩阵W模的上界. 显然, 当t-k很大时, 会有
梯度爆炸, 当t-k很小时, 会有
梯度消失.
为了解决RNN的梯度爆炸和梯度消失的问题, 就出现了长短时记忆网络(Long Short Memory Network, LSTM). 原始RNN的隐藏层只有一个状态h, 它对于短期的输入非常敏感. 如果再增加一个状态c, 让它来保存长期的状态, 那么就可以解决原始RNN无法处理长距离依赖的问题.
新增加的状态c, 称为单元状态(cell state). 上图按照时间维度展开:
上图中, 在t时刻, LSTM的输入有三个: 当前时刻网络的输入值
xt
, 上一时刻LSTM的输出值
ht−1
, 以及上一时刻的单元状态
ct−1
; LSTM的输出有两个: 当前时刻的LSTM输出
ht
, 当前时刻的状态
ct
. 其中
x,h,c
都是向量.
LSTM的关键在于怎样控制长期状态c. 在这里, LSTM的思路是使用三个控制开关:
第一个开关, 负责控制继续保存长期状态c; (遗忘门)
第二个开关, 负责控制把即时状态输入到长期状态c; (输入门)
第三个开关, 负责控制是都把长期状态c作为当前的LSTM的输出. (输出门)
接下来, 具体描述一下输出h和单元状态c的计算方法.
长短时记忆网络的前向计算
开关在算法中用门(gate)实现. 门实际上就是一层全连接层, 它的输入是一个向量, 输出是一个0~1的实数向量. 假设w是门的权重向量, b是偏置项, 门可以表示为:
g(x)=σ(Wx+b)
门的使用, 就是
用门的输出向量按元素乘以我们需要控制的那个向量. 当门的输出为0时, 任何向量与之相乘都会得到0向量, 相当于什么都不能通过; 当输出为1时, 任何向量与之相乘都为本身, 相当于什么都可以通过. 上式中
σ
是sigmoid函数, 值域为(0,1), 所以门的状态是半开半闭的.
LSTM用两个门来控制单元状态c的内容, 一个是遗忘门(forget gate), 它决定了上一时刻的单元状态
ct−1
有多少保留到当前时刻
ct
; 另一个是输入门(input gate), 它决定了当前时刻网络的输入
xt
有多少保存到单元状态
ct
. LSTM用输出门(output gate)来控制单元状态
ct
有多少输出到LSTM的当前输出值
ht
.
1. 遗忘门:
ft=σ(Wf⋅[ht−1,xt]+bf)(式1)
上式中,
Wf
是遗忘门的权重矩阵,
[ht−1,xt]
表示把两个向量连接到一个更长的向量,
bf
是遗忘门的偏置项,
σ
是sigmoid函数. 如果输入的维度是
dh
, 单元状态的维度是
dc
(通常
dc=dh
), 则遗忘门的权重矩阵
Wf
维度是
dc×(dh+dx)
.
事实上, 权重矩阵
Wf
都是两个矩阵拼接而成的: 一个是
Wfh
, 它对应着输入项
ht−1
, 其维度为
dc×dh
; 一个是
Wfx
, 它对应着输入项
xt
, 其维度为
dc×dh
.
Wf
可以写成:
[Wf][ht−1xt]=[WfhWfx][ht−1xt]=Wfhht−1+Wfxxt(4)(5)
下图是遗忘门的计算:
2. 输入门:
it=σ(Wi⋅[ht−1,xt]+bi)(式2)
上式中,
Wi
是输入门的权重矩阵,
bi
是输入门的偏置项.
下图是输入门的计算:
接下来, 计算用于描述当前输入的单元状态
c̃t
, 它是根据根据上一次的输出和本次的输入来计算的:
c̃t=tanh(Wc⋅[ht−1,xt]+bc)(式3)
下图是
c̃t
的计算:
现在, 我们计算当前时刻的单元状态
ct
. 它是由上一次的单元状态
ct−1
按元素乘以遗忘门
ft
, 再用当前输入的单元状态
c̃t
按元素乘以输入门
it
, 再将两个积加和产生的:
ct=ft∘ct−1+it∘c̃t(式4)
符号
∘
表示
按元素乘. 下图是
ct
的计算:
这样, 就把LSTM关于当前的记忆
c̃t
和长期的记忆
ct−1
组合在一起, 形成了新的单元状态
ct
. 由于遗忘门的控制, 它可以保存很久之前的信息, 由于输入门的控制, 它又可以避免当前无关紧要的内容进入记忆.
3. 输出门
ot=σ(Wo⋅[ht−1,xt]+bo)(式5)
下图表示输出门的计算:
LSTM最终的输出, 是由输出门和单元状态共同确定的:
ht=ot∘tanh(ct)(式6)
下图表示LSTM最终输出的计算:
式1到式6就是LSTM前向计算的全部公式.
长短时记忆网络的训练
训练部分比前向计算部分复杂, 具体推导如下.
LSTM训练算法框架
LSTM的训练算法仍然是反向传播算法, 主要是三个步骤:
- 前向计算每个神经元的输出值, 对于LSTM来说, 即
ft,it,ctot,ht
五个向量的值;
- 反向计算每个神经元的误差项
δ
值, 与RNN一样, LSTM误差项的反向传播也是包括两个方向: 一个沿时间的反向传播, 即从当前t时刻开始, 计算每个时刻的误差项; 一个是将误差项向上一层传播;
- 根据相应的误差项, 计算每个权重的梯度.
关于公式和符号的说明
接下来的推导, 设定gate的激活函数为sigmoid, 输出的激活函数为tanh函数. 他们的导数分别为:
σ(z)σ′(z)tanh(z)tanh′(z)=y=11+e−z=y(1−y)=y=ez−e−zez+e−z=1−y2(6)(7)(8)(9)
从上式知, sigmoid函数和tanh函数的导数都是原函数的函数, 那么计算出原函数的值, 导数便也计算出来.
LSTM需要学习的参数共有8组, 权重矩阵的两部分在反向传播中使用不同的公式, 分别是:
- 遗忘门的权重矩阵
Wf
和偏置项
bt
,
Wf
分开为两个矩阵
Wfh
和
Wfx
- 输入门的权重矩阵
Wi
和偏置项
bi
,
Wi
分开为两个矩阵
Wih
和
Wxi
- 输出门的权重矩阵
Wo
和偏置项
bo
,
Wo
分开为两个矩阵
Woh
和
Wox
- 计算单元状态的权重矩阵
Wc
和偏置项
bc
,
Wc
分开为两个矩阵
Wch
和
Wcx
按元素乘
∘
符号. 当
∘
作用于两个向量时, 运算如下:
a∘b=⎡⎣⎢⎢⎢⎢⎢a1a2a3...an⎤⎦⎥⎥⎥⎥⎥∘⎡⎣⎢⎢⎢⎢⎢b1b2b3...bn⎤⎦⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢a1b1a2b2a3b3...anbn⎤⎦⎥⎥⎥⎥⎥
当
∘
作用于
一个向量和
一个矩阵时, 运算如下:
a∘X=⎡⎣⎢⎢⎢⎢⎢a1a2a3...an⎤⎦⎥⎥⎥⎥⎥∘⎡⎣⎢⎢⎢⎢⎢x11x21x31xn1x12x22x32xn2x13x23x33...xn3............x1nx2nx3nxnn⎤⎦⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢a1x11a2x21a3x31anxn1a1x12a2x22a3x32anxn2a1x13a2x23a3x33...anxn3............a1x1na2x2na3x3nanxnn⎤⎦⎥⎥⎥⎥⎥(10)(11)
当
∘
作用于
两个矩阵时, 两个矩阵对应位置的元素相乘. 按元素乘可以在某些情况下简化矩阵和向量运算.
例如, 当一个对角矩阵右乘一个矩阵时, 相当于用对角矩阵的对角线组成的向量按元素乘那个矩阵:
diag[a]X=a∘X
当一个行向量左乘一个对角矩阵时, 相当于这个行向量按元素乘那个矩阵对角组成的向量:
aTdiag[b]=a∘b
在t时刻, LSTM的输出值为
ht
. 我们定义t时刻的误差项
δt
为:
δt=def∂E∂ht
这里假设误差项是损失函数对输出值的导数, 而不是对加权输出
netlt
的导数. 因为LSTM有四个加权输入, 分别对应
ft,it,ct,ot
, 我们希望往上一层传递一个误差项而不是四个, 但需要定义这四个加权输入以及它们对应的误差项.
netf,tneti,tnetc̃,tneto,tδf,tδi,tδc̃,tδo,t=Wf[ht−1,xt]+bf=Wfhht−1+Wfxxt+bf=Wi[ht−1,xt]+bi=Wihht−1+Wixxt+bi=Wc[ht−1,xt]+bc=Wchht−1+Wcxxt+bc=Wo[ht−1,xt]+bo=Wohht−1+Woxxt+bo=def∂E∂netf,t=def∂E∂neti,t=def∂E∂netc̃,t=def∂E∂neto,t(12)(13)(14)(15)(16)(17)(18)(19)(20)(21)(22)(23)
误差项沿时间的反向传递
沿时间反向传递误差项, 就是要计算出t-1时刻的误差项
δt−1
.
δTt−1=∂E∂ht−1=∂E∂ht∂ht∂ht−1=δTt∂ht∂ht−1(24)(25)(26)
其中,
∂ht∂ht−1
是一个Jacobian矩阵, 为了求出它, 需要列出
ht
的计算公式, 即前面的
式6和
式4:
ht=ot∘tanh(ct)(式6)ct=ft∘ct−1+it∘c̃t(式4)
显然,
ot,ft,it,c̃t
都是
ht−1
的函数, 那么, 利用全导数公式可得:
δTt∂ht∂ht−1=δTt∂ht∂ot∂ot∂neto,t∂neto,t∂ht−1+δTt∂ht∂ct∂ct∂ft∂ft∂netf,t∂netf,t∂ht−1+δTt∂ht∂ct∂ct∂it∂it∂neti,t∂neti,t∂ht−1+δTt∂ht∂ct∂ct∂c̃t∂c̃t∂netc̃,t∂netc̃,t∂ht−1=δTo,t∂neto,t∂ht−1+δTf,t∂netf,t∂ht−1+δTi,t∂neti,t∂ht−1+δTc̃,t∂netc̃,t∂ht−1(式7)(27)(28)(29)
下面, 要把
式7中的每个偏导数都求出来, 根据
式6, 可以求出:
∂ht∂ot∂ht∂ct=diag[tanh(ct)]=diag[ot∘(1−tanh(ct)2)](30)(31)
根据
式4, 可以求出:
∂ct∂ft∂ct∂it∂ct∂c̃t=diag[ct−1]=diag[c̃t]=diag[it](32)(33)(34)
因为:
otneto,tftnetf,titneti,tc̃tnetc̃,t=σ(neto,t)=Wohht−1+Woxxt+bo=σ(netf,t)=Wfhht−1+Wfxxt+bf=σ(neti,t)=Wihht−1+Wixxt+bi=tanh(netc̃,t)=Wchht−1+Wcxxt+bc(35)(36)(37)(38)(39)(40)(41)(42)(43)(44)(45)
可以得出:
∂ot∂neto,t∂neto,t∂ht−1∂ft∂netf,t∂netf,t∂ht−1∂it∂neti,t∂neti,t∂ht−1∂c̃t∂netc̃,t∂netc̃,t∂ht−1=diag[ot∘(1−ot)]=Woh=diag[ft∘(1−ft)]=Wfh=diag[it∘(1−it)]=Wih=diag[1−c̃2t]=Wch(46)(47)(48)(49)(50)(51)(52)(53)
将上述偏导数导入到
式7, 可以得到:
δt−1=δTo,t∂neto,t∂ht−1+δTf,t∂netf,t∂ht−1+δTi,t∂neti,t∂ht−1+δTc̃,t∂netc̃,t∂ht−1=δTo,tWoh+δTf,tWfh+δTi,tWih+δTc̃,tWch(式8)(54)(55)
根据
δo,t,δf,t,δi,t,δc̃,t
的定义, 可知:
δTo,tδTf,tδTi,tδTc̃,t=δTt∘tanh(ct)∘ot∘(1−ot)(式9)=δTt∘ot∘(1−tanh(ct)2)∘ct−1∘ft∘(1−ft)(式10)=δTt∘ot∘(1−tanh(ct)2)∘c̃t∘it∘(1−it)(式11)=δTt∘ot∘(1−tanh(ct)2)∘it∘(1−c̃2)(式12)(56)(57)(58)(59)
式8到
式12就是将误差沿时间反向传播一个时刻的公式. 有了它, 便可以写出将误差项传递到任意k时刻的公式:
δTk=∏j=kt−1δTo,jWoh+δTf,jWfh+δTi,jWih+δTc̃,jWch(式13)
将误差项传递到上一层
假设当前是第
l
层, 定义
l−1
层的误差项是误差函数对
l−1
层加权输入的导数, 即:
δl−1t=def∂Enetl−1t
本次LSTM的输入
xt
由下面的公式计算:
xlt=fl−1(netl−1t)
上式中,
fl−1
表示第
l−1
的
激活函数.
因为
netlf,t,netli,t,netlc̃,t,netlo,t
都是
xt
的函数,
xt
又是
netl−1t
的函数, 因此, 要求出
E
对
netl−1t
的导数, 就需要使用全导数公式:
∂E∂netl−1t=∂E∂netlf,t∂netlf,t∂xlt∂xlt∂netl−1t+∂E∂netli,t∂netli,t∂xlt∂xlt∂netl−1t+∂E∂netlc̃,t∂netlc̃,t∂xlt∂xlt∂netl−1t+∂E∂netlo,t∂netlo,t∂xlt∂xlt∂netl−1t=δTf,tWfx∘f′(netl−1t)+δTi,tWix∘f′(netl−1t)+δTc̃,tWcx∘f′(netl−1t)+δTo,tWox∘f′(netl−1t)=(δTf,tWfx+δTi,tWix+δTc̃,tWcx+δTo,tWox)∘f′(netl−1t)(式14)(60)(61)(62)(63)
式14就是将误差传递到上一层的公式.
权重梯度的计算
对于
Wfh,Wih,Wch,Woh
的权重梯度, 我们知道它的梯度是各个时刻梯度之和. 我们首先求出它们在t时刻的梯度, 然后再求出他们最终的梯度.
我们已经求得了误差项
δo,t,δf,t,δi,t,δc̃,t
, 很容易求出t时刻的
Woh,Wfh,Wih,Wch
:
∂E∂Woh,t∂E∂Wfh,t∂E∂Wih,t∂E∂Wch,t=∂E∂neto,t∂neto,t∂Woh,t=δo,thTt−1=∂E∂netf,t∂netf,t∂Wfh,t=δf,thTt−1=∂E∂neti,t∂neti,t∂Wih,t=δi,thTt−1=∂E∂netc̃,t∂netc̃,t∂Wch,t=δc̃,thTt−1(64)(65)(66)(67)(68)(69)(70)(71)(72)(73)(74)
将各个时刻的梯度加在一起, 就能得到最终的梯度:
∂E∂Woh∂E∂Wfh∂E∂Wih∂E∂Wch=∑j=1tδo,jhTj−1=∑j=1tδf,jhTj−1=∑j=1tδi,jhTj−1=∑j=1tδc̃,jhTj−1(75)(76)(77)(78)
对于偏置项
bf,bi,bc,bo
的梯度, 先求出各个时刻的偏置项梯度:
∂E∂bo,t∂E∂bf,t∂E∂bi,t∂E∂bc,t=∂E∂neto,t∂neto,t∂bo,t=δo,t=∂E∂netf,t∂netf,t∂bf,t=δf,t=∂E∂neti,t∂neti,t∂bi,t=δi,t=∂E∂netc̃,t∂netc̃,t∂bc,t=δc̃,t(79)(80)(81)(82)(83)(84)(85)(86)(87)(88)(89)
将各个时刻的偏置项梯度加在一起:
∂E∂bo∂E∂bi∂E∂bf∂E∂bc=∑j=1tδo,j=∑j=1tδi,j=∑j=1tδf,j=∑j=1tδc̃,j(90)(91)(92)(93)
对于
Wfx,Wix,Wcx,Wox
的权重梯度, 只需要根据相应的误差项直接计算即可:
∂E∂Wox∂E∂Wfx∂E∂Wix∂E∂Wcx=∂E∂neto,t∂neto,t∂Wox=δo,txTt=∂E∂netf,t∂netf,t∂Wfx=δf,txTt=∂E∂neti,t∂neti,t∂Wix=δi,txTt=∂E∂netc̃,t∂netc̃,t∂Wcx=δc̃,txTt(94)(95)(96)(97)(98)(99)(100)(101)(102)(103)(104)
以上就是LSTM的训练算法的全部公式
GRU
上面所述是一种普通的LSTM, 事实上LSTM存在很多变体, GRU就是其中一种最成功的变体. 它对LSTM做了很多简化, 同时保持和LSTM相同的效果.
GRU对LSTM做了两大改动:
- 将输入门, 遗忘门, 输出门变为两个门: 更新门(Update Gate)
zt
和充值门(Reset Gate)
rt
.
- 将单元状态与输出合并为一个状态:
h
GRU的前向计算公式为:
ztrth̃th=σ(Wz⋅[ht−1,xt])=σ(Wr⋅[ht−1,xt])=tanh(W⋅[rt∘ht−1,xt])=(1−zt)∘ht−1+zt∘h̃t(105)(106)(107)(108)
下图是GRU的示意图: