论文:Training Region-based Object Detectors with Online Hard Example Mining (CVPR 2016)
OHEM算法思想
目标检测中的目标框和背景框之间存在严重的不平衡
Fast RCNN 中 Rois Sampling 操作时就是一种平衡正负样本的手段,将正负样本比例调整到 1:3 左右训练效果较好。
OHEM其实就是一种挑选Rois的方法,它的挑选依据是Rois的Loss,挑选出Loss大的hard examples使得网络的训练更有针对性
网络结构
Fast RCNN
Fast RCNN + OHEM
最容易想到的方法直接在损失层挑选出hard example,将各个Rois的Loss由大到小排序,选择Loss大的Rois,其余的Roi Loss置为0,然后再进行反向传播,但这样每个Roi仍需要反向传播,训练效率低
所以文中用的另一种方法如下图:相当于加了一个只向前传播的副RoI Net,副RoI Net对所有的RoIs计算loss,然后挑选出困难样本送入常规的RoI Net中正常训练,副RoI Net 共享常规RoI Net 的参数
注意:挑选hard example时需要用NMS去除重合率较大的ROI,避免同一区域重复计算损失
OHEM适合于batch size(images)较少,但每张 image 的 examples 很多的情况
ohem_loss
def ohem_loss(
batch_size, cls_pred, cls_target, loc_pred, loc_target, smooth_l1_sigma=1.0
):
"""
Arguments:
batch_size (int): number of sampled rois for bbox head training
loc_pred (FloatTensor): [R, 4], location of positive rois
loc_target (FloatTensor): [R, 4], location of positive rois
pos_mask (FloatTensor): [R], binary mask for sampled positive rois
cls_pred (FloatTensor): [R, C]
cls_target (LongTensor): [R]
Returns:
cls_loss, loc_loss (FloatTensor)
"""
ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)
ohem_loc_loss = smooth_l1_loss(loc_pred, loc_target, sigma=smooth_l1_sigma, reduce=False)
#这里先暂存下正常的分类loss和回归loss
loss = ohem_cls_loss + ohem_loc_loss
#然后对分类和回归loss求和
sorted_ohem_loss, idx = torch.sort(loss, descending=True)
#再对loss进行降序排列
keep_num = min(sorted_ohem_loss.size()[0], batch_size)
#得到需要保留的loss数量
if keep_num < sorted_ohem_loss.size()[0]:
#这句的作用是如果保留数目小于现有loss总数,则进行筛选保留,否则全部保留
keep_idx_cuda = idx[:keep_num]
#保留到需要keep的数目
ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]
ohem_loc_loss = ohem_loc_loss[keep_idx_cuda]
#分类和回归保留相同的数目
cls_loss = ohem_cls_loss.sum() / keep_num
loc_loss = ohem_loc_loss.sum() / keep_num
#然后分别对分类和回归loss求均值
return cls_loss, loc_loss
实验
参考文献
【1】OHEM论文解读
【2】OHEM (CVPR, 2016)
【3】Hard Negative Mining/OHEM 你真的知道二者的区别吗?
【4】目标检测-Training with Online Hard Example Mining