一,版权声明
版权声明:本文为weixin_44291388原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
二,官方文档
pytorch官网关于torch.gather的文档:torch.gather — PyTorch 1.11.0 documentationhttps://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather
三,理解
接下来是原文作者理解部分:
import torch
# 先看看out 和 index 都是二维数组的情况
# out[i][j] = tensor[index[i][j]][j] # dim=0
# out[i][j] = tensor[[i][index[i][j]] # dim=1
t = torch.Tensor([[1,2],[3,4]])
# t = 1 2
# 3 4
index = torch.LongTensor([[0,0],[1,0]])
# index = 0 0
# 1 0
print(torch.gather(t, 1, index)) #此时dim = 1
# 输出 1 1
# 4 3
# 输出的结果的size = index.size()
# 讲解过程
# index[0][0] = 0
# index[0][1] = 0
# index[1][0] = 1
# index[1][1] = 0
# dim = 1
# out[0][0] = tensor[[0]index[0][0]] == tensor[0][0] == 1
# out[0][1] = tensor[[0]index[0][1]] == tensor[0][0] == 1
# out[1][0] = tensor[[1]index[1][0]] == tensor[1][1] == 4
# out[1][1] = tensor[[1]index[1][1]] == tensor[1][0] == 3
#例二 dim = 1
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 1])
print(y.view(-1,1)) # 2行1列
print(y_hat.gather(1, y.view(-1, 1)))
# 0.1
# 0.2
# y[0][0] = y_hat[[0][y[0][0]] == y_hat[0][0] == 0.1
# y[1][0] = y_hat[[1][y[1][0]] == y_hat[1][1] == 0.1
#例三 dim = 0
t = torch.Tensor([[1,2],[3,4]])
# t = 1 2
# 3 4
index = torch.LongTensor([[0,0],[1,0]])
# index = 0 0
# 1 0
print(torch.gather(t, 0, index)) #此时dim = 1
# 输出 1 2
# 3 2
# index[0][0] = 0
# index[0][1] = 0
# index[1][0] = 1
# index[1][1] = 0
# dim = 0
# out[0][0] = tensor[[index[0][0]][0]] == tensor[0][0] == 1
# out[0][1] = tensor[[index[0][1]][1]] == tensor[0][1] == 2
# out[1][0] = tensor[[index[1][0]][0]] == tensor[1][1] == 3
# out[1][1] = tensor[[index[1][1]][1]] == tensor[0][1] == 2