OHEM,Focal loss,GHM loss二分类pytorch代码实现(减轻难易样本不均衡问题)

https://mp.weixin.qq.com/s/iOAICJege2b0pCVxPkvNiA
综述:解决目标检测中的样本不均衡问题
该综述主要介绍了OHEM,Focal loss,GHM loss;由于我这的二分类数据集不存在正负样本不均衡的问题,所以着重看了处理难易样本不均衡(正常情况下,容易的样本较多,困难的样本较少);由于我只是分类问题,所以写了各种分类的loss,且网络的最后一层为softmax,所以网络输出的pred是softmax层前的logits经过softmax后的结果,普通的交叉熵损失即为sum(-gt*log(pred)),但torch.nn.CrossEntropyLoss()中会对于输入的pred再进行一次softmax,所以这里使用torch.nn.NLLLoss代替,当然经测试,即使网络最后一层使用softmax损失函数还是使用torch.nn.CrossEntropyLoss(),效果和使用torch.nn.NLLLoss差不多。。。

OHEM:
代码参考:https://www.codeleading.com/article/7442852142/

def ohem_loss(pred, target, keep_num):
    loss = torch.nn.NLLLoss(reduce=False)(torch.log(pred), target)
    print(loss)
    loss_sorted, idx = torch.sort(loss, descending=True)
    loss_keep = loss_sorted[:keep_num]
    return loss_keep.sum() / keep_num

Focal loss:
详解:原论文Focal Loss for Dense Object Detection
代码参考:https://zhuanlan.zhihu.com/p/80594704

def focal_loss(pred,target,gamma=0.5):
    pred_temp=pred.
  • 7
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
GHM loss是一种用于解决目标类别分布不均衡问题的损失函数。在PyTorch中可以通过以下代码实现GHM loss: ```python import torch class GHMLoss(torch.nn.Module): def __init__(self, bins=10, alpha=0.75): super(GHMLoss, self).__init__() self.bins = bins self.alpha = alpha self.edges = [x / bins for x in range(bins + 1)] self.edges[-1] += 1e-6 def forward(self, input, target): N, C = input.size() grad_input = input.clone().detach() grad_input.zero_() target = target.view(-1, 1) edges = self.edges inds = (torch.arange(1, self.bins + 1).float() / self.bins).to(input.device) weights = torch.zeros((self.bins,)).to(input.device) weights[0] = inds[0] weights[1:] = inds[1:] - inds[:-1] inds = (target * self.bins).long().clamp(0, self.bins - 1) weights = weights[inds.view(-1)] Ns = torch.zeros((self.bins,)).to(input.device) for i in range(self.bins): Ns[i] = ((inds == i).sum()).float() Ns[Ns == 0] = float('inf') weights = (weights * Ns).sqrt() weights = (weights / weights.sum()) * self.bins inds = torch.bucketize(input.softmax(dim=1)[:, 0], edges) g = -(target - input.softmax(dim=1)[:, 0]).detach().abs() grad_input[:, 0] = g / (2 * g.abs().mean() + 1e-8) grad_input[:, 0] *= weights[inds.view(-1)].view(N, 1) return torch.nn.functional.binary_cross_entropy_with_logits(input, target, weight=None, reduction='mean', pos_weight=None, label_smoothing=None) + grad_input.sum() * self.alpha / N ``` 其中,`bins`表示将概率分布分成的区间数量,`alpha`为平衡交叉熵损失和GHM损失的权重。在`forward`函数中,首先计算每个样本的概率分布落在哪个区间,并根据该区间的样本数量和梯度权重计算出每个样本的权重。然后,根据权重计算GHM损失,并计算交叉熵损失和GHM损失的加权和。最后,将GHM损失的梯度乘以`alpha`并加入到交叉熵损失的梯度中,返回总的损失值。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值