torch.gather(input, dim, index, *, sparse_grad=False, out=None)
参数:
input:输入张量
dim:index按照哪个轴取值
index:取值用的索引张量
gather其实就是根据index中索引查找input中元素重排,数据都是原来的,只是重新查找形成新张量矩阵。
公式就是下面这样
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
这么看下来可能有点懵,我们举个栗子
import torch
t = torch.arange(0,32).view(1,2,4,4)
整个4维数据1x2x4x4
tensor([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]]]])
index为1x1x2x4
index = torch.LongTensor([[[[0,1,2,3],[2,2,2,2]]]])
执行gather,axis设在第3个维度上
a = torch.gather(t, 2, index)
结果为
tensor([[[[ 0, 5, 10, 15],
[ 8, 9, 10, 11]]]])
index维度为1x1x2x4,所以gather输出也是这个dims
根据上面的公式,我们可以一个个来取值
a[0][0][0][0] = t[0][0][index[0][0][0][0]][0] = t[0][0][0][0] = 0
a[0][0][0][1] = t[0][0][index[0][0][0][1]][1] = t[0][0][1][1] = 5
a[0][0][0][2] = t[0][0][index[0][0][0][2]][2] = t[0][0][1][1] = 10
...
a[0][0][1][2] = t[0][0][index[0][0][1][2]][2] = t[0][0][2][2] = 10
a[0][0][1][3] = t[0][0][index[0][0][1][3]][3] = t[0][0][2][3] = 11