tf.contrib.seq2seq.dynamic_decode 返回值的shape 巨坑

tf.contrib.seq2seq.dynamic_decode 这个函数真的是巨坑啊,每一个batch,他的rnn_output的shape居然会是:
[batch_size, max_efficient_sentence across entire batch, num_classes] ,而不是我们想要的每个序列应该有的最大长度,尽管我们在TrainingHelper里面指定了Sequence_length的有效长度。。。。我日哦。
于是在计算sequence_loss的时候,

def sequence_loss(logits,
                  targets,
                  weights,
                  average_across_timesteps=True,
                  average_across_batch=True,
                  softmax_loss_function=None,
                  name=None):

我们传入的targets是每条序列都填充到某个固定的最大长度,然而logits的sequence大小却是又都是填充到 当前batch里面最大有效长度 ,这两个长度不相等.着实坑人···

举个例子:机器翻译里面,目标语词典大小1100,我们每次送5条目标语序列到模型,送入之前都会把这5个句子填充到指定的长度max_len(例如max_len=200),那么targets的shape就是[5,200,1100]。

然后我们也会保留目标语序列填充前的真实长度数组sequence_lengths,比如说sequence_lengths=[32,43,96,44,76] 。 坑爹的dynamic_decode执行后,当前batch的结果rnn_output的shape居然是[5,96,1100]!!!!!!!!!
!!!
!!!
!!!

解决方法:对targets进行截断,截断成和rnn_output一样的形状。

# 获取logits
logits = tf.contrib.seq2seq.dynamic_decode(xxx)[0].rnn_output

# 获取当前的长度,max_len 和 logits 的较小者.
current_ts = tf.to_int32(tf.minimum(tf.shape(target_input)[1], tf.shape(logits)[1]))
# 对 target 进行截取
target_sequence = tf.slice(target_input, begin=[0, 0], size=[-1, current_ts])
mask_ = tf.sequence_mask(lengths=sequence_lengths, maxlen=current_ts, dtype=logits.dtype)
logits = tf.slice(logits, begin=[0, 0, 0], size=[-1, current_ts, -1])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值