多分类不需要满足y_pred和y_gt的维度相同
二分类中,torch.nn.BECLoss
和torch,nn.BCEWithLogitsLoss
需要预测输出y_pred和y_gt需要一样的维度。
故需要使用一下两种转变y_gt的方法。
假设y_pred如下:
[[1.2 2.3 ],
[2.1 2.2 ]]
其中,y_gt如下:
[[0],[1]]
目标是是转化成下面的维度:
[[1,0],
[0,1]]
代码如下:
- 直接使用list的index属性。
gt_y_temp = torch.zeros(gt_y.shape[0], 2)
gt_y_temp[range(gt_y.shape[0]), list(gt_y.squeeze(1).int())] = 1
gt_y_temp=gt_y_temp.cuda()
- 使用scatter来使用
gt_y_temp = torch.zeros(gt_y.shape[0], 2, device='cuda').scatter_(1, gt_y.long(), torch.tensor(1, dtype=torch.float)).cuda()
注意这个地方gt_y的特征维度最后一个维度是1.比如:[[0],[1],[0]]
更多案例
>>>class_num = 10
>>>batch_size = 4
>>>label = torch.LongTensor(batch_size, 1).random_() % class_num
3
0
0
8
>>>one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)
0 0 0 1 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 1 0
参考博客:
pytorch中scatter的使用原理
三种one_hot的办法
scatter使用方法案例
How is Pytorch’s binary_cross_entropy_with_logits function related to sigmoid and binary_cross_entropy