官方说明
简介
torch.gather (input , dim , index , * , sparse_grad=False , out=None) → Tensor
沿着指定的轴收集值(是对input进行一种映射,index必须是 LongTensor格式)。
对于一个3-D张量,输出由下面的公式指定
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
连接
torch.gather — PyTorch 1.13 documentation
使用
使用示例1
import torch
a_4_4 = torch.arange(0,16).view(4,4)
print(a_4_4)
"""
输出
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
"""
index1=torch.LongTensor([[3,2,1,0],
[0,1,2,3]])
print(index1)
"""
输出
tensor([[3, 2, 1, 0],
[0, 1, 2, 3]])
"""
# dim = 0
print(a_4_4.gather(0, index1))
"""
a_4_4.gather(0,index1) 解释
[[a_4_4[3][0],a_4_4[2][1],a_4_4[1][2],a_4_4[0][3],
[a_4_4[0][0],a_4_4[1][1],a_4_4[2][2],a_4_4[3][3]]
输出
tensor([[12, 9, 6, 3],
[ 0, 5, 10, 15]])
"""
# dim = 1
print(a_4_4.gather(1,index1))
"""
a_4_4.gather(1,index1) 解释
[[a_4_4[0][3],a_4_4[0][2],a_4_4[0][1],a_4_4[0][0],
[a_4_4[1][0],a_4_4[1][1],a_4_4[1][2],a_4_4[1][3]]
输出
tensor([[3, 2, 1, 0],
[4, 5, 6, 7]])
"""
使用示例2
import torch
import numpy as np
a_2_3_4 = torch.LongTensor(np.arange(24)).view(2,3,4)
print(a_2_3_4)
"""
输出
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
"""
index2 = torch.LongTensor([[[0 ,1 ,2 ,0],
[0, 0, 0 ,0],
[1, 1, 1, 1]],
[[2, 2, 2, 2],
[1, 1, 1, 1],
[0, 0, 0, 0]]])
print(torch.gather(a_2_3_4, 1, index2))
"""
输出
tensor([[[ 0, 5, 10, 3],
[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[20, 21, 22, 23],
[16, 17, 18, 19],
[12, 13, 14, 15]]])
"""
index3 = torch.LongTensor([[[0 ,1 ,1 ,0],
[0, 0, 0 ,0],
[1, 1, 1, 1]],
[[1, 1, 1, 1],
[1, 1, 1, 1],
[0, 0, 0, 0]]])
print(torch.gather(a_2_3_4, 0, index3))
"""
输出
tensor([[[ 0, 13, 14, 3],
[ 4, 5, 6, 7],
[20, 21, 22, 23]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[ 8, 9, 10, 11]]])
"""