1 代码示例:
import torch
a = torch.arange(15).view(3, 5)
b = torch.zeros_like(a)
b[1][2] = 1
b[0][0] = 2
a、b矩阵分别为:
a矩阵
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
b矩阵
tensor([[2, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 0, 0]]
dim=0的结果为:
a.gather(0, b) # dim=0
tensor([[10, 1, 2, 3, 4],
[ 0, 1, 7, 3, 4],
[ 0, 1, 2, 3, 4]])
dim=1的结果为
a.gather(1, b) # dim=1
tensor([[ 2, 0, 0, 0, 0],
[ 5, 5, 6, 5, 5],
[10, 10, 10, 10, 10]])
2 思路理解: