.bidirectional_dynamic_rnn函数详解

时间:2025-04-10 08:11:36

最近在做一些文本分类问题过程中,频繁使用Bilstm,对于.bidirectional_dynamic_rnn()函数使用较多,笔者在之前介绍过.dynamic_rnn()函数,在此基础上,参考/wuzqChom/article/details/75453327/taolusi/article/details/81232210两篇博客,结合自身理解对.bidirectional_dynamic_rnn()函数进行详细解释。

首先我们了解一下函数的参数

  1. bidirectional_dynamic_rnn(
  2. cell_fw, # 前向RNN
  3. cell_bw, # 后向RNN
  4. inputs, # 输入
  5. sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
  6. initial_state_fw=None, # 前向的初始化状态(可选)
  7. initial_state_bw=None, # 后向的初始化状态(可选)
  8. dtype=None, # 初始化和输出的数据类型(可选)
  9. parallel_iterations=None,
  10. swap_memory=False,
  11. time_major=False,
  12. scope=None)

值得注意的是,当inputs的张量形状为[batch_size,max_len,embeddings_num]时,time_major = False。当inputs的形状为[max_len,batch_size,embeddings_num]时,time_major = True。一般我们将输入的格式为[batch_size,max_len,embeddings_num],因此time_major的默认值为False。

函数的输入与.dynamic_rnn()相似,由(outputs,outputs_states)组成。

  • outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的元组。当time_major = False时,output_fw和output_bw的形状为[batch_size,max_len,hiddens_num]。在此情况下,最终的outputs可以用([output_fw, output_bw],-1)或([output_fw, output_bw],2),这里面的[output_fw, output_bw]可以直接用outputs进行代替。关于可以参考/leviopku/article/details/82380118
  • output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的元组。 output_state_fw和output_state_bw的类型为LSTMStateTuple,由(c,h)组成,分别代表memory cell 和hidden state. 

笔者最近做的两个项目分别为基于Bilstm的文本分类和中文实体抽取。对于文本分类来说,需要最后一个time_step的输出,而中文实体抽取则需要最终的outputs,即所有time_step的输出。

  1. #文本分类可以由以下方式得到最后的输入状态
  2. outputs, outputs_state = .bidirectional_dynamic_rnn(lstm_fw_cell_m, lstm_bw_cell_m, embedding_inputs,time_major = False,dtype = 32)
  3. output_fw = outputs[0]
  4. output_bw = outputs[1]#原形状为[batch_size,max_len,hidden_num]
  5. output_fw = (output_fw,[1,0,2])#现在形状为[max_len,batch_size,hidden_num]
  6. output_bw = (output_bw,[1,0,2])
  7. outputs1 = [output_fw,output_bw]
  8. lstmoutputs = (outputs1, 2)#连接后形状为[max_len,batch_size,2*hidden_num]
  9. last = lstmoutputs[-1]#最后一个time_step的输出,为[batch_size,2*hidden_num]
  10. #中文实体抽取
  11. (output_fw_seq, output_bw_seq), _ = .bidirectional_dynamic_rnn(cell_fw=cell_fw,cell_bw=cell_bw,inputs=self.word_embeddings,sequence_length=self.sequence_lengths,dtype=32)
  12. output = ([output_fw_seq, output_bw_seq],axis=-1) # time_major = False,所以输入为[batch_size,time_step,embedding_dim],所以这样连接,相当于 axis = 2