SSD 算法detection_evaluate_layer解读

代码位置
caffe/include/caffe/layers/detection_evaluate_layer.hpp

#ifndef CAFFE_DETECTION_EVALUATE_LAYER_HPP_
#define CAFFE_DETECTION_EVALUATE_LAYER_HPP_

#include <utility>
#include <vector>

#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

namespace caffe {

/**
 * @brief Generate the detection evaluation based on DetectionOutputLayer and
 * ground truth bounding box labels.
 *
 * Intended for use with MultiBox detection method.
 *
 * NOTE: does not implement Backwards operation.
 */
template <typename Dtype>
class DetectionEvaluateLayer : public Layer<Dtype> {
 public:
  explicit DetectionEvaluateLayer(const LayerParameter& param)
      : Layer<Dtype>(param) {}
  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);
  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);

  virtual inline const char* type() const { return "DetectionEvaluate"; }
  virtual inline int ExactBottomBlobs() const { return 2; }
  virtual inline int ExactNumTopBlobs() const { return 1; }

 protected:
  /**
   * @brief Evaluate the detection output.
   *
   * @param bottom input Blob vector (exact 2)
   *   -# @f$ (1 \times 1 \times N \times 7) @f$
   *      N detection results.
   *   -# @f$ (1 \times 1 \times M \times 7) @f$
   *      M ground truth.
   * @param top Blob vector (length 1)
   *   -# @f$ (1 \times 1 \times N \times 4) @f$
   *      N is the number of detections, and each row is:
   *      [image_id, label, confidence, true_pos, false_pos]
   */
  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top);
  /// @brief Not implemented
  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
    NOT_IMPLEMENTED;
  }

  int num_classes_;
  int background_label_id_;
  float overlap_threshold_;
  bool evaluate_difficult_gt_;
  vector<pair<int, int> > sizes_;
  int count_;
  bool use_normalized_bbox_;
  bool has_resize_;
  ResizeParameter resize_param_;
};

}  // namespace caffe

#endif  // CAFFE_DETECTION_EVALUATE_LAYER_HPP_

caffe/src/caffe/layers/detection_evaluate_layer.cpp

#include <algorithm>
#include <map>
#include <string>
#include <vector>

#include "caffe/layers/detection_evaluate_layer.hpp"
#include "caffe/util/bbox_util.hpp"

namespace caffe {

template <typename Dtype>
void DetectionEvaluateLayer<Dtype>::LayerSetUp(
      const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
  const DetectionEvaluateParameter& detection_evaluate_param =
      this->layer_param_.detection_evaluate_param();
  CHECK(detection_evaluate_param.has_num_classes())
      << "Must provide num_classes.";
  num_classes_ = detection_evaluate_param.num_classes();
  background_label_id_ = detection_evaluate_param.background_label_id();
  overlap_threshold_ = detection_evaluate_param.overlap_threshold();
  CHECK_GT(overlap_threshold_, 0.) << "overlap_threshold must be non negative.";
  evaluate_difficult_gt_ = detection_evaluate_param.evaluate_difficult_gt();
  if (detection_evaluate_param.has_name_size_file()) {
    string name_size_file = detection_evaluate_param.name_size_file();
    std::ifstream infile(name_size_file.c_str());
    CHECK(infile.good())
        << "Failed to open name size file: " << name_size_file;
    // The file is in the following format:
    //    name height width
    //    ...
    string name;
    int height, width;
    while (infile >> name >> height >> width) {
      sizes_.push_back(std::make_pair(height, width));
    }
    infile.close();
  }
  count_ = 0;
  // If there is no name_size_file provided, use normalized bbox to evaluate.
  use_normalized_bbox_ = sizes_.size() == 0;

  // Retrieve resize parameter if there is any provided.
  has_resize_ = detection_evaluate_param.has_resize_param();
  if (has_resize_) {
    resize_param_ = detection_evaluate_param.resize_param();
  }
}

template <typename Dtype>
void DetectionEvaluateLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {
  CHECK_LE(count_, sizes_.size());
  CHECK_EQ(bottom[0]->num(), 1);
  CHECK_EQ(bottom[0]->channels(), 1);
  CHECK_EQ(bottom[0]->width(), 7);
  CHECK_EQ(bottom[1]->num(), 1);
  CHECK_EQ(bottom[1]->channels(), 1);
  CHECK_EQ(bottom[1]->width(), 8);
  // num() and channels() are 1.
  vector<int> top_shape(2, 1);
  int num_pos_classes = background_label_id_ == -1 ?
      num_classes_ : num_classes_ - 1;
  int num_valid_det = 0;
  const Dtype* det_data = bottom[0]->cpu_data();
  for (int i = 0; i < bottom[0]->height(); ++i) {
    if (det_data[1] != -1) {
      ++num_valid_det;
    }
    det_data += 7;
  }
  top_shape.push_back(num_pos_classes + num_valid_det);
  // Each row is a 5 dimension vector, which stores
  // [image_id, label, confidence, true_pos, false_pos]
  top_shape.push_back(5);
  top[0]->Reshape(top_shape);
}

template <typename Dtype>
void DetectionEvaluateLayer<Dtype>::Forward_cpu(
    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
  const Dtype* det_data = bottom[0]->cpu_data();//  bottom: "detection_out"
  const Dtype* gt_data = bottom[1]->cpu_data();  // bottom: "label"

  // Retrieve all detection results.
  map<int, LabelBBox> all_detections;
  GetDetectionResults(det_data, bottom[0]->height(), background_label_id_,// num_det=bottom[0]->height()
                      &all_detections);
/*
*  .defined in /src/caffe/util/bbox_util.cpp

> void GetDetectionResults(const Dtype* det_data, const int num_det,
> 
>       const int background_label_id,
>       map<int, map<int, vector<NormalizedBBox> > >* all_detections) {   all_detections->clear();   for (int i = 0; i < num_det; ++i) {//
>     int start_idx = i * 7;//7
>     /*N : num of det after nms, each row is: [image_id, label, confidence, xmin, ymin, xmax, ymax]*/
>     /*N个将bbox的所有信息存成了一维vector*/
>     int item_id = det_data[start_idx];//0,7...指的是图像ID.
>     if (item_id == -1) {
>       continue;
>     }
>     int label = det_data[start_idx + 1];//每个框的label
>     CHECK_NE(background_label_id, label)//二者相等则输出。。。
>         << "Found background label in the detection results.";
>     NormalizedBBox bbox;
>     bbox.set_score(det_data[start_idx + 2]);
>     bbox.set_xmin(det_data[start_idx + 3]);
>     bbox.set_ymin(det_data[start_idx + 4]);
>     bbox.set_xmax(det_data[start_idx + 5]);
>     bbox.set_ymax(det_data[start_idx + 6]);
>     float bbox_size = BBoxSize(bbox);//box长宽的乘积,加入了边界处理。
>     bbox.set_size(bbox_size);
>     (*all_detections)[item_id][label].push_back(bbox);   } }// 

*
*/

  // Retrieve all ground truth (including difficult ones).
  map<int, LabelBBox> all_gt_bboxes;
  GetGroundTruth(gt_data, bottom[1]->height(), background_label_id_,
                 true, &all_gt_bboxes);


> void GetGroundTruth(const Dtype* gt_data, const int num_gt,
>       const int background_label_id, const bool use_difficult_gt,
>       map<int, vector<NormalizedBBox> >* all_gt_bboxes) {   all_gt_bboxes->clear();   
>       /*查看AnnotatedData层如何读取lmdb并分别存储为label和data,label的结构如下 8 个元素*/
>       /*[item_id(图像id), group_label(每一类的id), instance_id(类内), xmin, ymin, xmax, ymax, diff(?)]
>       */
> 
>     for (int i = 0; i < num_gt; ++i) {
>     int start_idx = i * 8;
>     int item_id = gt_data[start_idx];
>     if (item_id == -1) {
>       continue;
>     }
>     int label = gt_data[start_idx + 1];
>     CHECK_NE(background_label_id, label)
>         << "Found background label in the dataset.";
>     bool difficult = static_cast<bool>(gt_data[start_idx + 7]);
>     if (!use_difficult_gt && difficult) {
>       // Skip reading difficult ground truth. 哪个bbox的label是difficult的??
>       continue;
>     }
>     NormalizedBBox bbox;
>     bbox.set_label(label);
>     bbox.set_xmin(gt_data[start_idx + 3]);
>     bbox.set_ymin(gt_data[start_idx + 4]);
>     bbox.set_xmax(gt_data[start_idx + 5]);
>     bbox.set_ymax(gt_data[start_idx + 6]);
>     bbox.set_difficult(difficult);
>     float bbox_size = BBoxSize(bbox);//面积
>     bbox.set_size(bbox_size);
>     (*all_gt_bboxes)[item_id].push_back(bbox);   } }

Dtype* top_data = top[0]->mutable_cpu_data();
  caffe_set(top[0]->count(), Dtype(0.), top_data);
  int num_det = 0;

  // Insert number of ground truth for each label.
  map<int, int> num_pos;
  for (map<int, LabelBBox>::iterator it = all_gt_bboxes.begin();
       it != all_gt_bboxes.end(); ++it) {
    for (LabelBBox::iterator iit = it->second.begin(); iit != it->second.end();
         ++iit) {
      int count = 0;
      if (evaluate_difficult_gt_) {
        count = iit->second.size();
      } else {
        // Get number of non difficult ground truth.
        for (int i = 0; i < iit->second.size(); ++i) {
          if (!iit->second[i].difficult()) {
            ++count;
          }
        }
      }
      if (num_pos.find(iit->first) == num_pos.end()) {
        num_pos[iit->first] = count;
      } else {
        num_pos[iit->first] += count;
      }
    }
  }
  for (int c = 0; c < num_classes_; ++c) {
    if (c == background_label_id_) {
      continue;
    }
    top_data[num_det * 5] = -1;
    top_data[num_det * 5 + 1] = c;
    if (num_pos.find(c) == num_pos.end()) {
      top_data[num_det * 5 + 2] = 0;
    } else {
      top_data[num_det * 5 + 2] = num_pos.find(c)->second;
    }
    top_data[num_det * 5 + 3] = -1;
    top_data[num_det * 5 + 4] = -1;
    ++num_det;
  }

  // Insert detection evaluate status.
  for (map<int, LabelBBox>::iterator it = all_detections.begin();
       it != all_detections.end(); ++it) {
    int image_id = it->first;// all_detections  map<int, map<int, vector<NormalizedBBox> > >
    LabelBBox& detections = it->second;//  map<int, vector<NormalizedBBox> > 存储每副图像的检测结果。
    if (all_gt_bboxes.find(image_id) == all_gt_bboxes.end()) {
      // No ground truth for current image. All detections become false_pos.
      for (LabelBBox::iterator iit = detections.begin();
           iit != detections.end(); ++iit) {
        int label = iit->first;
        if (label == -1) {
          continue;
        }
        const vector<NormalizedBBox>& bboxes = iit->second;
        for (int i = 0; i < bboxes.size(); ++i) {//每一个box
          top_data[num_det * 5] = image_id;
          top_data[num_det * 5 + 1] = label;
          top_data[num_det * 5 + 2] = bboxes[i].score();
          top_data[num_det * 5 + 3] = 0;// 0和1是做什么的?
          top_data[num_det * 5 + 4] = 1;
          ++num_det;
        }
      }
    } else {
      LabelBBox& label_bboxes = all_gt_bboxes.find(image_id)->second;
      for (LabelBBox::iterator iit = detections.begin();
           iit != detections.end(); ++iit) {
        int label = iit->first;
        if (label == -1) {
          continue;
        }
        vector<NormalizedBBox>& bboxes = iit->second;
        if (label_bboxes.find(label) == label_bboxes.end()) {
          // No ground truth for current label. All detections become false_pos.
          for (int i = 0; i < bboxes.size(); ++i) {
            top_data[num_det * 5] = image_id;
            top_data[num_det * 5 + 1] = label;
            top_data[num_det * 5 + 2] = bboxes[i].score();
            top_data[num_det * 5 + 3] = 0;// 0,1 false_pos.
            top_data[num_det * 5 + 4] = 1;//  1,0 true positive.
            ++num_det;
          }
        } else {
          vector<NormalizedBBox>& gt_bboxes = label_bboxes.find(label)->second;
          // Scale ground truth if needed.
          if (!use_normalized_bbox_) {
            CHECK_LT(count_, sizes_.size());
            for (int i = 0; i < gt_bboxes.size(); ++i) {
              OutputBBox(gt_bboxes[i], sizes_[count_], has_resize_,
                         resize_param_, &(gt_bboxes[i]));
            }
          }
          vector<bool> visited(gt_bboxes.size(), false);
          // Sort detections in descend order based on scores.
          std::sort(bboxes.begin(), bboxes.end(), SortBBoxDescend);
          for (int i = 0; i < bboxes.size(); ++i) {//遍历每一个box。
            top_data[num_det * 5] = image_id;
            top_data[num_det * 5 + 1] = label;
            top_data[num_det * 5 + 2] = bboxes[i].score();
            if (!use_normalized_bbox_) {
              OutputBBox(bboxes[i], sizes_[count_], has_resize_,
                         resize_param_, &(bboxes[i]));
            }
            // Compare with each ground truth bbox.每一个检测出的box遍历匹配图像中每一个gtbox
            float overlap_max = -1;
            int jmax = -1;
            for (int j = 0; j < gt_bboxes.size(); ++j) {
              float overlap = JaccardOverlap(bboxes[i], gt_bboxes[j],
                                             use_normalized_bbox_);// 如果没提供 name_size_file,为True 
              if (overlap > overlap_max) {
                overlap_max = overlap;
                jmax = j;
              }
            }
            if (overlap_max >= overlap_threshold_) {//overlap_max  :如果有某个gtbox和检测出的boxoverlap>0.5
             //overlap_threshold_ 在程序中设置为0.5
              if (evaluate_difficult_gt_ ||//这个gtbox不是背景难例的情况下。
                  (!evaluate_difficult_gt_ && !gt_bboxes[jmax].difficult())) {
                if (!visited[jmax]) {//visited初始化为false,表示这个gtbox未访问
                  // true positive.
                  top_data[num_det * 5 + 3] = 1;
                  top_data[num_det * 5 + 4] = 0;
                  visited[jmax] = true;//访问标记
                } else {
                  // false positive (multiple detection).检测到的bbox已经有和这个gtbox最匹配的了。已访问
                  top_data[num_det * 5 + 3] = 0;
                  top_data[num_det * 5 + 4] = 1;
                }
              }
            } else {//当前遍历的这个检测框没有匹配到任何一个gtbox
              // false positive.
              top_data[num_det * 5 + 3] = 0;
              top_data[num_det * 5 + 4] = 1;
            }
            ++num_det;
          }
        }
      }
    }
    if (sizes_.size() > 0) {
      ++count_;
      if (count_ == sizes_.size()) {
        // reset count after a full iterations through the DB.
        count_ = 0;
      }
    }
  }
}

INSTANTIATE_CLASS(DetectionEvaluateLayer);
REGISTER_LAYER_CLASS(DetectionEvaluate);

}  // namespace caffe
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值