结论:使用方法
# gather,沿dim指定的轴收集值。
y_hat.gather(1, y.view(-1, 1))# y.view(-1, 1)会变成一列,y_hat的取y作为的索引的值
分步理解:先创建一个2*3的tensor
>>y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
tensor([[0.1000, 0.3000, 0.6000],
[0.3000, 0.2000, 0.5000]])
为了使用gather函数,我们得创建一个tensor作为gather得参数
>>y = torch.LongTensor([0, 2])
tensor([0, 2])
我们需要把y变个形状
>>y.view(-1, 1)
tensor([[0],
[2]])
先来看看使用得结果
>>y_hat.gather(1, y.view(-1, 1))
tensor([[0.1000],
[0.5000]])
图解: