作用:取出特定维度索引的值。
torch.gather(input,dim,index,out=None,sparse_grad=False)
举个例子
import torch
a=torch.tensor([
[1,2,3,4,5],
[6,7,8,9,10],
[11,12,13,14,15],
])
index=torch.tensor([
[0,2],
[3,4],
[1,4]
])
print(torch.gather(a,1,index))
输出
tensor([[ 1, 3],
[ 9, 10],
[12, 15]])
1 在指定的维度上,index维度要和tensor可以不一致。
2 除了指定的维度上,index维度要和tensor必须一致。
上例中a的维度为(3,5),指定维度为1,那么在维度0上,a和index的0维度都为3。在指定维度1上,a的1维度为5,index的1维度为2。1
意思是:
在维度0上(也就是0行)取维度1的(也就是列)第0和第2个数值:第0个是1,第2个是3。
在维度0上(也就是1行)取维度1的(也就是列)第3和第4个数值:第3个是9,第4个是10。
在维度0上(也就是2行)取维度1的(也就是列)第1和第4个数值:第1个是12,第4个是15。
例2
import torch
a=torch.tensor([[[1,2,3],
[4,5,6],
[7,8,9]
],
[[11,22,33],
[44,55,66],
[77,88,99]
]])
hi=torch.tensor([[[1],[1],[1]],[[1],[1],[1]]])
b=torch.gather(a,-1,hi+1)
print(b)
输出:
torch.Size([2, 3, 1])
tensor([[[ 3],
[ 6],
[ 9]],
[[33],
[66],
[99]]])
这个例子讲个反向思路:
我们知道要提取3,6,9,3,33,66,99这6个数,
a的维度为 index的维度
2 2 #不是指定维度,保持不动
3 3 #不是指定维度,保持不动
3 1 #是指定维度,写入需要的索引值
假如我们要取 2,3,5,6,8,9,22,33,55,66,88,99这几个值
首先明确维度为2,那么index维度为2,3,[-2,-1]
index=[
[],[]
]
index=[
[[],[],[]],
[[],[],[]]
]
index=[
[[1,2],[1,2],[1,2]],
[[1,2],[1,2],[1,2]],
]
import torch
a=torch.tensor([[[1,2,3],
[4,5,6],
[7,8,9]
],
[[11,22,33],
[44,55,66],
[77,88,99]
]])
index=torch.tensor([
[[1,2],[1,2],[1,2]],
[[1,2],[1,2],[1,2]],
])
print(index.shape)
b=torch.gather(a,-1,index)
print(b)
输出
torch.Size([2, 3, 2])
tensor([[[ 2, 3],
[ 5, 6],
[ 8, 9]],
[[22, 33],
[55, 66],
[88, 99]]])
error:
gather_out_cuda():expected dtype int64 for index
解决办法:把index的数据类型转换为long。
例如:
temp=torch.ones(1,63,268,1)
temp2=temp.type(torch.long).to(device)
a=torch.gather(input,-1,temp2)