- 传送门:
- 相关OHEM的介绍:检测模型改进—OHEM与Focal-Loss算法总结
- 代码地址:OHEM
1. 前言
有关OHEM的介绍请参考上面给出的链接,这里主要就OHEM是怎么运行的做一些简单的分析,整个OHEM的代码也不是很多,这里将算法的步骤归纳为:
1)计算检测器的损失,这部分是使用和最后fc6、fc7预测头一样的共享参数,预测分类与边界框回归的结果,将预测的结果与GT进行比较得到分类和边界框回归的loss,这里的损失是将两种损失相加得到的;
2)使用阈值为0.7的NMS预先处理一遍检测框,去除一些无效的检测框;
3)NMS之后的检测框按照loss由大到小排列,选取一定数目(由两个数取最小决定)的边界框返回。
下面是OHEM在网络定义文件中的定义,方便后面查看相关代码的时候查找对应条目。
layer {
name: "hard_roi_mining"
type: "Python"
bottom: "cls_prob_readonly"
bottom: "bbox_pred_readonly"
bottom: "rois"
bottom: "labels"
bottom: "bbox_targets"
bottom: "bbox_inside_weights"
bottom: "bbox_outside_weights"
top: "rois_hard"
top: "labels_hard"
top: "bbox_targets_hard"
top: "bbox_inside_weights_hard"
top: "bbox_outside_weights_hard"
propagate_down: false
propagate_down: false
propagate_down: false
propagate_down: false
propagate_down: false
propagate_down: false
propagate_down: false
python_param {
module: "roi_data_layer.layer"
layer: "OHEMDataLayer"
param_str: "'num_classes': 6" #6
}
}
2. OHEM代码简单梳理
2.1 OHEMDataLayer
class OHEMDataLayer(caffe.Layer):
"""Online Hard-example Mining Layer."""
def setup(self, bottom, top):
"""Setup the OHEMDataLayer."""
# parse the layer parameter string, which must be valid YAML
layer_params = yaml.load(self.param_str_)
self._num_classes = layer_params['num_classes'] # 获取分类数目
self._name_to_bottom_map = {
# 将bottom的blob名称与index使用dict关联
'cls_prob_readonly': 0,
'bbox_pred_readonly': 1,
'rois': 2,
'labels': 3}
if cfg.TRAIN.BBOX_REG: # 有边界框回归
self._name_to_bottom_map['bbox_targets'] = 4
self._name_to_bottom_map['bbox_loss_weights'] = 5
self._name_to_top_map = {
} # 同理top的blob名称也要与index关联起来
……
# 前向传播函数
def forward(self