利用tf.gather_nd等一系列tf函数取出qebrain中的mis_matching feature

            shift_proj_inputs = self.emb_proj_layer(shift_inputs)
            _pre_qefv = tf.concat([shift_outputs, shift_proj_inputs], axis=-1)
            # _pre_qefv = shift_outputs + shift_proj_inputs
            # Notice, currently <s> is not to predict, but actually in our QE model, we can predict it.
            logits = self.output_layer(_pre_qefv)
            sample_id = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
            # Extract logits feature for mismatching
            shape = tf.shape(fw_target_input)
            idx0 = tf.expand_dims(tf.range(shape[0]), -1)
            idx0 = tf.tile(idx0, [1, shape[1]])
            idx0 = tf.cast(idx0, fw_target_input.dtype)
            idx1 = tf.expand_dims(tf.range(shape[1]), 0)
            idx1 = tf.tile(idx1, [shape[0], 1])
            idx1 = tf.cast(idx1, fw_target_input.dtype)
            indices_real = tf.stack([idx0, idx1, fw_target_input], axis=-1)
            logits_mt = tf.gather_nd(logits, indices_real)
            logits_max = tf.reduce_max(logits, axis=-1)
            logits_diff = tf.subtract(logits_max, logits_mt)
            logits_same = tf.cast(tf.equal(sample_id, fw_target_input), tf.float32)
            logits_fea = tf.stack([logits_mt, logits_max, logits_diff, logits_same], axis=-1)

如题,在上面一段代码中,关键要理解的是tf.gather_nd函数。这个函数接受两个参数,然后会根据第二个参数从第一个参数中取出来对应位置的值。

logits的大小是[batch_size, sequence_len, vocab_size], indices_real的大小是[batch_size, sequence_len, 3],这里的3又对应的是[batch_size, sequence_len, vocab_size], 最后会取出[batch_size, sequence_len]个值,组成一个[batch_size, sequence_len]大小的向量,这就是所谓的mt feature。

那么indices_real是如何得到的呢?关键在于idx0和idx1。我们希望idx0的每个元素对应的是batch_size,idx1的每个元素对应的是seq_len,所以希望idx0是[[0,0,0,...],[1,1,1,...],...],idx1是[[0,1,2,...],[0,1,2,...],...],这样最后得到的元素就是[0,0,x],[0,1,x],...,[1,0,x],[1,1,x],...。所以对于idx0,需要在第1维上重复,而对于idx1,需要在第0维上重复。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值