在tensorflow和pytorch中都有一个gather函数,其作用相似但是用法不同。关于tf.gather的用法可以参考知乎作者Towser的文章《TF 中的 indexing 和 slicing》。torch.gather函数的用法也很简单,就是给定indices获取tensor对应元素。给个例子就明白了。
tensor
下面直接给出torch.gather(针对二维情况)的tf实现,一维情况可以根据这个进行修改。
def
就酱。还是写tf太少了,不过mask真的是一种很有用的技巧。