Tensorflow Seq2seq attention decode解析

时间:2022-09-06 16:05:58

tensorflow基于 Grammar as a Foreign Language实现,这篇论文给出的公式也比较清楚。

这里关注seq2seq.attention_decode函数,

  • 主要输入

decoder_inputs,

initial_state,

attention_states,

 
 

这里可以主要参考 models/textsum的应用,textsum采用的多层双向lstm,

假设只有一层,texsum将正向
最后输出的state作为 attention_decode的输入initial_state

(不过很多论文认为用逆向最后的state可能效果更好)

对应decocer_inputs就是标注的摘要的字符序列id对应查找到的embedding序列

而attention_states是正向负向输出concatenate的所有outputs(hidden注意output和hidden是等同概念)

 
 

  • 关于linear

首先注意到在attention_decode函数用到了一个linear这个定义在rnn_cell._linear函数

他的输入是
一个list 可能的输入是比如

[ [batch_size, lenght1], [batch_size_length2]]

对应一个list 2个数组

它的作用是内部定义一个数组
对应这个例子 [length1 + length2, output_size]

也就是起到将[batch_size, length1][batch_size, length2]的序列输入映射到 [batch_size, output_size]的输出

 
 

这个在attention机制最后会遇到

先看attention的公式

将encoder的hidden states表示为

(h 1 , . . . , h T A)

将decoder的hidden states表示为

(d 1 , . . . , d T B) :=
(h T A +1 , . . . , h T A +T B).

 
 

Tensorflow  Seq2seq attention decode解析

这里最后计算得到的

Tensorflow  Seq2seq attention decode解析

就是attention的结果
对应一个样本
就是长度为 atten_size的向量(就是所有attention输入向量按照第三个公式的线性叠加之后的结果)那么对应batch_size的输入
就是[batch_size, atten_size]的一个结果。

论文中提到后面会用到这个attention,

Tensorflow  Seq2seq attention decode解析

 
 

Tensorflow  Seq2seq attention decode解析

 
 

也就是说会concat attention的结果和原始hidden state的结果,那么如何使用呢,tf的做法

x = linear([inp] + attns, input_size, True)

# Run the RNN.

cell_output, state = cell(x, state)

就是说 inp是 [batch_size, input_size], attns [batch_size, attn_size] linear的输入对应 input_size

即在linear内部经过input和attns concate之后输出[batch_size, input_size]使得能够x作为输入继续进行rnn过程

 
 

  • attention公式

    继续看attention公式
    ,不要考虑batch_size就是按照一个样本来考虑

    第一个公式
    对应3个举止 W1,W2都是[attn_size, atten_size]的正方形矩阵,h,d对应 [attent_size, 1]的向量

    v对应[atten_size, 1]的矩阵,

    那么就是线性叠加之后做非线性变化tanh([attn_size, 1])->[attn_size, 1]最后和v做dot得到一个数值
    表示u(i,t)

    即对应第i个attention向量在decode的t时刻时候应该的权重大小,

    第二个公式表示使用softmax做归一化得到权重向量概率大小。

    第三个公式上面已经分析。

  • tensorflow中attention的实现
    • 步骤1

    这里第一个问题是我们按照batch操作所以对应处理的不是一个样本而是一批batch_size个样本。

    那么上面的操作就不能按照tf.matmul来执行了,因为[batch_size, x, y][y, 1]这样相乘是不行的

    tf的做法是使用1by1 convolution来完成,主要利用1by1 + num_channels + num_filters

    关于conv2d的使用特别是配合1by1,num_channels, num_filters 这里解释的非常清楚

    http://*.com/questions/34619177/what-does-tf-nn-conv2d-do-in-tensorflow

     
     

    # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.

    hidden = array_ops.reshape(

    attention_states, [-1, attn_length, 1, attn_size])

    hidden_features = []

    v = []

    attention_vec_size = attn_size # Size of query vectors for attention.

    for a in xrange(num_heads):

    k = variable_scope.get_variable("AttnW_%d" % a,

    [1, 1, attn_size, attention_vec_size])

    hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))

    v.append(

    variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size]))

     
     

     
     

atention_vec_szie == attn_size

 
 

attn_size 对应 num_channels (num_channels个位置相乘加和 dot)

attention_vec_size 对应 num_filters

刚好这个conv2d的对应就是batch_size版本的attention的第一个公式里面的
W1 * h_t

Conv2d输出[batch_size, atten_length, 1, attention_vec_size]

  • def attention(query)的分析

attention(query)的输入是rnn上一步输出的state

输出 attns = attention(state)对应 [batch_size, attn_size]的矩阵

对应当前步骤需要用到的attention

 
 

def
attention(query):

"""Put attention masks on hidden using hidden_features and query."""

ds = [] # Results of attention reads will be stored here.

if nest.is_sequence(query): # If the query is a tuple, flatten it.

query_list = nest.flatten(query)

for q in query_list: # Check that ndims == 2 if specified.

ndims = q.get_shape().ndims

if ndims:

assert ndims == 2

query = array_ops.concat(1, query_list)

for a in xrange(num_heads):

with variable_scope.variable_scope("Attention_%d" % a):

y = linear(query, attention_vec_size, True)

y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])

# Attention mask is a softmax of v^T * tanh(...).

s = math_ops.reduce_sum(

v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3])

a = nn_ops.softmax(s)

# Now calculate the attention-weighted vector d.

d = math_ops.reduce_sum(

array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden,

[1, 2])

ds.append(array_ops.reshape(d, [-1, attn_size]))

return ds

 
 

首先目前默认都是用state_is_tuple=True选项(这样效率更高,后面state_is_tupe=False将会depreciated)

前面已经说过tf实现的state对应两个(cell_state, hidden_state)

所以这里nest_issequence是True 对应最后处理后query 就是 [batch_size, 2 * input_size]

 
 

y = linear(query, attention_vec_size, True)

y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])

对应W2dt的计算

hidden_features[a] + y 则注意是 W2dt累加到
所有的hi(attn_length个)

 
 

a对应[batdh_size, attn_length]

Reshape[batch_size, atten_length, 1, 1]

Hidden [batch_size, atten_length, 1, atten_size]

 
 

最终返回 [batch_size, attn_size]

Tensorflow Seq2seq attention decode解析的更多相关文章

  1. 学习笔记CB014:TensorFlow seq2seq模型步步进阶

    神经网络.<Make Your Own Neural Network>,用非常通俗易懂描述讲解人工神经网络原理用代码实现,试验效果非常好. 循环神经网络和LSTM.Christopher ...

  2. seq2seq attention

    1.seq2seq:分为encoder和decoder a.在decoder中,第一时刻输入的是上encoder最后一时刻的状态,如果用了双向的rnn,那么一般使用逆序的最后一个时刻的输出(网上说实验 ...

  3. 深度学习中的序列模型演变及学习笔记(含RNN&sol;LSTM&sol;GRU&sol;Seq2Seq&sol;Attention机制)

    [说在前面]本人博客新手一枚,象牙塔的老白,职业场的小白.以下内容仅为个人见解,欢迎批评指正,不喜勿喷![认真看图][认真看图] [补充说明]深度学习中的序列模型已经广泛应用于自然语言处理(例如机器翻 ...

  4. DL4NLP —— seq2seq&plus;attention机制的应用:文档自动摘要(Automatic Text Summarization)

    两周以前读了些文档自动摘要的论文,并针对其中两篇( [2] 和 [3] )做了presentation.下面把相关内容简单整理一下. 文本自动摘要(Automatic Text Summarizati ...

  5. ChatGirl 一个基于 TensorFlow Seq2Seq 模型的聊天机器人&lbrack;中文文档&rsqb;

    ChatGirl 一个基于 TensorFlow Seq2Seq 模型的聊天机器人[中文文档] 简介 简单地说就是该有的都有了,但是总体跑起来效果还不好. 还在开发中,它工作的效果还不好.但是你可以直 ...

  6. ChatGirl is an AI ChatBot based on TensorFlow Seq2Seq Model

    Introduction [Under developing,it is not working well yet.But you can just train,and run it.] ChatGi ...

  7. Tensorflow的CNN教程解析

    之前的博客我们已经对RNN模型有了个粗略的了解.作为一个时序性模型,RNN的强大不需要我在这里重复了.今天,让我们来看看除了RNN外另一个特殊的,同时也是广为人知的强大的神经网络模型,即CNN模型.今 ...

  8. tensorflow seq2seq&period;py接口实例

    以简单英文问答问题为例测试tensorflow1.4 tf.contrib.legacy_seq2seq中seq2seq文件的几个seq2seq接口 github:https://github.com ...

  9. seq2seq&plus;attention解读

    1什么是注意力机制? Attention是一种用于提升Encoder + Decoder模型的效果的机制. 2.Attention Mechanism原理 要介绍Attention Mechanism ...

随机推荐

  1. tomcat在linux下自启动

    Linux下设置tomcat开机自启动  一.以root用户登录系统: 二.进入init.d文件夹 cd /etc/init.d/ 三.创建并打开tomcat文件 vi tomcat 四.tomcat ...

  2. PHP类的原理

    一.类的实现 类的内部存储结构: struct _zend_class_entry { char type; // 类型:ZEND_INTERNAL_CLASS / ZEND_USER_CLASS c ...

  3. Linux(9&period;14-9&period;20)学习笔记

    实验一 Linux系统简介 一.Linux 为何物 Linux 就是一个操作系统,Linux 也就是系统调用和内核那两层. 二.Linux 历史简介 操作系统始于二十世纪 50 年代,当时的操作系统能 ...

  4. Selenium - IWebDriver&period;SwitchTo&lpar;&rpar; frame 和 Window 的用法

    IWebDriver.SwitchTo().Frame(IWebElement frame) 如果一个页面是一个html元素, 只有一个head, 一个body, 那么使用IWebDriver.Fin ...

  5. tcpdump 使用实践

    tcpdump常用配置指导 参考:http://www.cnblogs.com/ggjucheng/archive/2012/01/14/2322659.html 1. 指定端口抓包  -i 说明: ...

  6. python的学习笔记01&lowbar;2变量 常量 注释 用户交互 格式化输出

    变量是什么? 变量的作用  Variables are used to store information to be referenced and manipulated in a computer ...

  7. &lbrack;原创&rsqb;Modbus协议学习笔记

    一.参考资料 1.老罗传奇的2篇博文,写的不错,通俗易懂.链接地址为:http://www.cnblogs.com/luomingui/tag/Modbus/ 2.阿莫论坛精华资料:http://ww ...

  8. IIS站点报拒绝访问Temporary ASP&period;NET Files的解决办法

    IIS站点本来运行的好好的,突然就出现了:Temporary ASP.NET Files拒绝访问的问题.遇到此类问题,请逐步排查,定可解决. 原因:Windows操作系统升级导致. 办法: 1.检查C ...

  9. Django服务端读取excel文件并且传输到接口

    path_name = "opboss_download_" + str(int(time.time())) + ".csv" print(path_name) ...

  10. 再谈应用环境下的TIME&lowbar;WAIT和CLOSE&lowbar;WAIT

    昨天解决了一个HttpClient调用错误导致的服务器异常,具体过程如下: http://blog.csdn.net/shootyou/article/details/6615051 里头的分析过程有 ...