在线难例挖掘OHEM
论文Training Region-based Object Detectors with Online Hard Example Mining
源码链接https://github.com/abhi2610/ohem
基于Fast R-CNN,框架基于Caffe
OHEM训练样本生成的部分在/ohem/blob/master/lib/roi_data_layer/minibatch.py里的 get_ohem_minibatch函数
def get_ohem_minibatch(loss, rois, labels, bbox_targets=None,
bbox_inside_weights=None, bbox_outside_weights=None):
"""Given rois and their loss, construct a minibatch using OHEM."""
loss = np.array(loss)
if cfg.TRAIN.OHEM_USE_NMS:
# Do NMS using loss for de-dup and diversity
keep_inds = []
nms_thresh = cfg.TRAIN.OHEM_NMS_THRESH
source_img_ids = [roi[0] for roi in rois]
for img_id in np.unique(source_img_ids):
for label in np.unique(labels):
sel_indx = np.where(np.logical_and(labels == label, \
source_img_ids == img_id))[0]
if not len(sel_indx):
continue
boxes = np.concatenate((rois[sel_indx, 1:],
loss[sel_indx][:,np.newaxis]), axis=1).astype(np.float32)
keep_inds.extend(sel_indx[nms(boxes, nms_thresh)])
hard_keep_inds = select_hard_examples(loss[keep_inds])
hard_inds = np.array(keep_inds)[hard_keep_inds]
else:
hard_inds = select_hard_examples(loss)
blobs = {'rois_hard': rois[hard_inds, :].copy(),
'labels_hard': labels[hard_inds].copy()}
if bbox_targets is not None:
assert cfg.TRAIN.BBOX_REG
blobs['bbox_targets_hard'] = bbox_targets[hard_inds, :].copy()
blobs['bbox_inside_weights_hard'] = bbox_inside_weights[hard_inds, :].copy()
blobs['bbox_outside_weights_hard'] = bbox_outside_weights[hard_inds, :].copy()
return blobs
简单来说就是输入一些batch的loss,挑选出loss最大的指定batch size个数的样本