原文链接:https://blog.csdn.net/wudi_X/article/details/80417120
我们可以从SSD的caffe源码中得到test的mAP,mAP是不同阈值下的precision均值,但如果我们想看某个阈值下的recall和precision时,就需要对solver.cpp源码做一定修改。
此处阈值指confidence_threshold,并非原文中所指overlap_threshold。
关于mAP, recall和precision的解释这里不赘述:
修改caffe.proto
首先在src/caffe/proto/caffe.proto中的SolverParameter这个message下加上一个参数rec_prec_thr,该参数是判断样本是否为true positive (tp) 的score阈值,我们给他一个默认值0.6,代码如下(注意序列号在自己的SolverParameter最后的序列号上加1)
optional float rec_prec_thr = 45 [default = 0.6];
修改solver.cpp
接下来在src/caffe/solver.cpp的void Solver<Dtype>::TestDetection(const int test_net_id)
函数中加入计算recall和precision的代码:
template <typename Dtype>
void Solver<Dtype>::TestDetection(const int test_net_id) {
CHECK(Caffe::root_solver());
LOG(INFO) << "Iteration " << iter_
<< ", Testing net (#" << test_net_id << ")";
CHECK_NOTNULL(test_nets_[test_net_id].get())->
ShareTrainedLayersWith(net_.get());
map<int, map<int, vector<pair<float, int> > > > all_true_pos;
map<int, map<int, vector<pair<float, int> > > > all_false_pos;
map<int, map<int, int> > all_num_pos;
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
Dtype loss = 0;
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
SolverAction::Enum request = GetRequestedAction();
// Check to see if stoppage of testing/training has been requested.
while (request != SolverAction::NONE) {
if (SolverAction::SNAPSHOT == request) {
Snapshot();
} else if (SolverAction::STOP == request) {
requested_early_exit_ = true;
}
request = GetRequestedAction();
}
if (requested_early_exit_) {
// break out of test loop.
break;
}
Dtype iter_loss;
const vector<Blob<Dtype>*>& result = test_net->Forward(&iter_loss);
if (param_.test_compute_loss()) {
loss += iter_loss;
}
for (int j = 0; j < result.size(); ++j) {
CHECK_EQ(result[j]->width(), 5);
const Dtype* result_vec = result[j]->cpu_data();
int num_det = result[j]->height();
for (int k = 0; k < num_det; ++k) {
int item_id = static_cast<int>(result_vec[k * 5]);
int label = static_cast<int>(result_vec[k * 5 + 1]);
if (item_id == -1) {
// Special row of storing number of positives for a label.
if (all_num_pos[j].find(label) == all_num_pos[j].end()) {
all_num_pos[j][label] = static_cast<int>(result_vec[k * 5 + 2]);
} else {
all_num_pos[j][label] += static_cast<int>(result_vec[k * 5 + 2]);
}
} else {
// Normal row storing detection status.
float score = result_vec[k * 5 + 2];
int tp = static_cast<int>(result_vec[k * 5 + 3]);
int fp = static_cast<int>(result_vec[k * 5 + 4]);
if (tp == 0 && fp == 0) {
// Ignore such case. It happens when a detection bbox is matched to
// a difficult gt bbox and we don't evaluate on difficult gt bbox.
continue;
}
all_true_pos[j][label].push_back(std::make_pair(score, tp));
all_false_pos[j][label].push_back(std::make_pair(score, fp));
}
}
}
}
if (requested_early_exit_) {
LOG(INFO) << "Test interrupted.";
return;
}
if (param_.test_compute_loss()) {
loss /= param_.test_iter(test_net_id);
LOG(INFO) << "Test loss: " << loss;
}
for (int i = 0; i < all_true_pos.size(); ++i) {
if (all_true_pos.find(i) == all_true_pos.end()) {
LOG(FATAL) << "Missing output_blob true_pos: " << i;
}
const map<int, vector<pair<float, int> > >& true_pos =
all_true_pos.find(i)->second;
if (all_false_pos.find(i) == all_false_pos.end()) {
LOG(FATAL) << "Missing output_blob false_pos: " << i;
}
const map<int, vector<pair<float, int> > >& false_pos =
all_false_pos.find(i)->second;
if (all_num_pos.find(i) == all_num_pos.end()) {
LOG(FATAL) << "Missing output_blob num_pos: " << i;
}
const map<int, int>& num_pos = all_num_pos.find(i)->second;
map<int, float> APs;
map<int, float> recalls;// 各个类别的recall
map<int, float> precisions;// 各个类别的precision
float mAP = 0.;
// Sort true_pos and false_pos with descend scores.
for (map<int, int>::const_iterator it = num_pos.begin();
it != num_pos.end(); ++it) {
int label = it->first;
int label_num_pos = it->second;
if (true_pos.find(label) == true_pos.end()) {
LOG(WARNING) << "Missing true_pos for label: " << label;
continue;
}
const vector<pair<float, int> >& label_true_pos =
true_pos.find(label)->second;
if (false_pos.find(label) == false_pos.end()) {
LOG(WARNING) << "Missing false_pos for label: " << label;
continue;
}
const vector<pair<float, int> >& label_false_pos =
false_pos.find(label)->second;
vector<float> prec, rec;
ComputeAP(label_true_pos, label_num_pos, label_false_pos,
param_.ap_version(), &prec, &rec, &(APs[label]));
mAP += APs[label];
// 在阈值下计算recall和precision,并打印出来
float thr = param_.rec_prec_thr(); // 可在solver中定义的阈值参数
int tp_sum = 0; // true positive的总数
int fp_sum = 0; // false positive的总数
for(int i = 0; i < label_true_pos.size(); ++i) {// 计算tp
if(label_true_pos[i].first > thr) {
tp_sum += label_true_pos[i].second;
}
}
recalls[label] = static_cast<float>(tp_sum) / label_num_pos;
for(int i = 0; i < label_false_pos.size(); ++i) {// 计算fp
if(label_false_pos[i].first > thr) {
fp_sum += label_false_pos[i].second;
}
}
precisions[label] = static_cast<float>(tp_sum) / (tp_sum + fp_sum);
if (param_.show_per_class_result()) {
LOG(INFO) << "class" << label << ": " << APs[label];
}
}
mAP /= num_pos.size();
const int output_blob_index = test_net->output_blob_indices()[i];
const string& output_name = test_net->blob_names()[output_blob_index];
LOG(INFO) << " Test net output #" << i << ": " << output_name << " = "
<< mAP;
// 打印出所有类别在一定阈值下的recall
LOG(INFO) << "-------------recalls-----------";
for(map<int, float>::const_iterator it = recalls.begin();
it != recalls.end(); ++it) {
int label = it->first;
float recall = it->second;
LOG(INFO) << "class" << label << ": " << recall;
}
LOG(INFO) << "-------------recalls-----------";
// 打印出所有类别在一定阈值下的precision
LOG(INFO) << "-------------precisions-----------";
for(map<int, float>::const_iterator it = precisions.begin();
it != precisions.end(); ++it) {
int label = it->first;
float precision = it->second;
LOG(INFO) << "class" << label << ": " << precision;
}
LOG(INFO) << "-------------precisions-----------";
}
}
重新编译caffe即可
make clean
make all
在使用时通过在solver.prototxt中设定rec_prec_thr
的值来调整阈值。