深入理解PyTorch中的gather函数

gather函数

今天在用PyTorch复现softmax的时候,参考的书籍为《Dive into DL Pytorch》。在书里面,关于gather函数原文是这样叙述的:

上一节中,我们介绍了softmax回归使用的交叉熵损失函数。为了得到标签的预测概率,我们可以使用gather函数。在下面的例子中,变量y_hat是2个样本在3个类别的预测概率,变量y是这2个样本的标签类别。通过使用gather函数,我们得到了2个样本的标签的预测概率。在代码中,标签类别的离散值是从0开始逐一递增的。

y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))

输出为:

tensor([[0.1000],
        [0.5000]])

看到这里的时候,不太能理解两点:

  1. torch.LongTensor是什么样的数据格式;
  2. gather函数到底在中间有什么样的作用。
1. torch.LongTensor
torch.LongTensor(2, 3) # 构建一个2 * 3 Long类型的张量

所以torch.LongTensor指示的数据类型为张量,但里面的元素为Long类型

2. gather函数的作用

很显然在这段代码里面:gather函数的第一个参数’1‘,指定的是dim,即维度,也就是对哪个维度进行操作,此处为对the first dim(或者说第1轴)进行操作。第二个参数y.view(-1, 1), 其展开应该为:

torch.LongTensor([[0],
 				 [2]])

这里 y.view(-1, 1) 里面的元素应该是作为index也即索引的意思。所以这句代码的理解就是对 y_hat 在第一轴上根据y.view(-1, 1)提供的索引进行取值,最后得到的输出就是:

tensor([[0.1000],
        [0.5000]])
  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值