3.RNN推导

时间:2023-02-02 23:43:39

1.基本RNN结构

 这几天想入门NLP,所以开始了解RNN以及一系列变体。首先RNN最原始的结构如下图(图是按自己的理解用visio画的,有错麻烦提一下),

  

3.RNN推导

首先我们来说明一下各个符号的定义:

3.RNN推导

各个变量之间的关系如下:

3.RNN推导

2.RNN推导

其实,在RNN中,框架并不大,整体的参数只有W和b,因为这些参数是共用的。下面说一下RNN的loss函数,它的loss是随你的需求变化,比如你的RNN是多对多,那么总的loss就是所有输出的loss之和,如果RNN的是多对一,那么总loss就是最后一个输出的loss。下面就具体说一说RNN的反向传播,这里叫BPTT (Back propagation through time).

  • loss计算

    其中

      3.RNN推导,为了便于推导,假设我们的训练batch_size=1,

        3.RNN推导

        3.RNN推导

  • 误差推导

    为了方便推导,我把变量之间的关系在这里再写一遍:

3.RNN推导

     具体推导如下:

      有一个小提示就是,注意这里求导的时候,由于a<t>与a<t+1>和y_hat<t>均有关系,所以链式求导的时候要算这两个部分。

3.RNN推导

    然后我们把式子通过向量化简化一下:

3.RNN推导

  • 梯度计算

    当我我们把误差算出来,那么各个参数的梯度就很简单了~~~

3.RNN推导

    剩下的就是通过迭代更新了~~其实整个推导也不是很难~~~只要把几个量的关系理清楚就可以了~~