1.基本RNN结构
这几天想入门NLP,所以开始了解RNN以及一系列变体。首先RNN最原始的结构如下图(图是按自己的理解用visio画的,有错麻烦提一下),
首先我们来说明一下各个符号的定义:
各个变量之间的关系如下:
2.RNN推导
其实,在RNN中,框架并不大,整体的参数只有W和b,因为这些参数是共用的。下面说一下RNN的loss函数,它的loss是随你的需求变化,比如你的RNN是多对多,那么总的loss就是所有输出的loss之和,如果RNN的是多对一,那么总loss就是最后一个输出的loss。下面就具体说一说RNN的反向传播,这里叫BPTT (Back propagation through time).
- loss计算
其中
,为了便于推导,假设我们的训练batch_size=1,
- 误差推导
为了方便推导,我把变量之间的关系在这里再写一遍:
具体推导如下:
有一个小提示就是,注意这里求导的时候,由于a<t>与a<t+1>和y_hat<t>均有关系,所以链式求导的时候要算这两个部分。
然后我们把式子通过向量化简化一下:
- 梯度计算
当我我们把误差算出来,那么各个参数的梯度就很简单了~~~
剩下的就是通过迭代更新了~~其实整个推导也不是很难~~~只要把几个量的关系理清楚就可以了~~