最近在做一些文本分类问题过程中,频繁使用Bilstm,对于.bidirectional_dynamic_rnn()函数使用较多,笔者在之前介绍过.dynamic_rnn()函数,在此基础上,参考/wuzqChom/article/details/75453327和/taolusi/article/details/81232210两篇博客,结合自身理解对.bidirectional_dynamic_rnn()函数进行详细解释。
首先我们了解一下函数的参数
-
bidirectional_dynamic_rnn(
-
cell_fw, # 前向RNN
-
cell_bw, # 后向RNN
-
inputs, # 输入
-
sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
-
initial_state_fw=None, # 前向的初始化状态(可选)
-
initial_state_bw=None, # 后向的初始化状态(可选)
-
dtype=None, # 初始化和输出的数据类型(可选)
-
parallel_iterations=None,
-
swap_memory=False,
-
time_major=False,
-
scope=None)
值得注意的是,当inputs的张量形状为[batch_size,max_len,embeddings_num]时,time_major = False。当inputs的形状为[max_len,batch_size,embeddings_num]时,time_major = True。一般我们将输入的格式为[batch_size,max_len,embeddings_num],因此time_major的默认值为False。
函数的输入与.dynamic_rnn()相似,由(outputs,outputs_states)组成。
- outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的元组。当time_major = False时,output_fw和output_bw的形状为[batch_size,max_len,hiddens_num]。在此情况下,最终的outputs可以用([output_fw, output_bw],-1)或([output_fw, output_bw],2),这里面的[output_fw, output_bw]可以直接用outputs进行代替。关于可以参考/leviopku/article/details/82380118
- output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的元组。 output_state_fw和output_state_bw的类型为LSTMStateTuple,由(c,h)组成,分别代表memory cell 和hidden state.
笔者最近做的两个项目分别为基于Bilstm的文本分类和中文实体抽取。对于文本分类来说,需要最后一个time_step的输出,而中文实体抽取则需要最终的outputs,即所有time_step的输出。
-
#文本分类可以由以下方式得到最后的输入状态
-
-
outputs, outputs_state = .bidirectional_dynamic_rnn(lstm_fw_cell_m, lstm_bw_cell_m, embedding_inputs,time_major = False,dtype = 32)
-
output_fw = outputs[0]
-
output_bw = outputs[1]#原形状为[batch_size,max_len,hidden_num]
-
output_fw = (output_fw,[1,0,2])#现在形状为[max_len,batch_size,hidden_num]
-
output_bw = (output_bw,[1,0,2])
-
outputs1 = [output_fw,output_bw]
-
lstmoutputs = (outputs1, 2)#连接后形状为[max_len,batch_size,2*hidden_num]
-
last = lstmoutputs[-1]#最后一个time_step的输出,为[batch_size,2*hidden_num]
-
-
-
#中文实体抽取
-
(output_fw_seq, output_bw_seq), _ = .bidirectional_dynamic_rnn(cell_fw=cell_fw,cell_bw=cell_bw,inputs=self.word_embeddings,sequence_length=self.sequence_lengths,dtype=32)
-
output = ([output_fw_seq, output_bw_seq],axis=-1) # time_major = False,所以输入为[batch_size,time_step,embedding_dim],所以这样连接,相当于 axis = 2