一. 论文介绍
标题:Training Region-based Object Detectors with Online Hard Example Mining
论文:https://arxiv.org/pdf/1604.03540.pdf
虽然该论文是几年前的文章了,但依然值得读,该思想可以用到其他视觉任务中。
二. 核心思想
核心思想:论文提出了一种困难负样本挖掘的方法。困难负样本是指AI模型难以区分的负样本。在模型不断训练的过程中,模型通常会对正样本有着比较高的confidence,但少不了对某些负样本也留有余芥,给了一个不那么接近0的confidence。而困难负例挖掘就是找到这些负例,然后针对性地训练。OHEM提出是一种线上的困难负例挖掘解决方案。使用了这个trick以后,检测模型的准确性有一定提升。
实现方法:首先按照损失进行排序,屏蔽掉loss值非常低的小片。loss值非常高的小片意味着,模型训练很多次还对这些小片有着很高的loss,那么就认为这是困难负例。所谓的线上挖掘,就是先计算loss->筛选->得到困难负例。
代码实现:
def ohem_loss(
batch_size, cls_pred, cls_target, loc_pred, loc_target, smooth_l1_sigma=1.0
):
"""
Arguments:
batch_size (int): number of sampled rois for bbox head training
loc_pred (FloatTensor): [R, 4], location of positive rois
loc_target (FloatTensor): [R, 4], location of positive rois
pos_mask (FloatTensor): [R], binary mask for sampled positive rois
cls_pred (FloatTensor): [R, C]
cls_target (LongTensor): [R]
Returns:
cls_loss, loc_loss (FloatTensor)
"""
ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target, sigma=smooth_l1_sigma, reduce=False)
#这里先暂存下正常的分类loss和回归loss
loss = ohem_cls_loss + ohem_loc_loss
#然后对分类和回归loss求和
sorted_ohem_loss, idx = torch.sort(loss, descending=True)
#再对loss进行降序排列
keep_num = min(sorted_ohem_loss.size()[0], batch_size)
#得到需要保留的loss数量
if keep_num < sorted_ohem_loss.size()[0]:
#这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留
keep_idx_cuda = idx[:keep_num]
#保留到需要keep的数目
ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]
ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]
#分类和回归保留相同的数目
cls_loss = ohem_cls_loss.sum() / keep_num
loc_loss = ohem_loc_loss.sum() / keep_num
#然后分别对分类和回归loss求均值
return cls_loss, loc_loss
将ohem用在分类任务中:
def ohem_loss( rate, cls_pred, cls_target ):
batch_size = cls_pred.size(0)
ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
sorted_ohem_loss, idx = torch.sort(ohem_cls_loss, descending=True)
keep_num = min(sorted_ohem_loss.size()[0], int(batch_size*rate) )
if keep_num < sorted_ohem_loss.size()[0]:
keep_idx_cuda = idx[:keep_num]
ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]
cls_loss = ohem_cls_loss.sum() / keep_num
return cls_loss