tensor.scatter_(dim,index,src)理解

文章目录

前言

看到这样一行代码

label = torch.zeros_like(pred_label)
label.scatter_(1, batch_data.label.cuda().unsqueeze(dim=1), 1)

解释

假设这里的pred_label的大小为(32,18),32是batch_size,18代表此分类任务有18个类别。
我们创建的label大小也为(32,18),全0。
再假设这里的batch_data.label长下面的样子,它代表batch中,第一个样本的类别是0,第二个样本的类别是13…
在这里插入图片描述
然后我们可以看到,我们创建的label:
tensor.scatter_(dim,index,src)这三个参数在这里设置为

  • dim=1: 表示按照列进行填充
  • index=batch_data.label:表示把batch_data.label里面的元素值作为下标,去下标对应位置(这里的"对应位置"解释为列,如果dim=0,那就解释为行)进行填充
  • src=1:表示填充的元素值为1,src也可以是一个跟batch_data.label同样大小的tensor,具体可以看这篇文章

最后经过scatter_我们得到label长下面的样子:
在这里插入图片描述
所以应该知道: 当dim设置为1的时候,我们遍历label的每一行,然后去填充该行的一些列

  • 哪些列?——去batch_data.label的对应行找(所以这两个tensor的行数在dim=1时,需要相等;如果dim=0,则"需要变换的那个tensor"和“index”的列数需要相等)
  • 填充什么东西?——去src那个参数的地方找,如果src就是一个值,那填充的就是那个值,否则src就必须要是一个跟index参数大小一样的一个tensor。

这样我们就可以进一步把pred_label和label送入一个BCEloss function去计算loss了。

self.training_criterion = nn.BCELoss()
loss = self.training_criterion(pred_label, label)

参考

Pytorch scatter_ 理解轴的含义
BCE loss

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值