1、官方示例代码
import torch
t = torch.tensor([[1, 2], [3, 4]])
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
2、运行结果
3、图解
4、验证
4.1、第一行的第一个元素和第二行的第二个元素
4.2、3个第一行的第一个元素和3个第二行的第个元素
参考:
import torch
t = torch.tensor([[1, 2], [3, 4]])
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
参考: