分类问题常用的几种损失,记录下来备忘,后续不断完善。
nn.CrossEntropyLoss()交叉熵损失
常用于多分类问题
CE = nn.CrossEntropyLoss()
loss = CE(input,target)
Input: (N, C) , dtype: float, N是样本数量,在批次计算时通常就是batch_size
target: (N), dtype: long,是类别号,0 ≤ targets[i] ≤ C−1
pytorch中的交叉熵损失就是softmax和NLL损失的组合,即
nn.CrossEntropyLoss()(input,target) == nn.NLLLoss()(torch.log(nn.Softmax()(input)),target)
nn.NLLLoss()
NLL = nn.NLLLoss()
loss = NLL(input,target)
Input: (N, C) , dtype: float, N是样本数量,在批次计算时通常就是batch_size
target: (N), dtype: long,是类别号,0 ≤ targets[i] ≤ C−1
nn.BCELoss() 二元交叉熵损失
常用于二分类或多标签分类
BCE = nn.BCELoss()
loss = BCE(input,target)
Input: (N, x) , dtype: float, N是样本数量,在批次计算时通常就是batch_size,x是标签数
target: (N, x), dtype: float,通常是标签的独热码形式,注意需改成float格式
nn.BCEWithLogitsLoss()
相当于BCE加上sigmoid
nn.BCEWithLogitsLoss()(input,target) == nn.BCELoss()(torch.sigmoid(input),target)
focal_loss
focal loss在pytorch中没有,它常用在目标检测问题中,公式和曲线见论文中的图:
带平衡参数的focal loss公式如下:
代码:(待后补)
heatmap_loss
heatmap_loss出现在anchor-free的目标检测网络centernet和conernet中,它在focal loss的基础上进一步改进,加入了对热点区域的损失减小的措施,以使模型输出可以较容易的收敛到检测点附件区域。(否则,必须收敛到检测点的话,难度太大,收敛速度慢)
注意,它只是在otherwise情况下多加了一个
(
1
−
Y
x
y
c
)
β
(1-Y_{xyc})^\beta
(1−Yxyc)β 除此之外,就是focal loss