dynamic_rnn获取最后一步输出

tf.dynamic的输入参数中包含了一个sequence_lengths参数,传递的是一个batch中序列的真是长度,这个参数默认为None,如果输入的batch中每个样本的序列长度不相同,那么得到的通过dynamic_rnn得到的outputs每一个时间步的输出都不全为0(意思是和static_rnn一样,把padding部分得到的输出也算进来了),如果这时候我想取到真是长度位置的输出要怎么办?

def extract_axis_1(data, ind):
    """
    Get specified elements along the first axis of tensor.
    :param data: Tensorflow tensor that will be subsetted.
    :param ind: Indices to take (one for each element along axis 0 of data).
    :return: Subsetted tensor.
    """

    batch_range = tf.range(tf.shape(data)[0])
    indices = tf.stack([batch_range, ind], axis=1)
    res = tf.gather_nd(data, indices)

    return res

outputs, last_states = tf.nn.dynamic_rnn(
        cell=cell,
        dtype=tf.float64,
        sequence_length=None,
        inputs=X)
 
outputs2 = extract_axis_1(outputs, [5, 3, 5]) #真是长度是[6,4,6]

output2就满足了你想要的输出

当你把真是的长度传递给dynamic_rnn时,它会根据真实长度大小刚好计算到真实长度为止,这时,只需要取last_states.h即可。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值