最近在学习pytorch时遇到gather函数,开始没怎么理解,后来查阅网上相关资料后大概明白了原理。
gather()函数
在pytorch中,gather()函数的作用是将数据从input中按index提出,我们看gather函数的的官方文档说明如下:
torch.gather(input, dim, index, out=None) → Tensor
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out