pytorch中二分类的lossf函数使用

多分类不需要满足y_pred和y_gt的维度相同
在这里插入图片描述
二分类中,torch.nn.BECLosstorch,nn.BCEWithLogitsLoss需要预测输出y_pred和y_gt需要一样的维度
故需要使用一下两种转变y_gt的方法。
假设y_pred如下:

[[1.2 2.3 ],
[2.1 2.2 ]]

其中,y_gt如下:

[[0],[1]]

目标是是转化成下面的维度:

[[1,0],
[0,1]]

代码如下:

  1. 直接使用list的index属性。
gt_y_temp = torch.zeros(gt_y.shape[0], 2)
gt_y_temp[range(gt_y.shape[0]), list(gt_y.squeeze(1).int())] = 1
gt_y_temp=gt_y_temp.cuda()
  1. 使用scatter来使用
gt_y_temp = torch.zeros(gt_y.shape[0], 2, device='cuda').scatter_(1, gt_y.long(), torch.tensor(1, dtype=torch.float)).cuda()

注意这个地方gt_y的特征维度最后一个维度是1.比如:[[0],[1],[0]]

更多案例

>>>class_num = 10
>>>batch_size = 4
>>>label = torch.LongTensor(batch_size, 1).random_() % class_num
 3
 0
 0
 8

>>>one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)
    0     0     0     1     0     0     0     0     0     0
    1     0     0     0     0     0     0     0     0     0
    1     0     0     0     0     0     0     0     0     0
    0     0     0     0     0     0     0     0     1     0

参考博客:
pytorch中scatter的使用原理
三种one_hot的办法
scatter使用方法案例
How is Pytorch’s binary_cross_entropy_with_logits function related to sigmoid and binary_cross_entropy

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值