Keras深度学习实战(29)——长短时记忆网络详解与实现
0. 前言
长短时记忆网络 (Long Short Term Memory
, LSTM
),顾名思义是具有记忆长短期信息能力的神经网络,解决了循环神经网络 (Recurrent neural networks
, RNN
) 梯度爆炸/消失的问题,是建立在循环神经网络上的一种新型深度学习的时间序列模型,它具有高度的学习能力与模拟能力,具有记忆可持续性的特点,且能预测未来的任意步长。本文首先介绍了 RNN
模型的局限性,从而引入介绍长短时记忆网络 (Long Short Term Memory
, LSTM
) 的基本原理,最后通过实现 LSTM
进行深入了解。
1. RNN 的局限性
我们首先可视化 RNN
在考虑多个时刻做出预测时的情况,如下所示,随着时间的增加,早期输入的影响会逐渐降低:
更具体的,我们也可以通过公式得到相同的结论,例如我们需要计算第 5
个时刻网络的中间状态:
h 5 = W X 5 + U h 4 = W X 5 + U W X 4 + U 2 W X 3 + U 3 W X 2 + U 4 W X 1 h_5 = WX_5 + Uh_4 = WX_5 + UWX_4 + U_2WX_3 + U_3WX_2 + U_4WX_1 h5=WX5+Uh4=WX5+UWX4+U2WX3+U3WX2+U4WX1
可以看到,随着时间的增加,如果
U
>
1
U>1
U>1,则网络中间状态的值高度依赖于
X
1
X_1
X1;而如果
U
<
1
U<1
U<1,则网络中间状态值对
X
1
X_1
X1 的依赖就少得多。对 U
矩阵的依赖性还可能在 U
值很小时导致梯度消失,而在 U
值很高时会导致梯度爆炸。
当在预测单词时存在长期依赖性时,RNN
的这种现象将导致无法学习长期依赖关系的问题。为了解决这个问题,我们将引入介绍长短期记忆 (Long Short Term Memory
, LSTM
) 体系结构。
2. LSTM 模型架构详解
在有关传统 RNN
的问题中,我们了解了 RNN
对于长期依赖问题无济于事。例如,假设输入句子如下:
I live in China. I speak ____.
可以通过关键字 China
来推测以上空中应填充的单词,但该关键字与我们要预测的单词距离 3
个时间戳。如果关键字远离要预测的单词,则需要解决消失/爆炸梯度问题。
2.1 LSTM 架构
在本节中,我们将学习 LSTM
如何帮助克服RNN体系结构的长期依赖缺点,并构建一个简单示例,以便了解 LSTM
的各个组成部分。LSTM
架构示意图如下所示:
可以看到,虽然每个时刻 (h
) 的输入 X
和输出保持不变,但是在网络中使用不同的计算方式和激活函数。
2.2 LSTM 各组成部分与计算流程
接下来,我们详细介绍在一个时间戳内的计算过程:
在上图中,
x
x
x 和
h
h
h 表示输入层和 LSTM
的输出向量,内部状态向量 Memory
存储在单元状态
c
c
c 中也就是说,相较于基础 RNN
而言,LSTM
将内部状态向量 Memory
和输出分开为两个变量,利用输入门 (Input Gate
)、遗忘门 (Forget Gate
)和输出门 (Output Gate
) 三个门控来控制内部信息的流动。门控机制是一种控制网络中数据流通量的手段,可以较好地控制数据流通的流量程度。
2.2.1 遗忘门
需要忘记的内容是通过“遗忘门
”获得的,用于控制上一个时间戳的记忆
c
t
−
1
c_{t-1}
ct−1 对当前时间戳的影响,遗忘门的控制变量
f
t
f_t
ft 由:
f t = σ ( W x f x ( t ) + W h f h ( t − 1 ) + b f ) f_t=\sigma(W_{xf}x^{(t)}+W_{hf}h^{(t-1)}+b_f) ft=σ(Wxfx(t)+Whfh(t−1)+bf)
sigmoid
激活函数使网络能够选择性地识别需要忘记的内容。在确定需要忘记的内容后,更新后的单元状态如下:
c t = ( c ( t − 1 ) ⊗ f ) c_t=(c_{(t-1)}\otimes f) ct=(c(t−1)⊗f)
其中,
⊗
\otimes
⊗ 表示逐元素乘法。例如,如果句子的输入序列是 I live in China. I speak ___
,可以根据输入的单词 China
来填充空格,在之后,我们可能并不再需要有关国家名称的信息。我们根据当前时间戳需要忘记的内容来更新单元状态。
2.2.2 输入门
输入门用于控制 LSTM
对输入的接受程度,根据当前时间戳提供的输入将其他信息添加到单元状态中,通过 tanh
激活函数获得更新,因此也称为更新门。首先通过对当前时间戳的输入和上一时间戳的输出作非线性变换:
i t = σ ( W x i x ( t ) + W h i h ( t − 1 ) + b i ) i_t=\sigma(W_{xi}x^{(t)}+W_{hi}h^{(t-1)}+b_i) it=σ(Wxix(t)+Whih(t−1)+bi)
输入门中,输入更新计算方法如下:
g t = t a n h ( W x g x ( t ) + W h g h ( t − 1 ) + b g ) g_t=tanh(W_{xg}x^{(t)}+W_{hg}h^{(t-1)}+b_g) gt=tanh(Wxgx(t)+Whgh(t−1)+bg)
在当前时间戳中需要忘记某些信息,并在其中添加一些其他信息,此时单元状态将按以下方式更新:
c ( t ) = ( c ( t 1 − ) ⊙ f t ) ⊕ ( i t ⊙ g t ) c^{(t)}=(c^{(t1-)}\odot f_t)\oplus(i_t\odot g_t) c(t)=(c(t1−)⊙ft)⊕(it⊙gt)
得到的新的状态向量 c ( t ) c^{(t)} c(t) 即为当前时间戳的状态向量。
2.2.3 输入门
最后一个门称为输出门,我们需要指定输入组合和单元状态的哪一部分需要传递到下一个时刻,输入组合包括当前时间戳的输入和前一时间戳的输出值:
o t = σ ( W x o x ( t ) + W h o h ( t − 1 ) + b o ) o_t=\sigma(W_{xo}x^{(t)}+W_{ho}h^{(t-1)}+b_o) ot=σ(Wxox(t)+Whoh(t−1)+bo)
最终的网络状态值表示如下:
h ( t ) = o t ⊙ t a n h ( c ( t ) ) h^{(t)}=o_t\odot tanh(c^{(t)}) h(t)=ot⊙tanh(c(t))
这样,我们就可以利用 LSTM
中的各个门来有选择地识别需要存储在存储器中的信息,从而克服了 RNN
的局限性。
3. 从零开始实现 LSTM
在本小节中,我们通过使用一个简单示例来了解 LSTM
的工作原理。
3.1 LSTM 模型实现
(1) 对输入数据进行预处理,该示例所用输入数据与预处理过程与在 RNN 模型中使用的完全相同:
# 定义输入与输出数据
docs = ['this is','is an']
# define class labels
labels = ['an','example']
from collections import Counter
counts = Counter()
for i,review in enumerate(docs+labels):
counts.update(review.split())
words = sorted(counts, key=counts.get, reverse=True)
vocab_size=len(words)
word_to_int = {word: i for i, word in enumerate(words, 1)}
encoded_docs = []
for doc in docs:
encoded_docs.append([word_to_int[word] for word in doc.split()])
encoded_labels = []
for label in labels:
encoded_labels.append([word_to_int[word] for word in label.split()])
from keras.utils import to_categorical
from keras.preprocessing.sequence import pad_sequences
# 数据预处理
max_length = 2
padded_docs = pad_sequences(encoded_docs, maxlen=max_length, padding='pre')
print(padded_docs)
one_hot_encoded_labels = to_categorical(encoded_labels, num_classes=5)
print(one_hot_encoded_labels)
(2) 定义 LSTM
模型:
from keras.layers import Dense, LSTM
from keras.models import Sequential
embed_length=1
max_length=2
model = Sequential()
model.add(LSTM(1,activation='tanh',return_sequences=False,recurrent_activation='sigmoid',input_shape=(max_length,embed_length),unroll=True))
model.add(Dense(5, activation='softmax'))
# 模型编译
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
model.summary()
该模型的简要信息输入如下:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm (LSTM) (None, 1) 12
_________________________________________________________________
dense (Dense) (None, 5) 10
=================================================================
Total params: 22
Trainable params: 22
Non-trainable params: 0
_________________________________________________________________
LSTM
层中的参数数量为 12
,其中包括输入门、遗忘门和输出门三部分,产生了 4
个权重和 4
个偏置将输入连接网络单元中。此外,还包含对应的 4
个权重,因此,总共 12
个参数。全连接层共有 10
个参数,因为输出包含 5
个可能的类别,因此有 5
个权重和 5
个偏置将 LSTM
的输出连接到到输出层中。
(3) 拟合模型:
import numpy as np
model.fit(padded_docs.reshape(2,2,1),np.array(one_hot_encoded_labels),epochs=100)
# 输出模型权重信息
print(model.weights)
该模型的权重信息输出如下:
[<tf.Variable 'lstm/lstm_cell/kernel:0' shape=(1, 4) dtype=float32, numpy=
array([[ 0.92491925, -0.93431926, -0.67187965, -0.00756256]],
dtype=float32)>, <tf.Variable 'lstm/lstm_cell/recurrent_kernel:0' shape=(1, 4) dtype=float32, numpy=
array([[ 0.27872017, -0.48063734, 0.31194845, -0.72623277]],
dtype=float32)>, <tf.Variable 'lstm/lstm_cell/bias:0' shape=(4,) dtype=float32, numpy=array([ 0.11236392, 0.93647027, -0.10823309, 0.11666972], dtype=float32)>, <tf.Variable 'dense/kernel:0' shape=(1, 5) dtype=float32, numpy=
array([[-0.28399688, 0.40721267, 0.17018904, 0.58124113, -0.5605382 ]],
dtype=float32)>, <tf.Variable 'dense/bias:0' shape=(5,) dtype=float32, numpy=
array([-0.09865286, -0.09742296, 0.09871767, -0.09712653, 0.09683956],
dtype=float32)>]
可以看到,在 LSTM
中的权重顺序如下:
- 输入权重 (
kenel
) - 对应于单元状态的权重 (
recurrent_kernel
) -
LSTM
层中的偏差 (bias
)
(4) 计算对于输入的预测,整形输入形状以适应 predict
方法,使其与 LSTM
预期的输入数据格式相同,即(批大小, 时间戳长度, 每个时间戳的特征数):
print(model.predict(padded_docs[0].reshape(1,2,1)))
使用 predict
方法预测得到的输入如下:
[[0.20947632 0.15290551 0.20733334 0.14125276 0.28903207]]
3.2 验证输出
我们已经使用训练后的模型中获得了一个预测概率,接下来,使用 NumPy
提取网络权重计算网络的前向传播过程来验证模型输出,以此加深 LSTM
的计算流程的理解。验证所构建模型的前向传播过程获取输出的步骤如下。
(1) 在时间戳 1
中更新遗忘门,此步需要利用输入,然后计算到目前为止的单元状态(或称内存)中需要被遗忘的信息:
input_t0 = padded_docs[0][0]
cell_state0 = 0
forget0 = input_t0*model.get_weights()[0][0][1] + model.get_weights()[2][1]
forget1 = 1/(1+np.exp(-(forget0)))
(2) 基于计算后的遗忘门更新单元状态,上一步的输出在此步中用于指示要从单元状态中忘记的信息:
cell_state1 = forget1 * cell_state0
(3) 更新时间戳 1
中的输入门的值,此步骤根据当前输入估计要向单元状态注入多少新信息:
input_t0_1 = input_t0*model.get_weights()[0][0][0] + model.get_weights()[2][0]
input_t0_2 = 1/(1+np.exp(-(input_t0_1)))
(4) 根据更新的输入值更新单元状态,此步需要使用上一步的输出来指示单元状态将发生的信息更新量:
input_t0_cell1 = input_t0*model.get_weights()[0][0][2] +model.get_weights()[2][2]
input_t0_cell2 = np.tanh(input_t0_cell1)
tanh
激活函数有利于确定输入的更新是否会增加或减少单元状态,如果某些信息已经在当前时间戳中传递,并且在将来的时间戳中没有用处,则最好将这些多余信息从单元状态中删除:
input_t0_cell3 = input_t0_cell2*input_t0_2
input_t0_cell4 = input_t0_cell3 + cell_state1
(5) 更新输出门,此步计算当前时间戳将向下一时间戳传送多少信息:
output_t0_1 = input_t0*model.get_weights()[0][0][3] + model.get_weights()[2][3]
output_t0_2 = 1/(1+np.exp(-output_t0_1))
(6) 输入计算时间戳 1
的网络中间状态,最终的网络中间状态值是当前时间戳中传送的单元状态和s门输出量的组合:
hidden_layer_1 = np.tanh(input_t0_cell4)*output_t0_2
通过以上步骤,我们完成计算了在第 1
个时间戳后的输出,接下来,我们利用在时戳 1
中更新的单元状态值和第 1
个时间戳的输出作为第 2
个时间戳输入的一部分。
(7) 传递第 2
个时间戳的输入值和进入第 2
个时间戳的单元状态值:
input_t1 = padded_docs[0][1]
cell_state1 = input_t0_cell4
更新遗忘门值:
forget21 = hidden_layer_1*model.get_weights()[1][0][1] + model.get_weights()[2][1] + input_t1*model.get_weights()[0][0][1]
forget_22 = 1/(1+np.exp(-(forget21)))
更新第 2
个时间戳时单元状态值:
cell_state2 = cell_state1 * forget_22
input_t1_1 = input_t1*model.get_weights()[0][0][0] + model.get_weights()[2][0] + hidden_layer_1*model.get_weights()[1][0][0]
input_t1_2 = 1/(1+np.exp(-(input_t1_1)))
input_t1_cell1 = input_t1*model.get_weights()[0][0][2] + model.get_weights()[2][2]+ hidden_layer_1*model.get_weights()[1][0][2]
input_t1_cell2 = np.tanh(input_t1_cell1)
input_t1_cell3 = input_t1_cell2*input_t1_2
input_t1_cell4 = input_t1_cell3 + cell_state2
根据更新的单元状态和网络中间状态,更新当前时间戳的输出值:
output_t1_1 = input_t1*model.get_weights()[0][0][3] + model.get_weights()[2][3]+ hidden_layer_1*model.get_weights()[1][0][3]
output_t1_2 = 1/(1+np.exp(-output_t1_1))
hidden_layer_2 = np.tanh(input_t1_cell4)*output_t1_2
(8) 通过全连接层传递 LSTM
输出:
final_output = hidden_layer_2 * model.get_weights()[3][0] +model.get_weights()[4]
在以上输出上执行 softmax
函数:
print(np.exp(final_output)/np.sum(np.exp(final_output)))
# [0.2094763 0.15290551 0.20733333 0.14125276 0.28903207]
小结
长短时记忆网络 (Long Short Term Memory
, LSTM
) 解决了循环神经网络 (Recurrent neural networks
, RNN
) 梯度爆炸/消失的问题,是建立在循环神经网络基础上的一种新型深度学习的时间序列模型,在大多数任务中,使用 LSTM
比使用传统 RNN
模型能够取得更好的效果。本节中,首先介绍了传统 RNN
模型的局限性,然后详细介绍了 LSTM
的基本原理与各组成部分的计算细节,最后,使用 Keras
实现了 LSTM
模型用以加深对 LSTM
背后运行机制的理解。
系列链接
Keras深度学习实战(1)——神经网络基础与模型训练过程详解
Keras深度学习实战(2)——使用Keras构建神经网络
Keras深度学习实战(3)——神经网络性能优化技术
Keras深度学习实战(4)——深度学习中常用激活函数和损失函数详解
Keras深度学习实战(5)——批归一化详解
Keras深度学习实战(6)——深度学习过拟合问题及解决方法
Keras深度学习实战(7)——卷积神经网络详解与实现
Keras深度学习实战(8)——使用数据增强提高神经网络性能
Keras深度学习实战(9)——卷积神经网络的局限性
Keras深度学习实战(10)——迁移学习详解
Keras深度学习实战(11)——可视化神经网络中间层输出
Keras深度学习实战(12)——面部特征点检测
Keras深度学习实战(13)——目标检测基础详解
Keras深度学习实战(14)——从零开始实现R-CNN目标检测
Keras深度学习实战(15)——从零开始实现YOLO目标检测
Keras深度学习实战(16)——自编码器详解
Keras深度学习实战(17)——使用U-Net架构进行图像分割
Keras深度学习实战(18)——语义分割详解
Keras深度学习实战(19)——使用对抗攻击生成可欺骗神经网络的图像
Keras深度学习实战(20)——DeepDream模型详解
Keras深度学习实战(21)——神经风格迁移详解
Keras深度学习实战(22)——生成对抗网络详解与实现
Keras深度学习实战(23)——DCGAN详解与实现
Keras深度学习实战(24)——从零开始构建单词向量
Keras深度学习实战(25)——使用skip-gram和CBOW模型构建单词向量
Keras深度学习实战(26)——文档向量详解
Keras深度学习实战(27)——循环神经详解与实现
Keras深度学习实战(28)——利用单词向量构建情感分析模型