def collect_final_step_of_lstm(lstm_representation, lengths):
# lstm_representation: [batch_size, passsage_length, dim]
# lengths: [batch_size]
lengths = tf.maximum(lengths, tf.zeros_like(lengths, dtype=tf.int32))
batch_size = tf.shape(lengths)[0]
batch_nums = tf.range(0, limit=batch_size) # shape (batch_size)
indices = tf.stack((batch_nums, lengths), axis=1) # shape (batch_size, 2)
result = tf.gather_nd(lstm_representation, indices, name='last-forwar-lstm')
return result # [batch_size, dim]
用法:
from tensorflow.python.ops import rnn
self.outputs, _ = rnn.bidirectional_dynamic_rnn(
lstm_cell_fw,
lstm_cell_bw,
self.inputs_emb