gather torch_pytorch中的Torch.gather函数的含义

在动手学习深度学习中学到了一个函数gather,原文是说可以通过gather得到标签的预测概率。

y_hat = torch.tensor([[0.1,0.3,0.6],[0.3,0.2,0.5]])

y = torch.LongTensor([0,2])

y_hat.gather(1,y.view(-1,1))

tensor([[0.1000],

[0.5000]])

开始我看到这个输出一头雾水 不知道怎么回事

查了查 gather的时候我才知道

torch.gather(input,dim,index,out=None)

example:

t = torch.Tensor([1,2],[3,4])

torch.gather(t,1,torchLongTensor([[0,0],[1,0]]))

1,1

4,3

可以看出gather的作用是根据索引返回该项元素,首先先输入一个Tensor 然后根据dim进行判断是是行的还是列的,当dim=0 时候竖行查找,当dim=1的时候是横向查找

上题中,dim=1,那么索引就是列号。index的大小就是输出的大小,比如index是[1,0;0,0]其实就是第一行的第二个元素和第一个元素,第二行的第一个元素也就是返回的是2,1 3,3

所以例子中是[0,0],[1,0] 返回的就是[1,1],[4,3]

在例题中的他是通过view函数来返回index的,开始不知道view的意思,查过后知道了,他实际上和resize的意思差不多。

a = torch.Tensor([[1,2,3],[4,5,6]])

b = torch.Tensor([1,2,3,4,5,6])

print(a.view(1,6))

print(b.view(1,6))

得到的都是

tensor([[1,2,3,4,5,6]])

再看一个例子

a = torch.Tensor([[1,2,3],[4,5,6]])

print(a.view(3,2))

将会得到

tensor([[1,2],

[3,4],

[5,6]

])

相当于就是从1,2,3,4,5,6 顺序的拿数组来填充需要的形状。

参数中的-1就代表这个位置由其他位置的数字来进行推断,只要不在歧义的情况下,view参数就可以推断出来,也就是人可以推断出形状的情况下,view也是可以推断出来的,比如a tensor的数据个数是6个,如果view(1,-1)我们就可以推断出来-1代表6。而如果view(-1,-1,2)的话,人也不知道的话,机器也不会知道的,所以就会报错

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值