在一般的rnn模型中,rnn一般输出的形式如下[batch,seq_len,hidden_size],如果用做分类,一般是取最后一个状态[batch,hidden_size],如果用于做词性标注和分词则取全部的状态[batch,seq_len,hidden_size],下面介绍下用于文本分类取最后状态的两种方法,一种是直接transpose,取[-1]最后一个状态,大小变为[batch,hidden_size],另外一个是直接[:,-1,:]把中间的seq_len删除,直接变成[batch,hidden_size],直接看代码:
In [2]: import tensorflow as tf
In [3]: output=tf.get_variable(name='out',shape=[10,30,128])
In [4]: output.get_shape()
Out[4]: TensorShape([Dimension(10), Dimension(30), Dimension(128)])
In [7]: output_trans = tf.transpose(output, [1, 0, 2])
In [8]: output_trans.get_shape()
Out[8]: TensorShape([Dimension(30), Dimension(10), Dimension(128)])
In [9]: output_trans[-1].get_shape()
Out[9]: TensorShape([Dimension(10), Dimension(128)])
In [10]: output[:,-1,:].get_shape()
Out[10]: TensorShape([Dimension(10), Dimension(128)])