https://mp.weixin.qq.com/s/iOAICJege2b0pCVxPkvNiA
综述:解决目标检测中的样本不均衡问题
该综述主要介绍了OHEM,Focal loss,GHM loss;由于我这的二分类数据集不存在正负样本不均衡的问题,所以着重看了处理难易样本不均衡(正常情况下,容易的样本较多,困难的样本较少);由于我只是分类问题,所以写了各种分类的loss,且网络的最后一层为softmax,所以网络输出的pred是softmax层前的logits经过softmax后的结果,普通的交叉熵损失即为sum(-gt*log(pred)),但torch.nn.CrossEntropyLoss()中会对于输入的pred再进行一次softmax,所以这里使用torch.nn.NLLLoss代替,当然经测试,即使网络最后一层使用softmax损失函数还是使用torch.nn.CrossEntropyLoss(),效果和使用torch.nn.NLLLoss差不多。。。
OHEM:
代码参考:https://www.codeleading.com/article/7442852142/
def ohem_loss(pred, target, keep_num):
loss = torch.nn.NLLLoss(reduce=False)(torch.log(pred), target)
print(loss)
loss_sorted, idx = torch.sort(loss, descending=True)
loss_keep = loss_sorted[:keep_num]
return loss_keep.sum() / keep_num
Focal loss:
详解:原论文Focal Loss for Dense Object Detection
代码参考:https://zhuanlan.zhihu.com/p/80594704
def focal_loss(pred,target,gamma=0.5):
pred_temp=pred.