0.先搞清楚几种分类问题
图片分类问题中通常一张图片中同时有多个目标,所以会有多个标签,通常分类问题可以如下划分
样本单标签 | 样本多标签 | |
类别数量=2 | 简单二分类 | 当作多个二分类 |
类别数量>2 | 多分类 | 当作多个二分类 |
1.pytorch常用loss总结
pytorch中常用的几种loss:BCELoss、BCEWithLogitsLoss、NLLLoss、CrossEntropyLoss
1.1 NLLLoss和CrossEntropyLoss
总结一下:CrossEntropyLoss(x, label) = softmax(x) + log(x) + NLLLoss(x, label)
应用在表格1中的多分类问题中
1.2 BCELoss和BCEWithLogitsLoss
BCEWithLogitsLoss(x, label)= sigmoid(x) + BCELoss(x, label)
应用在表格1中的所有二分类问题
2. softlabel的loss实现
2.1 二分类问题
通过BCEWithLogitsLoss可以直接实现softlabel的loss
2.2 多分类问题
因为CrossEntropyLoss和NLLLoss都是默认hardlabel实现的,所以:
根据交叉熵的公式 target * log(p)
import torch
import torch.nn.functional as F
def SoftCrossEntropy(inputs, target, reduction='sum'):
log_likelihood = -F.log_softmax(inputs, dim=1)
batch = inputs.shape[0]
if reduction == 'average':
loss = torch.sum(torch.mul(log_likelihood, target)) / batch
else:
loss = torch.sum(torch.mul(log_likelihood, target))
return loss