>>>import torch
>>>
>>>a = torch.Tensor([[1,2,3],[4,5,6]])
>>>a
tensor([[1., 2., 3.],
[4., 5., 6.]])
>>>b = torch.gather(a,1,torch.LongTensor([[1,1,1],[0,0,0]]))
>>>b
tensor([[2., 2., 2.],
[4., 4., 4.]])
>>>b = torch.gather(a,1,torch.LongTensor([[0,0,2],[1,1,0]]))
>>>b
tensor([[1., 1., 3.],
[5., 5., 4.]])
pytorch之gather()
最新推荐文章于 2022-07-30 15:43:03 发布