循环神经网络2--LSTM

时间:2021-12-03 20:30:07

这周在看循环数据网络, 发现一个博客, 里面推导极其详细, 借此记录重点.

详细推导

强烈建议手推一遍, 虽然会花一点时间, 但便于理清思路.

长短时记忆网络

回顾BPTT算法里误差项沿时间反向传播的公式:

(1) δ k T = δ t T i = k t 1 d i a g [ f ( n e t i ) ] W

根据范数的性质, 来获取 δ k T 的模的上界:
(2) δ k T δ t T i = k t 1 d i a g [ f ( n e t i ) ] W (3) δ t T ( β f β W ) t k

可以看到, 误差项 δ 从t时刻传递到k时刻, 其值上界是 β f β w 的指数函数. β f β w 分别是对角矩阵 d i a g [ f ( n e t i ) ] 和矩阵W模的上界. 显然, 当t-k很大时, 会有 梯度爆炸, 当t-k很小时, 会有 梯度消失.

为了解决RNN的梯度爆炸和梯度消失的问题, 就出现了长短时记忆网络(Long Short Memory Network, LSTM). 原始RNN的隐藏层只有一个状态h, 它对于短期的输入非常敏感. 如果再增加一个状态c, 让它来保存长期的状态, 那么就可以解决原始RNN无法处理长距离依赖的问题.

循环神经网络2--LSTM

新增加的状态c, 称为单元状态(cell state). 上图按照时间维度展开:

循环神经网络2--LSTM

上图中, 在t时刻, LSTM的输入有三个: 当前时刻网络的输入值 x t , 上一时刻LSTM的输出值 h t 1 , 以及上一时刻的单元状态 c t 1 ; LSTM的输出有两个: 当前时刻的LSTM输出 h t , 当前时刻的状态 c t . 其中 x , h , c 都是向量.

LSTM的关键在于怎样控制长期状态c. 在这里, LSTM的思路是使用三个控制开关:

第一个开关, 负责控制继续保存长期状态c; (遗忘门)

第二个开关, 负责控制把即时状态输入到长期状态c; (输入门)

第三个开关, 负责控制是都把长期状态c作为当前的LSTM的输出. (输出门)

循环神经网络2--LSTM

接下来, 具体描述一下输出h和单元状态c的计算方法.

长短时记忆网络的前向计算

开关在算法中用门(gate)实现. 门实际上就是一层全连接层, 它的输入是一个向量, 输出是一个0~1的实数向量. 假设w是门的权重向量, b是偏置项, 门可以表示为:

g ( x ) = σ ( W x + b )

门的使用, 就是 用门的输出向量按元素乘以我们需要控制的那个向量. 当门的输出为0时, 任何向量与之相乘都会得到0向量, 相当于什么都不能通过; 当输出为1时, 任何向量与之相乘都为本身, 相当于什么都可以通过. 上式中 σ 是sigmoid函数, 值域为(0,1), 所以门的状态是半开半闭的.

LSTM用两个门来控制单元状态c的内容, 一个是遗忘门(forget gate), 它决定了上一时刻的单元状态 c t 1 有多少保留到当前时刻 c t ; 另一个是输入门(input gate), 它决定了当前时刻网络的输入 x t 有多少保存到单元状态 c t . LSTM用输出门(output gate)来控制单元状态 c t 有多少输出到LSTM的当前输出值 h t .

1. 遗忘门:

f t = σ ( W f [ h t 1 , x t ] + b f ) ( 1 )

上式中, W f 是遗忘门的权重矩阵, [ h t 1 , x t ] 表示把两个向量连接到一个更长的向量, b f 是遗忘门的偏置项, σ 是sigmoid函数. 如果输入的维度是 d h , 单元状态的维度是 d c (通常 d c = d h ), 则遗忘门的权重矩阵 W f 维度是 d c × ( d h + d x ) .

事实上, 权重矩阵 W f 都是两个矩阵拼接而成的: 一个是 W f h , 它对应着输入项 h t 1 , 其维度为 d c × d h ; 一个是 W f x , 它对应着输入项 x t , 其维度为 d c × d h . W f 可以写成:

(4) [ W f ] [ h t 1 x t ] = [ W f h W f x ] [ h t 1 x t ] (5) = W f h h t 1 + W f x x t

下图是遗忘门的计算:

循环神经网络2--LSTM

2. 输入门:

i t = σ ( W i [ h t 1 , x t ] + b i ) ( 2 )

上式中, W i 是输入门的权重矩阵, b i 是输入门的偏置项.

下图是输入门的计算:

循环神经网络2--LSTM

接下来, 计算用于描述当前输入的单元状态 c ~ t , 它是根据根据上一次的输出和本次的输入来计算的:

c ~ t = tanh ( W c [ h t 1 , x t ] + b c ) ( 3 )

下图是 c ~ t 的计算:

循环神经网络2--LSTM

现在, 我们计算当前时刻的单元状态 c t . 它是由上一次的单元状态 c t 1 按元素乘以遗忘门 f t , 再用当前输入的单元状态 c ~ t 按元素乘以输入门 i t , 再将两个积加和产生的:

c t = f t c t 1 + i t c ~ t ( 4 )

符号 表示 按元素乘. 下图是 c t 的计算:

循环神经网络2--LSTM

这样, 就把LSTM关于当前的记忆 c ~ t 和长期的记忆 c t 1 组合在一起, 形成了新的单元状态 c t . 由于遗忘门的控制, 它可以保存很久之前的信息, 由于输入门的控制, 它又可以避免当前无关紧要的内容进入记忆.

3. 输出门

o t = σ ( W o [ h t 1 , x t ] + b o ) ( 5 )

下图表示输出门的计算:

循环神经网络2--LSTM

LSTM最终的输出, 是由输出门和单元状态共同确定的:

h t = o t tanh ( c t ) ( 6 )

下图表示LSTM最终输出的计算:

循环神经网络2--LSTM

式1式6就是LSTM前向计算的全部公式.

长短时记忆网络的训练

训练部分比前向计算部分复杂, 具体推导如下.

LSTM训练算法框架

LSTM的训练算法仍然是反向传播算法, 主要是三个步骤:

  1. 前向计算每个神经元的输出值, 对于LSTM来说, 即 f t , i t , c t o t , h t 五个向量的值;
  2. 反向计算每个神经元的误差项 δ 值, 与RNN一样, LSTM误差项的反向传播也是包括两个方向: 一个沿时间的反向传播, 即从当前t时刻开始, 计算每个时刻的误差项; 一个是将误差项向上一层传播;
  3. 根据相应的误差项, 计算每个权重的梯度.

关于公式和符号的说明

接下来的推导, 设定gate的激活函数为sigmoid, 输出的激活函数为tanh函数. 他们的导数分别为:

(6) σ ( z ) = y = 1 1 + e z (7) σ ( z ) = y ( 1 y ) (8) tanh ( z ) = y = e z e z e z + e z (9) tanh ( z ) = 1 y 2

从上式知, sigmoid函数和tanh函数的导数都是原函数的函数, 那么计算出原函数的值, 导数便也计算出来.

LSTM需要学习的参数共有8组, 权重矩阵的两部分在反向传播中使用不同的公式, 分别是:

  1. 遗忘门的权重矩阵 W f 和偏置项 b t , W f 分开为两个矩阵 W f h W f x
  2. 输入门的权重矩阵 W i 和偏置项 b i , W i 分开为两个矩阵 W i h W x i
  3. 输出门的权重矩阵 W o 和偏置项 b o , W o 分开为两个矩阵 W o h W o x
  4. 计算单元状态的权重矩阵 W c 和偏置项 b c , W c 分开为两个矩阵 W c h W c x

按元素乘 符号. 当 作用于两个向量时, 运算如下:

a b = [ a 1 a 2 a 3 . . . a n ] [ b 1 b 2 b 3 . . . b n ] = [ a 1 b 1 a 2 b 2 a 3 b 3 . . . a n b n ]

作用于 一个向量一个矩阵时, 运算如下:
(10) a X = [ a 1 a 2 a 3 . . . a n ] [ x 11 x 12 x 13 . . . x 1 n x 21 x 22 x 23 . . . x 2 n x 31 x 32 x 33 . . . x 3 n . . . x n 1 x n 2 x n 3 . . . x n n ] (11) = [ a 1 x 11 a 1 x 12 a 1 x 13 . . . a 1 x 1 n a 2 x 21 a 2 x 22 a 2 x 23 . . . a 2 x 2 n a 3 x 31 a 3 x 32 a 3 x 33 . . . a 3 x 3 n . . . a n x n 1 a n x n 2 a n x n 3 . . . a n x n n ]

作用于 两个矩阵时, 两个矩阵对应位置的元素相乘. 按元素乘可以在某些情况下简化矩阵和向量运算.

例如, 当一个对角矩阵右乘一个矩阵时, 相当于用对角矩阵的对角线组成的向量按元素乘那个矩阵:

d i a g [ a ] X = a X

当一个行向量左乘一个对角矩阵时, 相当于这个行向量按元素乘那个矩阵对角组成的向量:
a T d i a g [ b ] = a b

在t时刻, LSTM的输出值为 h t . 我们定义t时刻的误差项 δ t 为:
δ t = d e f E h t

这里假设误差项是损失函数对输出值的导数, 而不是对加权输出 n e t t l 的导数. 因为LSTM有四个加权输入, 分别对应 f t , i t , c t , o t , 我们希望往上一层传递一个误差项而不是四个, 但需要定义这四个加权输入以及它们对应的误差项.
(12) n e t f , t = W f [ h t 1 , x t ] + b f (13) = W f h h t 1 + W f x x t + b f (14) n e t i , t = W i [ h t 1 , x t ] + b i (15) = W i h h t 1 + W i x x t + b i (16) n e t c ~ , t = W c [ h t 1 , x t ] + b c (17) = W c h h t 1 + W c x x t + b c (18) n e t o , t = W o [ h t 1 , x t ] + b o (19) = W o h h t 1 + W o x x t + b o (20) δ f , t = d e f E n e t f , t (21) δ i , t = d e f E n e t i , t (22) δ c ~ , t = d e f E n e t c ~ , t (23) δ o , t = d e f E n e t o , t

误差项沿时间的反向传递

沿时间反向传递误差项, 就是要计算出t-1时刻的误差项 δ t 1 .

(24) δ t 1 T = E h t 1 (25) = E h t h t h t 1 (26) = δ t T h t h t 1

其中, h t h t 1 是一个Jacobian矩阵, 为了求出它, 需要列出 h t 的计算公式, 即前面的 式6式4:
h t = o t tanh ( c t ) ( 6 ) c t = f t c t 1 + i t c ~ t ( 4 )

显然, o t , f t , i t , c ~ t 都是 h t 1 的函数, 那么, 利用全导数公式可得:
(27) δ t T h t h t 1 = δ t T h t o t o t n e t o , t n e t o , t h t 1 + δ t T h t c t c t f t f t n e t f , t n e t f , t h t 1 (28) + δ t T h t c t c t i t i t n e t i , t n e t i , t h t 1 + δ t T h t c t c t c ~ t c ~ t n e t c ~ , t n e t c ~ , t h t 1 (29) = δ o , t T n e t o , t h t 1 + δ f , t T n e t f , t h t 1 + δ i , t T n e t i , t h t 1 + δ c ~ , t T n e t c ~ , t h t 1 ( 7 )

下面, 要把 式7中的每个偏导数都求出来, 根据 式6, 可以求出:
(30) h t o t = d i a g [ tanh ( c t ) ] (31) h t c t = d i a g [ o t ( 1 tanh ( c t ) 2 ) ]

根据 式4, 可以求出:
(32) c t f t = d i a g [ c t 1 ] (33) c t i t = d i a g [ c ~ t ] (34) c t c ~ t = d i a g [ i t ]

因为:
(35) o t = σ ( n e t o , t ) (36) n e t o , t = W o h h t 1 + W o x x t + b o (37) (38) f t = σ ( n e t f , t ) (39) n e t f , t = W f h h t 1 + W f x x t + b f (40) (41) i t = σ ( n e t i , t ) (42) n e t i , t = W i h h t 1 + W i x x t + b i (43) (44) c ~ t = tanh ( n e t c ~ , t ) (45) n e t c ~ , t = W c h h t 1 + W c x x t + b c

可以得出:
(46) o t n e t o , t = d i a g [ o t ( 1 o t ) ] (47) n e t o , t h t 1 = W o h (48) f t n e t f , t = d i a g [ f t ( 1 f t ) ] (49) n e t f , t h t 1 = W f h (50) i t n e t i , t = d i a g [ i t ( 1 i t ) ] (51) n e t i , t h t 1 = W i h (52) c ~ t n e t c ~ , t = d i a g [ 1 c ~ t 2 ] (53) n e t c ~ , t h t 1 = W c h

将上述偏导数导入到 式7, 可以得到:
(54) δ t 1 = δ o , t T n e t o , t h t 1 + δ f , t T n e t f , t h t 1 + δ i , t T n e t i , t h t 1 + δ c ~ , t T n e t c ~ , t h t 1 (55) = δ o , t T W o h + δ f , t T W f h + δ i , t T W i h + δ c ~ , t T W c h ( 8 )

根据 δ o , t , δ f , t , δ i , t , δ c ~ , t 的定义, 可知:
(56) δ o , t T = δ t T tanh ( c t ) o t ( 1 o t ) ( 9 ) (57) δ f , t T = δ t T o t ( 1 tanh ( c t ) 2 ) c t 1 f t ( 1 f t ) ( 10 ) (58) δ i , t T = δ t T o t ( 1 tanh ( c t ) 2 ) c ~ t i t ( 1 i t ) ( 11 ) (59) δ c ~ , t T = δ t T o t ( 1 tanh ( c t ) 2 ) i t ( 1 c ~ 2 ) ( 12 )

式8式12就是将误差沿时间反向传播一个时刻的公式. 有了它, 便可以写出将误差项传递到任意k时刻的公式:
δ k T = j = k t 1 δ o , j T W o h + δ f , j T W f h + δ i , j T W i h + δ c ~ , j T W c h ( 13 )

将误差项传递到上一层

假设当前是第 l 层, 定义 l 1 层的误差项是误差函数对 l 1 加权输入的导数, 即:

δ t l 1 = d e f E n e t t l 1

本次LSTM的输入 x t 由下面的公式计算:
x t l = f l 1 ( n e t t l 1 )

上式中, f l 1 表示第 l 1 激活函数.

因为 n e t f , t l , n e t i , t l , n e t c ~ , t l , n e t o , t l 都是 x t 的函数, x t 又是 n e t t l 1 的函数, 因此, 要求出 E n e t t l 1 的导数, 就需要使用全导数公式:

(60) E n e t t l 1 = E n e t f , t l n e t f , t l x t l x t l n e t t l 1 + E n e t i , t l n e t i , t l x t l x t l n e t t l 1 (61) + E n e t c ~ , t l n e t c ~ , t l x t l x t l n e t t l 1 + E n e t o , t l n e t o , t l x t l x t l n e t t l 1 (62) = δ f , t T W f x f ( n e t t l 1 ) + δ i , t T W i x f ( n e t t l 1 ) + δ c ~ , t T W c x f ( n e t t l 1 ) + δ o , t T W o x f ( n e t t l 1 ) (63) = ( δ f , t T W f x + δ i , t T W i x + δ c ~ , t T W c x + δ o , t T W o x ) f ( n e t t l 1 ) ( 14 )

式14就是将误差传递到上一层的公式.

权重梯度的计算

对于 W f h , W i h , W c h , W o h 的权重梯度, 我们知道它的梯度是各个时刻梯度之和. 我们首先求出它们在t时刻的梯度, 然后再求出他们最终的梯度.

我们已经求得了误差项 δ o , t , δ f , t , δ i , t , δ c ~ , t , 很容易求出t时刻的 W o h , W f h , W i h , W c h :

(64) E W o h , t = E n e t o , t n e t o , t W o h , t (65) = δ o , t h t 1 T (66) (67) E W f h , t = E n e t f , t n e t f , t W f h , t (68) = δ f , t h t 1 T (69) (70) E W i h , t = E n e t i , t n e t i , t W i h , t (71) = δ i , t h t 1 T (72) (73) E W c h , t = E n e t c ~ , t n e t c ~ , t W c h , t (74) = δ c ~ , t h t 1 T

将各个时刻的梯度加在一起, 就能得到最终的梯度:

(75) E W o h = j = 1 t δ o , j h j 1 T (76) E W f h = j = 1 t δ f , j h j 1 T (77) E W i h = j = 1 t δ i , j h j 1 T (78) E W c h = j = 1 t δ c ~ , j h j 1 T

对于偏置项 b f , b i , b c , b o 的梯度, 先求出各个时刻的偏置项梯度:
(79) E b o , t = E n e t o , t n e t o , t b o , t (80) = δ o , t (81) (82) E b f , t = E n e t f , t n e t f , t b f , t (83) = δ f , t (84) (85) E b i , t = E n e t i , t n e t i , t b i , t (86) = δ i , t (87) (88) E b c , t = E n e t c ~ , t n e t c ~ , t b c , t (89) = δ c ~ , t

将各个时刻的偏置项梯度加在一起:
(90) E b o = j = 1 t δ o , j (91) E b i = j = 1 t δ i , j (92) E b f = j = 1 t δ f , j (93) E b c = j = 1 t δ c ~ , j

对于 W f x , W i x , W c x , W o x 的权重梯度, 只需要根据相应的误差项直接计算即可:
(94) E W o x = E n e t o , t n e t o , t W o x (95) = δ o , t x t T (96) (97) E W f x = E n e t f , t n e t f , t W f x (98) = δ f , t x t T (99) (100) E W i x = E n e t i , t n e t i , t W i x (101) = δ i , t x t T (102) (103) E W c x = E n e t c ~ , t n e t c ~ , t W c x (104) = δ c ~ , t x t T

以上就是LSTM的训练算法的全部公式

GRU

上面所述是一种普通的LSTM, 事实上LSTM存在很多变体, GRU就是其中一种最成功的变体. 它对LSTM做了很多简化, 同时保持和LSTM相同的效果.

GRU对LSTM做了两大改动:

  1. 将输入门, 遗忘门, 输出门变为两个门: 更新门(Update Gate) z t 和充值门(Reset Gate) r t .
  2. 将单元状态与输出合并为一个状态: h

GRU的前向计算公式为:

(105) z t = σ ( W z [ h t 1 , x t ] ) (106) r t = σ ( W r [ h t 1 , x t ] ) (107) h ~ t = tanh ( W [ r t h t 1 , x t ] ) (108) h = ( 1 z t ) h t 1 + z t h ~ t

下图是GRU的示意图:

循环神经网络2--LSTM