torch.gather 理解
查看如下代码
b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print torch.gather(b, dim=1, index=index_1)
print torch.gather(b, dim=0, index=index_2)
输出结果:
1 2 3
4 5 6
[torch.FloatTensor of size 2x3]
1 2
6 4
[torch.FloatTensor of size 2x2]
1 5 6
1 2 3
[torch.FloatTensor of size 2x3]
有观察可知,输出的size和index的size是一样大小的。
就是说将index放在那,然后对index里面的值进行替换。
如果dim=0,则要替换的值从输入数据的对应列里面得到,索引就是index里面的值。如果dim=1,则要替换的值从输入数据的对应行里面得到,索引就是index里面的值。