torch.gather
沿给定轴 dim ,将输入索引张量 index 指定位置的值进行聚合.
1. 二维情况下
(1)case1: dim=0
import torch
tensor_0 = torch.arange(3, 12).view(3, 3)
index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(0, index)
print("tensor_0:", tensor_0)
print("tensor_1", tensor_1)
输出
tensor_0: tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
tensor_1 tensor([[9, 7, 5]])
# note: dim=0从列里面选,【9】是第一列中第2个数,【7】是第二列第1个数,【5】是第三列第0个数
(2)case2: dim=1
tensor_0 = torch.arange(3, 12).view(3, 3)
index = torch.tensor([[2],[1],[0]])
tensor_1 = tensor_0.gather(1, index)
print("tensor_0:", tensor_0)
print("tensor_1", tensor_1)
输出
tensor_0: tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
tensor_1 tensor([[5],
[7],
[9]])
# note: dim=1, 从行里取【5】是第一行第二个数,【7】是第二行第1个数,【9】是第三行第0个数
case3: 一行中取多个数
tensor_0 = torch.arange(3, 12).view(3, 3)
index = torch.tensor([[2,1,0],[1,1,1],[0,1,0]])
tensor_1 = tensor_0.gather(1, index)
print("tensor_0:", tensor_0)
print("tensor_1", tensor_1)
输出
tensor_0: tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
tensor_1 tensor([[ 5, 4, 3],
[ 7, 7, 7],
[ 9, 10, 9]])
Case4: 一列中取多个数
tensor_0 = torch.arange(3, 12).view(3, 3)
index = torch.tensor([[2,1,0],[1,1,1],[0,1,0]])
tensor_1 = tensor_0.gather(0, index)
print("tensor_0:", tensor_0)
print("tensor_1", tensor_1)
输出
tensor_0: tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
tensor_1 tensor([[9, 7, 5],
[6, 7, 8],
[3, 7, 5]])
2. 三维情况
case1: dim=1
a = torch.randint(0, 30, (2, 3, 5))
index = torch.LongTensor([[[0,1,2,0,2],
[0,0,0,0,0],
[1,1,1,1,1]],
[[1,2,2,2,2],
[0,0,0,0,0],
[2,2,2,2,2]]])
b = torch.gather(a, 1, index)
print("a:", a)
print("b:", b)
输出:
a: tensor([[[13, 1, 25, 18, 28],
[24, 19, 5, 25, 11],
[13, 13, 20, 9, 22]],
[[22, 18, 12, 9, 1],
[ 6, 11, 23, 11, 29],
[15, 9, 8, 29, 6]]])
b: tensor([[[13, 19, 20, 18, 22],
[13, 1, 25, 18, 28],
[24, 19, 5, 25, 11]],
[[ 6, 9, 8, 29, 6],
[22, 18, 12, 9, 1],
[15, 9, 8, 29, 6]]])
# note: dim=1从列中取
case2: dim=2
a = torch.randint(0, 30, (2, 3, 5))
index = torch.LongTensor([[[0,1,2,0,2],
[0,0,0,0,0],
[1,1,1,1,1]],
[[1,2,2,2,2],
[0,0,0,0,0],
[2,2,2,2,2]]])
b = torch.gather(a, 2, index)
print("a:", a)
print("b:", b)
输出
a: tensor([[[ 0, 19, 3, 20, 29],
[ 4, 2, 1, 8, 13],
[16, 15, 13, 29, 10]],
[[25, 18, 16, 0, 6],
[ 3, 4, 13, 23, 19],
[ 7, 21, 28, 17, 11]]])
b: tensor([[[ 0, 19, 3, 0, 3],
[ 4, 4, 4, 4, 4],
[15, 15, 15, 15, 15]],
[[18, 16, 16, 16, 16],
[ 3, 3, 3, 3, 3],
[28, 28, 28, 28, 28]]])
# dim=2 从行里取数
case3: dim=0
a = torch.randint(0, 30, (2, 3, 5))
index = torch.LongTensor([[[0,1,1,0,1],
[0,1,1,1,1],
[1,1,1,1,1]],
[[1,0,0,0,0],
[0,0,0,0,0],
[1,1,0,0,0]]])
b = torch.gather(a, 0, index)
print("a:", a)
print("b:", b)
输出
a: tensor([[[ 9, 3, 10, 19, 4],
[26, 19, 20, 9, 28],
[ 5, 21, 29, 26, 24]],
[[10, 2, 11, 29, 26],
[20, 25, 17, 11, 16],
[ 4, 17, 27, 17, 29]]])
b: tensor([[[ 9, 2, 11, 19, 26],
[26, 25, 17, 11, 16],
[ 4, 17, 27, 17, 29]],
[[10, 3, 10, 19, 4],
[26, 19, 20, 9, 28],
[ 4, 17, 29, 26, 24]]])
# dim = 0时,索引代表在第几页取数,取数的位置为索引i所在的坐标,如上:index[0][0][0]=0,表示取a中第0页(0,0)的数9,index[0][0][1]=1表示取第1页的(0,1)坐标的数3
总结
- index的维数必须与输入维数相同,输入为2维矩阵,index也必须为2维矩阵
- 在二维矩阵中dim=0 表示列,dim=1表示行,三维矩阵中,dim=0表示页,dim=1表示列,dim=2表示行