import torch
a1 = torch.tensor([1,2,3,4])
b1= torch.tensor([[1,3],[2,2]])
print(a1[b1])
b = torch.tensor([[0, 1],[2, 3],[4, 5]])
idx_0 = torch.tensor([[1, 0],[2, 1]])
idx_1 = torch.tensor([0, 1])
print(b[idx_0,idx_1])
输出
import torch
a1 = torch.tensor([1,2,3,4])
b1= torch.tensor([[1,3],[2,2]])
print(a1[b1])
b = torch.tensor([[0, 1],[2, 3],[4, 5]])
idx_0 = torch.tensor([[1, 0],[2, 1]])
idx_1 = torch.tensor([0, 1])
print(b[idx_0,idx_1])
输出