【参考:torch.gather — PyTorch 1.12 documentation】
out = torch.gather(input, dim, index)
out[i][j][k] = input[ index[i][j][k] ] [j][k] # if dim == 0
out[i][j][k] = input[i] [ index[i][j][k] ] [k] # if dim == 1
out[i][j][k] = input[i][j] [ index[i][j][k] ] # if dim == 2
规则:需要构造一个和index一样shape的索引数组,然后把第dim维换成index[i][j]所得到的值,得到新的索引数组,然后根据该索引数组去input取值。
【参考:Pytorch中torch.gather函数祥解 - 简书】
用二维tensor举例
#按照dim = 0, 取一个2*2tensor的对角线上的数值
import torch
a = torch.Tensor([[1, 2],
[3, 4]])
b = torch.gather(a, dim = 0, index=torch.LongTensor([[0, 1]]))
print('a = ', a)
print('b = ', b)
"""
a = tensor([[1., 2.],
[3., 4.]])
b = tensor([[1., 4.]])
"""
index的shape是[1,2],所以i=0,j=0,1
i,j=0,0
b[0][0]=a[ index[0][0] ][0] =a[0][0]=1
i,j=0,1
b[0][0]=a[ index[0][1] ][1] =a[1][1]=4
案例二
#按照dim = 1, 取一个2*2 tensor的对角线上的数值
import torch
a = torch.Tensor([[1, 2],
[3, 4]])
c = torch.gather(a, dim = 1, index=torch.LongTensor([[0],
[1]]))
print('a = ', a)
print('c = ', c)
"""
a = tensor([[1., 2.],
[3., 4.]])
c = tensor([[1.],
[4.]])
"""
index的shape是[2,1],所以i=0,1,j=0
i,j=0,0
c[0][0]=a[ index[0][0] ][0] =a[0][0]=1
i,j=1,0
c[0][0]=a[ index[1][0] ][0] =a[1][1]=4
案例三
import torch
a = torch.Tensor([[1, 2],
[3, 4]])
d = torch.gather(a, dim= 0, index=torch.LongTensor([[0, 0],
[1, 0]]))
print('a = ', a)
print('d = ', d)
"""
a = tensor([[1., 2.],
[3., 4.]])
d = tensor([[1., 2.],
[3., 2.]])
"""
index的shape是[2,2],所以i=0,1,j=0,1
i,j=0,0
d[0][0]=a[ index[0][0] ][0] =a[0][0]=1
i,j=0,1
d[0][1]=a[ index[0][1] ][1] =a[0][1]=2
i,j=1,0
d[1][0]=a[ index[1][0] ][0] =a[1][0]=3
i,j=1,1
d[1][1]=a[ index[1][1] ][1] =a[0][1]=2
实际中的一个例子
引言:在多分类中,torch.gather常用来取出标签所对应的概率
有三个标签[0, 1, 2],即三个类别。现在知道两个样本(A 和 B)所得到的三个标签的概率分别为[0.1, 0.3, 0.6]和[0.3, 0.2, 0.5], 用my_pred表示, 这两个样本的真实标签分别为0和2, 那么我们很容易知道A所预测的真实标签的概率为0.1, B所预测的真实标签的概率为0.5,A分类错误,B正确分类。那么用程序这么获得标签对应的概率呢,这里就可以用gather函数。
import torch
my_pred = torch.tensor([[0.1, 0.3, 0.6],
[0.3, 0.2, 0.5]])
my = torch.LongTensor([[0],
[2]])
print(torch.gather(input=my_pred, dim=1, index=torch.LongTensor([[0],
[2]])))
"""
tensor([[0.1000],
[0.5000]])
"""
index的shape是[2,1],所以i=0,1,j=0
i,j=0,0
d[0][0]=a[0][ index[0][0] ] =a[0][0]=0.1
i,j=1,0
d[1][0]=a[1][ index[1][0] ] =a[1][2]=0.5