pytorch.gather()函数
官方文档
参考评论
直接举例子
import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
index矩阵为(维度需要与原tensor一致)
index=[ [x1,x2,x2],
[y1,y2,y2],
[z1,z2,z3] ]
1. 如果dim=0,新tensor为
[ [(x1,0),(x2,1),(x3,2)]
[(y1,0),(y2,1),(y3,2)]
[(z1,0),(z2,1),(z3,2)] ]
eg:
index = torch.tensor([[2, 1, 0],
[1, 2, 0],
[0, 1, 2] ])
tensor_1 = tensor_0.gather(0, index)
print(tensor_1)
tensor([[ 9, 7, 5],
[ 6, 10, 5],
[ 3, 7, 11]])
2. 如果dim=1,新tensor为
[ [(0,x1),(0,x2),(0,x3)]
[(1,y1),(1,y2),(1,y3)]
[(2,z1),(2,z2),(2,z3)] ]
eg:
index = torch.tensor([[2, 1, 0],
[1, 2, 0],
[0, 1, 2] ])
tensor_2 = tensor_0.gather(1, index)
print(tensor_2)
tensor([[ 5, 4, 3],
[ 7, 8, 6],
[ 9, 10, 11]])