深度学习之seq2seq模型以及Attention机制

时间:2023-01-18 10:42:04

RNN,LSTM,seq2seq等模型广泛用于自然语言处理以及回归预测,本期详解seq2seq模型以及attention机制的原理以及在回归预测方向的运用。

1. seq2seq模型介绍

  seq2seq模型是以编码(Encode)和解码(Decode)为代表的架构方式,seq2seq模型是根据输入序列X来生成输出序列Y,在翻译,文本自动摘要和机器人自动问答以及一些回归预测任务上有着广泛的运用。以encode和decode为代表的seq2seq模型,encode意思是将输入序列转化成一个固定长度的向量,decode意思是将输入的固定长度向量解码成输出序列。其中编码解码的方式可以是RNN,CNN等。

  深度学习之seq2seq模型以及Attention机制

图1. encode和decode框架

上图为seq2seq的encode和decode结构,采用CNN/LSTM模型。在RNN中,当前时间的隐藏状态是由上一时间的状态和当前时间的输入x共同决定的,即

深度学习之seq2seq模型以及Attention机制

【编码阶段】

  得到各个隐藏层的输出然后汇总,生成语义向量

深度学习之seq2seq模型以及Attention机制

  也可以将最后的一层隐藏层的输出作为语义向量C                      深度学习之seq2seq模型以及Attention机制

【解码阶段】 

  这个阶段,我们要根据给定的语义向量C和输出序列y1,y2,…yt−1来预测下一个输出的单词yt,即

深度学习之seq2seq模型以及Attention机制

  也可以写做                                                                         深度学习之seq2seq模型以及Attention机制

  其中g()代表的是非线性激活函数。在RNN中可写成 yt=g(yt1,ht,C) ,其中h为隐藏层的输出。

以上就是seq2seq的编码解码阶段,seq2seq模型的抽象框架可描述为下图:

深度学习之seq2seq模型以及Attention机制

图2. seq2seq抽象框架图

2.Attention机制在seq2seq模型中的运用

2.1 自然语言处理中的Attention机制

  由于encoder-decoder模型在编码和解码阶段始终由一个不变的语义向量C来联系着,编码器要将整个序列的信息压缩进一个固定长度的向量中去。这就造成了 (1)语义向量无法完全表示整个序列的信息,(2)最开始输入的序列容易被后输入的序列给覆盖掉,会丢失许多细节信息。在长序列上表现的尤为明显。

  Attention模型的引入:

  相比于之前的encoder-decoder模型,attention模型最大的区别就在于它不在要求编码器将所有输入信息都编码进一个固定长度的向量之中。相反,此时编码器需要将输入编码成一个向量的序列,而在解码的时候,每一步都会选择性的从向量序列中挑选一个子集进行进一步处理。这样,在产生每一个输出的时候,都能够做到充分利用输入序列携带的信息。而且这种方法在翻译任务中取得了非常不错的成果。

  下图为seq2seq模型加入了Attention注意力机制

深度学习之seq2seq模型以及Attention机制

图3. Attention注意力机制的seq2seq模型

  

【seq2seq的attention解码过程】

  现在定义条件概率:深度学习之seq2seq模型以及Attention机制

  上式 s表示解码器 i 时刻的隐藏状态。计算公式为:

深度学习之seq2seq模型以及Attention机制

  注意这里的条件概率与每个目标输出  yi相对应的内容向量  ci有关。在sea2seq模型中,只有一个语义向量C。‘s’为隐藏层输出,相当于上面提到的h。

  关键问题是语义向量 C 怎么得到?  

  c是由编码时的隐藏向量序列(h1,…,hTx)按权重相加得到的。

深度学习之seq2seq模型以及Attention机制

  将隐藏向量序列按权重相加,表示在生成第j个输出的时候的注意力分配是不同的。αij的值越高,表示第i个输出在第j个输入上分配的注意力越多,在生成第i个输出的时候受第j个输入的影响也就越大。

  这意味着在生成每个单词Yi的时候,原先都是相同的中间语义表示C会替换成根据当前生成单词而不断变化的Ci。理解AM模型的关键就是这里,即由固定的中间语义表示C换成了根据当前输出单词来调整成加入注意力模型的变化的Ci

  如何得到 αij 的权重值?

  由第i-1个输出隐藏状态 si−1 和输入中各个隐藏状态共同决定的,即:

深度学习之seq2seq模型以及Attention机制

  si−1 先跟每个h分别计算得到一个数值,然后使用softmax函数得到i时刻的输出在Tx个输入隐藏状态中的注意力分配向量。这个分配向量也就是计算ci的权重。

深度学习之seq2seq模型以及Attention机制

图4.   分配概率(权值)的计算

  图4 显示的是Attention模型在计算αij 的概率分配过程。

对于采用RNN的Decoder来说,如果要生成yi单词,在时刻i,我们是可以知道在生成Yi之前的隐层节点i时刻的输出值Hi的,而我们的目的是要计算生成Yi时的输入句子单词“Tom”、“Chase”、“Jerry”对Yi来说的注意力分配概率分布,那么可以用i时刻的隐层节点状态Hi去一一和输入句子中每个单词对应的RNN隐层节点状态hj进行对比,即通过函数F(hj,Hi)来获得目标单词Yi和每个输入单词对应的对齐可能性,这个F函数在不同论文里可能会采取不同的方法,然后函数F的输出经过Softmax进行归一化就得到了符合概率分布取值区间的注意力分配概率分布数值。图4显示的是当输出单词为“汤姆”时刻对应的输入句子单词的对齐概率。绝大多数AM模型都是采取上述的计算框架来计算注意力分配概率分布信息,区别只是在F的定义上可能有所不同。

公式汇总:深度学习之seq2seq模型以及Attention机制

【Attention机制类别】

  Attention机制大的方向可分为 Soft Attention 和 Hard Attention 。

Soft Attention通常是指以上我们描述的这种全连接(如MLP计算Attention 权重),对每一层都可以计算梯度和后向传播的模型;不同于Soft attention那样每一步都对输入序列的所有隐藏层hj(j=1….Tx) 计算权重再加权平均的方法,Hard Attention是一种随机过程,每次以一定概率抽样,以一定概率选择某一个隐藏层 hj*,在估计梯度时也采用蒙特卡罗抽样Monte Carlo sampling的方法。

深度学习之seq2seq模型以及Attention机制

图5. Soft Attention 模型

深度学习之seq2seq模型以及Attention机制

图6. Hard Attention

考虑到计算量,attention的另一种替代方法是用强化学习(Reinforcement Learning)来预测关注点的大概位置。这听起来更像是人的注意力,这也是Recurrent Models of Visual Attention文中的作法。然而,强化学习模型不能用反向传播算法端到端训练,因此它在NLP的应用不是很广泛(我本人反而觉得这里有突破点,数学上的不可求解必然会得到优化,attention model在RL领域的应用确实非常有趣)

参考资料:http://blog.csdn.net/u014595019/article/details/52826423

     http://blog.csdn.net/wuzqChom/article/details/75792501

     http://blog.csdn.net/mpk_no1/article/details/72862348

     http://www.deepnlp.org/blog/textsum-seq2seq-attention/

     http://blog.csdn.net/malefactor/article/details/50550211

     http://blog.csdn.net/xbinworld/article/details/54607525

2.2 计算机视觉中的Attention机制

未完待续...

http://blog.csdn.net/sinat_33761963/article/details/53521206     

深度学习之seq2seq模型以及Attention机制的更多相关文章

  1. 深度学习的seq2seq模型——本质是LSTM,训练过程是使得所有样本的p(y1,...,yT‘|x1,...,xT)概率之和最大

    from:https://baijiahao.baidu.com/s?id=1584177164196579663&wfr=spider&for=pc seq2seq模型是以编码(En ...

  2. 时间序列深度学习:seq2seq 模型预测太阳黑子

    目录 时间序列深度学习:seq2seq 模型预测太阳黑子 学习路线 商业中的时间序列深度学习 商业中应用时间序列深度学习 深度学习时间序列预测:使用 keras 预测太阳黑子 递归神经网络 设置.预处 ...

  3. Seq2Seq模型 与 Attention 策略

    Seq2Seq模型 传统的机器翻译的方法往往是基于单词与短语的统计,以及复杂的语法结构来完成的.基于序列的方式,可以看成两步,分别是 Encoder 与 Decoder,Encoder 阶段就是将输入 ...

  4. 【转】[caffe]深度学习之图像分类模型AlexNet解读

    [caffe]深度学习之图像分类模型AlexNet解读 原文地址:http://blog.csdn.net/sunbaigui/article/details/39938097   本文章已收录于: ...

  5. [caffe]深度学习之图像分类模型VGG解读

    一.简单介绍 vgg和googlenet是2014年imagenet竞赛的双雄,这两类模型结构有一个共同特点是go deeper.跟googlenet不同的是.vgg继承了lenet以及alexnet ...

  6. 深度学习之 seq2seq 进行 英文到法文的翻译

    深度学习之 seq2seq 进行 英文到法文的翻译 import os import torch import random source_path = "data/small_vocab_ ...

  7. 深度学习 vs. 概率图模型 vs. 逻辑学

    深度学习 vs. 概率图模型 vs. 逻辑学 摘要:本文回顾过去50年人工智能(AI)领域形成的三大范式:逻辑学.概率方法和深度学习.文章按时间顺序展开,先回顾逻辑学和概率图方法,然后就人工智能和机器 ...

  8. Seq2Seq模型与注意力机制

    Seq2Seq模型 基本原理 核心思想:将一个作为输入的序列映射为一个作为输出的序列 编码输入 解码输出 解码第一步,解码器进入编码器的最终状态,生成第一个输出 以后解码器读入上一步的输出,生成当前步 ...

  9. java web应用调用python深度学习训练的模型

    之前参见了中国软件杯大赛,在大赛中用到了深度学习的相关算法,也训练了一些简单的模型.项目线上平台是用java编写的web应用程序,而深度学习使用的是python语言,这就涉及到了在java代码中调用p ...

随机推荐

  1. 韩国手机游戏Elf Defense角色场景

    ! [复制链接] CG窝微博 签到天数: 36 天 连续签到: 1 天 [LV.5]常住居民I 22 主题 0 精华 2729 窝币 超级版主 积分 2546 收听TA 发消息 电梯直达 楼主     ...

  2. 【分析】Parcelable的作用

    一.介绍 1.Parcelable是一个接口,可以实现序列化. 2.序列化的作用体现在:可以使用Intent来传递数据,也可以在进程建传递数据(IPC). 3.Parcelable在使用的时候,有一个 ...

  3. DotNetBar v12.9.0.0 Fully Cracked

    更新信息: http://www.devcomponents.com/customeronly/releasenotes.asp?p=dnbwf&v=12.9.0.0 如果遇到破解问题可以与我 ...

  4. windows下实现微秒级的延时

    windowsintegeriostream汇编嵌入式任务 最近正在做一个嵌入式系统,是基于windows ce的,外接硬件的时序要微秒级的延时.1.微秒级的延时肯定不能基于消息(SetTimer函数 ...

  5. linux curses函数库

    fedora20,安装yum install ncurses-devel 编译时:-lncurses 头文件:#include<curses.h> 参考:man ncurses \linu ...

  6. SWT的选择文件和文件夹的函数

    org.eclipse.swt.widgets.DirectoryDialog//选择目录org.eclipse.swt.widgets.FileDialog//SWT.OPEN打开文件 SWT.SA ...

  7. leetcode先刷&lowbar;Maximum Subarray

    dp创始人级精英赛的冠军.最大的部分和. 扫从左至右,保持一个最佳值而当前部分和,在这一部分,并成为负值什么时候.再往下的积累后,也起到了负面作用,所以,放弃直销,然后部分和初始化为阅读的当前位置. ...

  8. 微信支付WxpayAPI&lowbar;php&lowbar;v3(一)sdk简介与错误修改

    经过断断续续将近一周的时间终于把微信支付调通了. 这里总结一下,算是给后来者有个指引.少踩坑!!!! 开发语言:php5.5 语言框架:laravel5.2 微信sdk:WxpayAPI_php_v3 ...

  9. OpenGL ES2&period;0入门详解

    引自:http://blog.csdn.net/wangyuchun_799/article/details/7736928  1.决定你要支持的OpenGL ES的版本.目前,OpenGL ES包含 ...

  10. 类似fabric主机管理demo

    类似于fabric的主机管理系统 可以批量对主机进行操作 批量上传文件 批量下载文件 批量执行命令 demo代码 #!/usr/bin/env python # -*- coding:utf-8 -* ...