目录
- RNN为何能记忆以及它面临的问题
- LSTM的网络结构
- LSTM的思想
- LSTM的详细网络结构
步骤1:存什么,丢什么(forget gate layer:忘记门)
步骤2:更新什么信息(输入门)
步骤3:开始更新信息
步骤4:当前时刻的输出 - 你可能会问关于LSTM的问题
- LSTM的变种(GRU)
- 总结
1.RNN为何能记忆以及它面临的问题
当你在看一部电影的时候,你使用传统的神经网络无法预测在下一时刻,这部电影的内容将是什么?循环神经网络解决了这个问题:因为循环神经网络能够记忆。
在上面这个例子中,像A这个的cell有一堆,然后他们连起来,当有一个输入信息
上一篇文章中(没有看过的点我)我们用RNN解决了序列型信息的之间的依赖问题,如图:
我们可以用之前出现的句子,比如根据“我是中国“这几个词预测下一个词是什么。
但有一个非常大的问题是:信息之间存在长期依赖问题。
比如有这么一段话:我出生在中国,但从小就去美国了,从小就接受美国式的教育,在美国有很多非常好的大学,·········,尽管我上的是美国最好的大学,但我不太会说普通话。你会发现这段信息很长,但真正重要的是第一句。
随着时间间隔的增大,RNN会丧失连接到很远的信息的能力,也就是当你上大四的要考研的时候,你发现你大一学的内容很多都不记得了,那该怎么办呢?
2.LSTM的网络结构
很高兴LSTM解决了这个问题。
LSTM和RNN的网络结构相差无几,区别是LSTM的cell被改造过了。它是被设计来解决长期依赖问题的,它有一种能力:把该记的东西记下来,不该记的东西忘掉。就像你准备考研,或者准备高考的时候一样,重点学那些老师说重点复习,可能会考的内容,而那些大家觉得不怎么重要的内容就直接不复习了。
在标准的RNN结构中,它有很多重复的cell,而且通过一个简单的tanh来**,如图:
与标准RNN不同的是,LSTM的cell是被改造过的,可以发现RNN只有一层网络,但LSTM有四层:
我们来一层一层看一下他们具体干了啥事,在此之前,先来定义一下我们的符号,其中Pointwise Operation表示逐点运算。
3.LSTM的思想
LSTM的本质是多了一条传送带之类的东西,我比较喜欢称呼它传送带,所以下文都以传送带称呼;还有细胞状态(cell state)指的是一个cell的变化。传送带会跟每个网络交互,可以存储信息,也可以删除信息,在传送带上记录着所有有用的信息,删除没用的信息,
它会通过门来选择是否让信息通过,在神经网络中,什么过滤信息的效果最好呀?当然是**函数呀,这里选用了sigmoid函数。
为什么选用sigmoid函数呢?因为sigmoid函数的函数值是0-1的,这就是一个概率嘛,它能衡量到底让信息过去多少。如果是0,就什么也不让它过去,如果是1,就把全部信息放过去。
让我们现在一步一步来了解它每一层都做了什么吧。
4.LSTM的详细网络结构
步骤1:存什么,丢什么(forget gate layer:忘记门)
首先应该考虑的是什么信息该存起来,什么信息不该存起来。起这个决策作用的是forget gate layer,它是一个sigmoid层,如果是1,意味着存下所有的信息,如果是0,意味着丢弃所有的信息。其中
比如上一时刻是“我是中国人“,但当前时刻是“我是中国共青团员“,我们就希望把‘人‘忘记,填上‘共青团员‘。也就是在传送带上用新的信息代替旧的信息。
步骤2:更新什么信息(输入门)
下一步就是看一下什么样的新信息应该存到传送带。这一步分为两部分,第一部分就是sigmoid层(input gate layer,不知道怎么翻译的好),它决定我们哪部分的信息应该更新。第二部分就是tanh层,它为每个能放到传送带上值(候选值)创建一个向量
步骤3:开始更新信息
上一步我们已经知道该存什么,该丢弃什么信息了,这一步就是直接做就OK了。
我们在步骤1的时候决定了要忘记存什么,忘记什么,得到的概率向量
举个小栗子来总结一下前三步:
假设你在大四的时候要去考研,我们肯定都有自己的方向是吧,你只要学习你考的内容就OK了,这个过程就是一个筛选的过程,把不考的丢掉,把考的留下,对应这里就是通过概率向量
步骤4:当前时刻的输出
最后,我们要决定当前时刻应该输出什么内容了。
这个输出是基于我们的传送带的,而且是过滤之后的。首先我们通过一个sigmoid层来决定传送带上的哪一部分信息应该被输出,得到决定当前时刻输出的概率向量
举个小栗子:还是考研的栗子,假设哥们现在已经都准备好,终于要考研了,在考数学的时候,发现了一道求解积分的题目,那我肯定得把以前关于积分的知识拿出来解题啦,那其他的在解这道题的时候就没什么用,可以暂时不用,对应这张图中
5.你可能会问关于LSTM的问题
问题1:有人可能会问:为什么sigmoid不能用relu替换,非要用sigmoid函数呢?
这个问题满分,我觉得最好的解释就是:LSTM使用sigmoid的本质作用是生成过滤信息的概率向量,概率都要在0-1的区间范围内,而relu不能生成概率。
问题2:最后输出结果的时候,tanh能不能换成其他的**函数?
这个问题不好说,因为根据实验结果没有一个非常确定的答案,而且LSTM的变种特别多,这里是标准的LSTM,它用tanh的主要作用是让每次输出的结果都维持在一个固定的范围内,不至于出现逐渐变大到最后膨胀爆炸的现象。
问题3:为什么LSTM能解决长时依赖的问题呀?
大家回想一下RNN的记忆是怎么来的呀?
那什么时候梯度会等于0呀?那就是链路非常非常长的时候呀,因为链路越长,哪怕其中有一个的偏导是0,那所有相乘都是0。
那LSTM是怎么解决的呢?
我看一下LSTM的记忆是怎么来的:
6.LSTM的变种
上面我们说的是一个标准版的LSTM,但并非所有的LSTM都是这个样子的,以下这个些都是比较常见的变种。
变种1
这个版本增加了一个叫peephole connections.的东西,其实最简单的理解就是在求解概率向量的时候,把记忆那一部分也加到了输入里。
变种2
之前是分开决定需要忘记哪部分,记住哪部分,现在他们是一起决定哪部分忘记,哪部分记住。它相当了忘记了多少,就把多少再补回来。比如你考研的时候知识总量为1,0.7是大三及以前学的,0.3是大四为了考研专门准备的。(例子可能不完全正确)
变种3:GRU(Gated Recurrent Unit)
更加著名的变种是GRU,它把忘记门和输入门都变成了更新门,同时融合了细胞状态和隐藏状态,使得最后的结果比标准的LSTM更为简单。
7.总结
虽然有那么多变种,到底哪个最好,其实没有定论,这也要看数据是什么样的。也有人说他们的效果都差不多,在此不深究,有兴趣可以看一下这篇文章 Greff, et al. (2015)。
这个人做了一万多个RNN的工程测试,发现在一些场合下,有部分模型是可以做的比LSTM更好的 Jozefowicz, et al. (2015)。
这篇文章从RNN面临的问题开始引出LSTM,它解决了长时依赖问题,然后痛殴介绍LSTM的详细网络结构,通过与RNN对比,全面的诠释了为什么LSTM能解决长时依赖问题,而RNN不能解决,在最后还顺便提了LSTM的三个变种,虽然没有非常全面的介绍,但其原理相差无几。
LSTM为我们研究RNN迈进了一大步,你可能会问:有没有更大的一步呢?答案是:有。那就是注意力模型,这部分我将会在下一篇文章介绍。