torch.gather
先看看定义
torch.gather(input, dim, index, out=None) → Tensor
沿给定轴dim,将输入索引张量index指定位置的值进行聚合。
对一个3维张量,输出可以定义为:
out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0
out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1
out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=3
例子
>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
1 1
4 3
[torch.FloatTensor of size 2x2]
首先可以确定的
- 输入的维度数与index的维度数量相等[即input为n维[a,b…c],则index必须为n维[x,y…z]](注意没有规定a和x,b和y,…c和z一定要相等)
- 输出的元素数量与index的元素数量相等[即index [a,b…c],out[a,b…c]]
确定了以上两点,之后我们要搞清楚的就是index在每个维度的意义
定义如下tensor
t = torch.randn(2,3,3)
print(t)
index_a=torch.LongTensor([[[1,2]],[[1,2]]])
index_b=torch.LongTensor([[[1,0]],[[1,0]]])
#第三维度即out[i][j][k]为input(t)相对行列的第index[i][j][k]个元素
print(index_a.shape,torch.gather(t, 2, index_a))
#第二维度即out[i][j][k]为input(t)相对i行index[i][j][k]列的第k个元素
print(torch.gather(t,1,index_a))
#第一维度即out[i][j][k]为input(t)相对index[i][j][k]行j列的第k个元素
print(torch.gather(t,0,index_b))
'''
tensor([[[-1.3513, -0.8054, -1.1973],
[-0.6869, 0.6490, 0.6097],
[-0.8863, 0.4745, 0.0845]],
[[ 0.0430, -1.1662, -1.6117],
[-0.5186, 0.4622, -0.0349],
[-0.3438, 1.4358, -0.6612]]])
torch.Size([2, 1, 2]) tensor([[[-0.8054, -1.1973]],
[[-1.1662, -1.6117]]])
tensor([[[-0.6869, 0.4745]],
[[-0.5186, 1.4358]]])
tensor([[[ 0.0430, -0.8054]],
[[ 0.0430, -0.8054]]])
'''