最近想用pytorch搭一个机器翻译模型,过程中涉及到从矩阵中取出某些指定元素组成新矩阵的步骤,然后在pytorch官网上找到对应功能的函数为 torch.gather 。
pytorch官网上介绍这个函数时就寥寥几句话:
###########################################
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k