从完整数据中按索引取值
例子一:
torch.gather(t, 0, torch.tensor([[1, 0, 2], [1,2, 1]]))
dim=0
index [1, 0, 2], [1,2, 1]
则取的值的索引是将index放在0维上
(1,_) (0, _) (2,_ );
(1, _ ) (2,_ ) (1, _)
然后_每行自动填充为0,1,2得到下标
(1,0) (0, 1) (2,2 );
(1, 0) (2,1) (1, 2)
从t中,取出下标值
(1,0) (0, 1) (2,2 );
3 2 -5
(1, 0) (2,1) (1, 2)
3 34 0
例子二:
torch.gather(t,1, torch.tensor([[1, 0, 2], [1,2, 1]]))
dim=1
index [1, 0, 2], [1,2, 1]
则取的值的索引是将index放在1维上
(_,1) (_,0) (_,2);
(_,1 ) (_,2) (_,1)
然后 _ 按列每列自动填充为0,1;或者按行第一行填0;第二行填1;第三行填2.。。。
从t中,取出下标
(0, 1) (0, 0) (0, 2);
2 1 -3
(1, 1 ) (1, 2) (1, 1)
4 0 4
的值。