看了一堆花里胡哨的,还是通过官方的定义看的明白
输入一个矩阵
x = torch.arange(1,7).reshape(2,3)
Out[4]:
tensor([[1, 2, 3],
[4, 5, 6]])
x是二维的,两行三列
输入一个索引矩阵index(LongTensor不是Tensor),维度(dim)要和x一致
index = torch.LongTensor([[0,1,1]])
index
Out[6]: tensor([[0, 1, 1]])
index.shape
Out[7]: torch.Size([1, 3])
torch.gather的输出size和index是一致的,这里是1×3一行三列的一个矩阵
我们先写一个一行三列的矩阵三个值分别为[out00,out01,out02]
torch.gather(input,dim,index) input为输入矩阵,dim为维数,index为索引矩阵
这里根据维数和索引值去更换我需要的值,例如我dim=0,我就把这里的00,01,02的第一个0换成index输入的值
索引什么值呢,就是input的矩阵中对应的值,对应到x就是1,5,6,即gather的输出
同理可扩展至更高维度