gather
torch.gather(input, dim, index, out=None) → Tensor
沿给定轴dim, 将输入索引张量index
指定位置的值进行聚合,即,指定维度的大小
会发生变化。
对于一个3维张量,输出可以定义为:
out[i][j][k] = tensor[index[i][j][k]][j][k] # dim=0
out[i][j][k] = tensor[i][index[i][j][k]][k] # dim=1
out[i][j][k] = tensor[i][j][index[i][j][k]] # dim=3
举个例子:
import torch
t = torch.Tensor([[1,2],[3,4]])
#t_gather在dim=1的维度为2,不变
t_gather = torch.gather(t, <