代码位置
caffe/include/caffe/layers/detection_evaluate_layer.hpp
#ifndef CAFFE_DETECTION_EVALUATE_LAYER_HPP_
#define CAFFE_DETECTION_EVALUATE_LAYER_HPP_
#include <utility>
#include <vector>
#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
namespace caffe {
/**
* @brief Generate the detection evaluation based on DetectionOutputLayer and
* ground truth bounding box labels.
*
* Intended for use with MultiBox detection method.
*
* NOTE: does not implement Backwards operation.
*/
template <typename Dtype>
class DetectionEvaluateLayer : public Layer<Dtype> {
public:
explicit DetectionEvaluateLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual inline const char* type() const { return "DetectionEvaluate"; }
virtual inline int ExactBottomBlobs() const { return 2; }
virtual inline int ExactNumTopBlobs() const { return 1; }
protected:
/**
* @brief Evaluate the detection output.
*
* @param bottom input Blob vector (exact 2)
* -# @f$ (1 \times 1 \times N \times 7) @f$
* N detection results.
* -# @f$ (1 \times 1 \times M \times 7) @f$
* M ground truth.
* @param top Blob vector (length 1)
* -# @f$ (1 \times 1 \times N \times 4) @f$
* N is the number of detections, and each row is:
* [image_id, label, confidence, true_pos, false_pos]
*/
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
/// @brief Not implemented
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
NOT_IMPLEMENTED;
}
int num_classes_;
int background_label_id_;
float overlap_threshold_;
bool evaluate_difficult_gt_;
vector<pair<int, int> > sizes_;
int count_;
bool use_normalized_bbox_;
bool has_resize_;
ResizeParameter resize_param_;
};
} // namespace caffe
#endif // CAFFE_DETECTION_EVALUATE_LAYER_HPP_
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
caffe/src/caffe/layers/detection_evaluate_layer.cpp
#include <algorithm>
#include <map>
#include <string>
#include <vector>
#include "caffe/layers/detection_evaluate_layer.hpp"
#include "caffe/util/bbox_util.hpp"
namespace caffe {
template <typename Dtype>
void DetectionEvaluateLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
const DetectionEvaluateParameter& detection_evaluate_param =
this->layer_param_.detection_evaluate_param();
CHECK(detection_evaluate_param.has_num_classes())
<< "Must provide num_classes.";
num_classes_ = detection_evaluate_param.num_classes();
background_label_id_ = detection_evaluate_param.background_label_id();
overlap_threshold_ = detection_evaluate_param.overlap_threshold();
CHECK_GT(overlap_threshold_, 0.) << "overlap_threshold must be non negative.";
evaluate_difficult_gt_ = detection_evaluate_param.evaluate_difficult_gt();
if (detection_evaluate_param.has_name_size_file()) {
string name_size_file = detection_evaluate_param.name_size_file();
std::ifstream infile(name_size_file.c_str());
CHECK(infile.good())
<< "Failed to open name size file: " << name_size_file;
// The file is in the following format:
// name height width
// ...
string name;
int height, width;
while (infile >> name >> height >> width) {
sizes_.push_back(std::make_pair(height, width));
}
infile.close();
}
count_ = 0;
// If there is no name_size_file provided, use normalized bbox to evaluate.
use_normalized_bbox_ = sizes_.size() == 0;
// Retrieve resize parameter if there is any provided.
has_resize_ = detection_evaluate_param.has_resize_param();
if (has_resize_) {
resize_param_ = detection_evaluate_param.resize_param();
}
}
template <typename Dtype>
void DetectionEvaluateLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
CHECK_LE(count_, sizes_.size());
CHECK_EQ(bottom[0]->num(), 1);
CHECK_EQ(bottom[0]->channels(), 1);
CHECK_EQ(bottom[0]->width(), 7);
CHECK_EQ(bottom[1]->num(), 1);
CHECK_EQ(bottom[1]->channels(), 1);
CHECK_EQ(bottom[1]->width(), 8);
// num() and channels() are 1.
vector<int> top_shape(2, 1);
int num_pos_classes = background_label_id_ == -1 ?
num_classes_ : num_classes_ - 1;
int num_valid_det = 0;
const Dtype* det_data = bottom[0]->cpu_data();
for (int i = 0; i < bottom[0]->height(); ++i) {
if (det_data[1] != -1) {
++num_valid_det;
}
det_data += 7;
}
top_shape.push_back(num_pos_classes + num_valid_det);
// Each row is a 5 dimension vector, which stores
// [image_id, label, confidence, true_pos, false_pos]
top_shape.push_back(5);
top[0]->Reshape(top_shape);
}
template <typename Dtype>
void DetectionEvaluateLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
const Dtype* det_data = bottom[0]->cpu_data();// bottom: "detection_out"
const Dtype* gt_data = bottom[1]->cpu_data(); // bottom: "label"
// Retrieve all detection results.
map<int, LabelBBox> all_detections;
GetDetectionResults(det_data, bottom[0]->height(), background_label_id_,// num_det=bottom[0]->height()
&all_detections);
/*
* .defined in /src/caffe/util/bbox_util.cpp
> void GetDetectionResults(const Dtype* det_data, const int num_det,
>
> const int background_label_id,
> map<int, map<int, vector<NormalizedBBox> > >* all_detections) { all_detections->clear(); for (int i = 0; i < num_det; ++i) {//
> int start_idx = i * 7;//7
> /*N : num of det after nms, each row is: [image_id, label, confidence, xmin, ymin, xmax, ymax]*/
> /*N个将bbox的所有信息存成了一维vector*/
> int item_id = det_data[start_idx];//0,7...指的是图像ID.
> if (item_id == -1) {
> continue;
> }
> int label = det_data[start_idx + 1];//每个框的label
> CHECK_NE(background_label_id, label)//二者相等则输出。。。
> << "Found background label in the detection results.";
> NormalizedBBox bbox;
> bbox.set_score(det_data[start_idx + 2]);
> bbox.set_xmin(det_data[start_idx + 3]);
> bbox.set_ymin(det_data[start_idx + 4]);
> bbox.set_xmax(det_data[start_idx + 5]);
> bbox.set_ymax(det_data[start_idx + 6]);
> float bbox_size = BBoxSize(bbox);//box长宽的乘积,加入了边界处理。
> bbox.set_size(bbox_size);
> (*all_detections)[item_id][label].push_back(bbox); } }//
*
*/
// Retrieve all ground truth (including difficult ones).
map<int, LabelBBox> all_gt_bboxes;
GetGroundTruth(gt_data, bottom[1]->height(), background_label_id_,
true, &all_gt_bboxes);
> void GetGroundTruth(const Dtype* gt_data, const int num_gt,
> const int background_label_id, const bool use_difficult_gt,
> map<int, vector<NormalizedBBox> >* all_gt_bboxes) { all_gt_bboxes->clear();
> /*查看AnnotatedData层如何读取lmdb并分别存储为label和data,label的结构如下 8 个元素*/
> /*[item_id(图像id), group_label(每一类的id), instance_id(类内), xmin, ymin, xmax, ymax, diff(?)]
> */
>
> for (int i = 0; i < num_gt; ++i) {
> int start_idx = i * 8;
> int item_id = gt_data[start_idx];
> if (item_id == -1) {
> continue;
> }
> int label = gt_data[start_idx + 1];
> CHECK_NE(background_label_id, label)
> << "Found background label in the dataset.";
> bool difficult = static_cast<bool>(gt_data[start_idx + 7]);
> if (!use_difficult_gt && difficult) {
> // Skip reading difficult ground truth. 哪个bbox的label是difficult的??
> continue;
> }
> NormalizedBBox bbox;
> bbox.set_label(label);
> bbox.set_xmin(gt_data[start_idx + 3]);
> bbox.set_ymin(gt_data[start_idx + 4]);
> bbox.set_xmax(gt_data[start_idx + 5]);
> bbox.set_ymax(gt_data[start_idx + 6]);
> bbox.set_difficult(difficult);
> float bbox_size = BBoxSize(bbox);//面积
> bbox.set_size(bbox_size);
> (*all_gt_bboxes)[item_id].push_back(bbox); } }
Dtype* top_data = top[0]->mutable_cpu_data();
caffe_set(top[0]->count(), Dtype(0.), top_data);
int num_det = 0;
// Insert number of ground truth for each label.
map<int, int> num_pos;
for (map<int, LabelBBox>::iterator it = all_gt_bboxes.begin();
it != all_gt_bboxes.end(); ++it) {
for (LabelBBox::iterator iit = it->second.begin(); iit != it->second.end();
++iit) {
int count = 0;
if (evaluate_difficult_gt_) {
count = iit->second.size();
} else {
// Get number of non difficult ground truth.
for (int i = 0; i < iit->second.size(); ++i) {
if (!iit->second[i].difficult()) {
++count;
}
}
}
if (num_pos.find(iit->first) == num_pos.end()) {
num_pos[iit->first] = count;
} else {
num_pos[iit->first] += count;
}
}
}
for (int c = 0; c < num_classes_; ++c) {
if (c == background_label_id_) {
continue;
}
top_data[num_det * 5] = -1;
top_data[num_det * 5 + 1] = c;
if (num_pos.find(c) == num_pos.end()) {
top_data[num_det * 5 + 2] = 0;
} else {
top_data[num_det * 5 + 2] = num_pos.find(c)->second;
}
top_data[num_det * 5 + 3] = -1;
top_data[num_det * 5 + 4] = -1;
++num_det;
}
// Insert detection evaluate status.
for (map<int, LabelBBox>::iterator it = all_detections.begin();
it != all_detections.end(); ++it) {
int image_id = it->first;// all_detections map<int, map<int, vector<NormalizedBBox> > >
LabelBBox& detections = it->second;// map<int, vector<NormalizedBBox> > 存储每副图像的检测结果。
if (all_gt_bboxes.find(image_id) == all_gt_bboxes.end()) {
// No ground truth for current image. All detections become false_pos.
for (LabelBBox::iterator iit = detections.begin();
iit != detections.end(); ++iit) {
int label = iit->first;
if (label == -1) {
continue;
}
const vector<NormalizedBBox>& bboxes = iit->second;
for (int i = 0; i < bboxes.size(); ++i) {//每一个box
top_data[num_det * 5] = image_id;
top_data[num_det * 5 + 1] = label;
top_data[num_det * 5 + 2] = bboxes[i].score();
top_data[num_det * 5 + 3] = 0;// 0和1是做什么的?
top_data[num_det * 5 + 4] = 1;
++num_det;
}
}
} else {
LabelBBox& label_bboxes = all_gt_bboxes.find(image_id)->second;
for (LabelBBox::iterator iit = detections.begin();
iit != detections.end(); ++iit) {
int label = iit->first;
if (label == -1) {
continue;
}
vector<NormalizedBBox>& bboxes = iit->second;
if (label_bboxes.find(label) == label_bboxes.end()) {
// No ground truth for current label. All detections become false_pos.
for (int i = 0; i < bboxes.size(); ++i) {
top_data[num_det * 5] = image_id;
top_data[num_det * 5 + 1] = label;
top_data[num_det * 5 + 2] = bboxes[i].score();
top_data[num_det * 5 + 3] = 0;// 0,1 false_pos.
top_data[num_det * 5 + 4] = 1;// 1,0 true positive.
++num_det;
}
} else {
vector<NormalizedBBox>& gt_bboxes = label_bboxes.find(label)->second;
// Scale ground truth if needed.
if (!use_normalized_bbox_) {
CHECK_LT(count_, sizes_.size());
for (int i = 0; i < gt_bboxes.size(); ++i) {
OutputBBox(gt_bboxes[i], sizes_[count_], has_resize_,
resize_param_, &(gt_bboxes[i]));
}
}
vector<bool> visited(gt_bboxes.size(), false);
// Sort detections in descend order based on scores.
std::sort(bboxes.begin(), bboxes.end(), SortBBoxDescend);
for (int i = 0; i < bboxes.size(); ++i) {//遍历每一个box。
top_data[num_det * 5] = image_id;
top_data[num_det * 5 + 1] = label;
top_data[num_det * 5 + 2] = bboxes[i].score();
if (!use_normalized_bbox_) {
OutputBBox(bboxes[i], sizes_[count_], has_resize_,
resize_param_, &(bboxes[i]));
}
// Compare with each ground truth bbox.每一个检测出的box遍历匹配图像中每一个gtbox
float overlap_max = -1;
int jmax = -1;
for (int j = 0; j < gt_bboxes.size(); ++j) {
float overlap = JaccardOverlap(bboxes[i], gt_bboxes[j],
use_normalized_bbox_);// 如果没提供 name_size_file,为True
if (overlap > overlap_max) {
overlap_max = overlap;
jmax = j;
}
}
if (overlap_max >= overlap_threshold_) {//overlap_max :如果有某个gtbox和检测出的boxoverlap>0.5
//overlap_threshold_ 在程序中设置为0.5
if (evaluate_difficult_gt_ ||//这个gtbox不是背景难例的情况下。
(!evaluate_difficult_gt_ && !gt_bboxes[jmax].difficult())) {
if (!visited[jmax]) {//visited初始化为false,表示这个gtbox未访问
// true positive.
top_data[num_det * 5 + 3] = 1;
top_data[num_det * 5 + 4] = 0;
visited[jmax] = true;//访问标记
} else {
// false positive (multiple detection).检测到的bbox已经有和这个gtbox最匹配的了。已访问
top_data[num_det * 5 + 3] = 0;
top_data[num_det * 5 + 4] = 1;
}
}
} else {//当前遍历的这个检测框没有匹配到任何一个gtbox
// false positive.
top_data[num_det * 5 + 3] = 0;
top_data[num_det * 5 + 4] = 1;
}
++num_det;
}
}
}
}
if (sizes_.size() > 0) {
++count_;
if (count_ == sizes_.size()) {
// reset count after a full iterations through the DB.
count_ = 0;
}
}
}
}
INSTANTIATE_CLASS(DetectionEvaluateLayer);
REGISTER_LAYER_CLASS(DetectionEvaluate);
} // namespace caffe
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
- 146
- 147
- 148
- 149
- 150
- 151
- 152
- 153
- 154
- 155
- 156
- 157
- 158
- 159
- 160
- 161
- 162
- 163
- 164
- 165
- 166
- 167
- 168
- 169
- 170
- 171
- 172
- 173
- 174
- 175
- 176
- 177
- 178
- 179
- 180
- 181
- 182
- 183
- 184
- 185
- 186
- 187
- 188
- 189
- 190
- 191
- 192
- 193
- 194
- 195
- 196
- 197
- 198
- 199
- 200
- 201
- 202
- 203
- 204
- 205
- 206
- 207
- 208
- 209
- 210
- 211
- 212
- 213
- 214
- 215
- 216
- 217
- 218
- 219
- 220
- 221
- 222
- 223
- 224
- 225
- 226
- 227
- 228
- 229
- 230
- 231
- 232
- 233
- 234
- 235
- 236
- 237
- 238
- 239
- 240
- 241
- 242
- 243
- 244
- 245
- 246
- 247
- 248
- 249
- 250
- 251
- 252
- 253
- 254
- 255
- 256
- 257
- 258
- 259
- 260
- 261
- 262
- 263
- 264
- 265
- 266
- 267
- 268
- 269
- 270
- 271
- 272
- 273
- 274
- 275
- 276
- 277
- 278
- 279
- 280
- 281
- 282
- 283
- 284
- 285
- 286
- 287
- 288
- 289
- 290
- 291
- 292
- 293
- 294
- 295
- 296
- 297
- 298
- 299
- 300
- 301
- 302
- 303
- 304
- 305
- 306
- 307
- 308
- 309
- 310
- 311
- 312
<link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/markdown_views-ea0013b516.css">
</div>