在线难例挖掘:Online Hard Example Mining (OHEM)

详细链接:https://erogol.com/online-hard-example-mining-pytorch/

OHEM通过减少计算成本来选择难例,提高网络性能。它主要用于目标检测。假设你想训练一个汽车检测器,并且你有正样本图像(图像中有汽车)和负样本图像(图像中没有汽车)。现在你想训练你的网络。实际上,你会发现负样本数量会远远大于正样本数量,而且大量的负样本是比较简单的。因此,比较明智的做法是选择一部分对网络最有帮助的负样本(比较有难度的,容易被识别为正样本的)参与训练。难例挖掘就是用于选择对网络最有帮助的负样本的。

通常来说,通过对网络训练进行一定的迭代后得到临时模型,使用临时模型对所有的负样本进行测试,便可以发现那些loss很大的负样本实例,这些实例就是所谓的难例。但是这种查找难例的方法,需要很大的计算量,因为负样本图像可能会很多;另外这一方法可能是次优的,当你进行难例挖掘的时候,模型的权重是固定的,当前权重下的难例未必适用于接下了的迭代(这句话不太理解)。也就是说,这里假设你选择的所有难例负样本对下一次迭代都是有用的,直到下一次难例选择。这是一个不完美的假设,尤其是对于大型数据集而言。

OHEM通过批量难例选择选择来解决上述两个问题。给定batch-size K,前向传播保持不变,计算损失。然后,选择M(M<K)个高损失值的实例,仅使用这M个实例的损失进行反向传播。

OHEM的具体pytorch实现代码如下:

import torch as th                                                                 
                                                                                   
                                                                                   
class NLL_OHEM(th.nn.NLLLoss):                                                     
    """ Online hard example mining. 
    Needs input from nn.LogSotmax() """                                             
                                                                                   
    def __init__(self, ratio):      
        super(NLL_OHEM, self).__init__(None, True)                                 
        self.ratio = ratio                                                         
                                                                                   
    def forward(self, x, y, ratio=None):                                           
        if ratio is not None:                                                      
            self.ratio = ratio                                                     
        num_inst = x.size(0)                                                       
        num_hns = int(self.ratio * num_inst)                                       
        x_ = x.clone()                                                             
        inst_losses = th.autograd.Variable(th.zeros(num_inst)).cuda()              
        for idx, label in enumerate(y.data):                                       
            inst_losses[idx] = -x_.data[idx, label]                                 
        #loss_incs = -x_.sum(1)                                                    
        _, idxs = inst_losses.topk(num_hns)                                        
        x_hn = x.index_select(0, idxs)                                             
        y_hn = y.index_select(0, idxs)                                             
        return th.nn.functional.nll_loss(x_hn, y_hn)     

  • 0
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值