float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, int letter_box, network *existing_net)
{
//读取设置的参数和数据
int j;
list *options = read_data_cfg(datacfg);
char *valid_images = option_find_str(options, "valid", "data/train.txt");
char *difficult_valid_images = option_find_str(options, "difficult", NULL);
char *name_list = option_find_str(options, "names", "data/names.list");
int names_size = 0;
char **names = get_labels_custom(name_list, &names_size); //get_labels(name_list);
FILE* reinforcement_fd = NULL;
network net;
//初始化网络,并获取验证数据集
if (existing_net) {
char *train_images = option_find_str(options, "train", "data/train.txt");
valid_images = option_find_str(options, "valid", train_images);
net = *existing_net;
remember_network_recurrent_state(*existing_net);
free_network_recurrent_state(*existing_net);
}
else {
net = parse_network_cfg_custom(cfgfile, 1, 1); // set batch=1
if (weightfile) {
load_weights(&net, weightfile);
}
//set_batch_network(&net, 1);
fuse_conv_batchnorm(net);
calculate_binary_weights(net);
}
if (net.layers[net.n - 1].classes != names_size) {
printf(" Error: in the file %s number of names %d that isn't equal to classes=%d in the file %s \n",
name_list, names_size, net.layers[net.n - 1].classes, cfgfile);
getchar();
}
srand(time(0));
printf("\n calculation mAP (mean average precision)...\n");
list *plist = get_paths(valid_images);
char **paths = (char **)list_to_array(plist);
char **paths_dif = NULL;
if (difficult_valid_images) {
list *plist_dif = get_paths(difficult_valid_images);
paths_dif = (char **)list_to_array(plist_dif);
}
layer l = net.layers[net.n - 1];
int classes = l.classes;
int m = plist->size;
int i = 0;
int t;
const float thresh = .005;
const float nms = .45;
//const float iou_thresh = 0.5;
int nthreads = 4;
if (m < 4) nthreads = m;
image* val = (image*)calloc(nthreads, sizeof(image));
image* val_resized = (image*)calloc(nthreads, sizeof(image));
image* buf = (image*)calloc(nthreads, sizeof(image));
image* buf_resized = (image*)calloc(nthreads, sizeof(image));
pthread_t* thr = (pthread_t*)calloc(nthreads, sizeof(pthread_t));
load_args args = { 0 };
args.w = net.w;
args.h = net.h;
args.c = net.c;
if (letter_box) args.type = LETTERBOX_DATA;
else args.type = IMAGE_DATA;
//const float thresh_calc_avg_iou = 0.24;
float avg_iou = 0;
int tp_for_thresh = 0;
int fp_for_thresh = 0;
box_prob* detections = (box_prob*)calloc(1, sizeof(box_prob));
int detections_count = 0;
int unique_truth_count = 0;
int* truth_classes_count = (int*)calloc(classes, sizeof(int));
// For multi-class precision and recall computation
float *avg_iou_per_class = (float*)calloc(classes, sizeof(float));
int *tp_for_thresh_per_class = (int*)calloc(classes, sizeof(int));
int *fp_for_thresh_per_class = (int*)calloc(classes, sizeof(int));
for (t = 0; t < nthreads; ++t) {
args.path = paths[i + t];
args.im = &buf[t];
args.resized = &buf_resized[t];
thr[t] = load_data_in_thread(args);
}
time_t start = time(0);
for (i = nthreads; i < m + nthreads; i += nthreads) {
fprintf(stderr, "\r%d", i);
for (t = 0; t < nthreads && i + t - nthreads < m; ++t) {
pthread_join(thr[t], 0);
val[t] = buf[t];
val_resized[t] = buf_resized[t];
}
for (t = 0; t < nthreads && i + t < m; ++t) {
args.path = paths[i + t];
args.im = &buf[t];
args.resized = &buf_resized[t];
thr[t] = load_data_in_thread(args);
}
for (t = 0; t < nthreads && i + t - nthreads < m; ++t) {
const int image_index = i + t - nthreads;
char *path = paths[image_index];
char *id = basecfg(path);
float *X = val_resized[t].data;
network_predict(net, X);
int nboxes = 0;
float hier_thresh = 0;
detection *dets;
//获取置信度大于thresh的预测框信息,默认thresh = 0.005;
//同时判断如果该预测框在该类的得分大于阈值thresh,则将类得分赋给prob[j],否则赋0,
//prob = objectness*predictions[class_index];;
if (args.type == LETTERBOX_DATA) {
dets = get_network_boxes(&net, val[t].w, val[t].h, thresh, hier_thresh, 0, 1, &nboxes, letter_box);
}
else {
dets = get_network_boxes(&net, 1, 1, thresh, hier_thresh, 0, 0, &nboxes, letter_box);
}
//排除预测同一对象的多个框;这里只是将预测同一目标的预测框的prob小的置为0,不会删减总框数
//即如果两个预测框预测同一类的prob都大于0,且iou大于num_thresh,
//则认为这两个预测框预测了同一个对象,将会令其中该类prob低的那个框的该类prob置为0;
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
//获取标签信息
char labelpath[4096];
replace_image_to_label(path, labelpath);
int num_labels = 0;
box_label *truth = read_boxes(labelpath, &num_labels);
int i, j;
for (j = 0; j < num_labels; ++j) {
truth_classes_count[truth[j].id]++;
}
// difficult
box_label *truth_dif = NULL;
int num_labels_dif = 0;
if (paths_dif)
{
char *path_dif = paths_dif[image_index];
char labelpath_dif[4096];
replace_image_to_label(path_dif, labelpath_dif);
truth_dif = read_boxes(labelpath_dif, &num_labels_dif);
}
const int checkpoint_detections_count = detections_count;
//遍历所有剩下的预测框
for (i = 0; i < nboxes; ++i) {
int class_id;
//遍历每个预测框中的类别prob,
//这里可能会出现一个预测框其多种类别的prob都大于0,即该预测框检测到了多个目标对象
for (class_id = 0; class_id < classes; ++class_id) {
float prob = dets[i].prob[class_id];
//如果prob得分大于0,则认为该框检测到了该类目标对象
if (prob > 0) {
detections_count++;//计数,检测到的目标数,可能会出现大于预测框的现象,即一个预测框对应了多个目标
detections = (box_prob*)realloc(detections, detections_count * sizeof(box_prob));
detections[detections_count - 1].b = dets[i].bbox;
detections[detections_count - 1].p = prob;
detections[detections_count - 1].image_index = image_index;
detections[detections_count - 1].class_id = class_id;
detections[detections_count - 1].truth_flag = 0;
detections[detections_count - 1].unique_truth_index = -1;
int truth_index = -1;//真实框索引,计数真实框
float max_iou = 0;
//遍历该图像所有真实标签框
for (j = 0; j < num_labels; ++j)
{
box t = { truth[j].x, truth[j].y, truth[j].w, truth[j].h };
//得到与该真实框的交并比
float current_iou = box_iou(dets[i].bbox, t);
//如果预测框与该真实框交并比大于阈值且类别也是真实标签的类别,
//则令其与之前和其他真实框的最大交并比进行比较,若大于,则认为其可以预测该真实框
//这里iou_thresh必须大于或等于0.5,
//这样才可以去除之前一个预测框对应多个真实框或一个真实框对应多个预测框的情况
if (current_iou > iou_thresh && class_id == truth[j].id) {
if (current_iou > max_iou) {
max_iou = current_iou;
truth_index = unique_truth_count + j;
}
}
}
// best IoU
//表示该预测框找到了对应的唯一真实框
if (truth_index > -1) {
detections[detections_count - 1].truth_flag = 1;
detections[detections_count - 1].unique_truth_index = truth_index;
}
else {
// if object is difficult then remove detection
for (j = 0; j < num_labels_dif; ++j) {
box t = { truth_dif[j].x, truth_dif[j].y, truth_dif[j].w, truth_dif[j].h };
float current_iou = box_iou(dets[i].bbox, t);
if (current_iou > iou_thresh && class_id == truth_dif[j].id) {
--detections_count;
break;
}
}
}
// calc avg IoU, true-positives, false-positives for required Threshold
//计算平均交并比,tp,fp
//这里排除了prob小于thresh_calc_avg_iou的框,只有大于阈值才能计算tp、fp
if (prob > thresh_calc_avg_iou) {
int z, found = 0;
//排除之前不符合条件的框
for (z = checkpoint_detections_count; z < detections_count - 1; ++z) {
if (detections[z].unique_truth_index == truth_index) {
found = 1; break;
}
}
//计算tp,同时计算平均交并比
if (truth_index > -1 && found == 0) {
avg_iou += max_iou;
++tp_for_thresh;
avg_iou_per_class[class_id] += max_iou;
tp_for_thresh_per_class[class_id]++;
}
//计算fp
else{
fp_for_thresh++;
fp_for_thresh_per_class[class_id]++;
}
}
}
}
}
unique_truth_count += num_labels;
free_detections(dets, nboxes);
free(id);
free_image(val[t]);
free_image(val_resized[t]);
}
}
if ((tp_for_thresh + fp_for_thresh) > 0)
avg_iou = avg_iou / (tp_for_thresh + fp_for_thresh);
int class_id;
for(class_id = 0; class_id < classes; class_id++){
if ((tp_for_thresh_per_class[class_id] + fp_for_thresh_per_class[class_id]) > 0)
avg_iou_per_class[class_id] = avg_iou_per_class[class_id] / (tp_for_thresh_per_class[class_id] + fp_for_thresh_per_class[class_id]);
}
// SORT(detections)
qsort(detections, detections_count, sizeof(box_prob), detections_comparator);
typedef struct {
double precision;
double recall;
int tp, fp, fn;
} pr_t;
// for PR-curve
pr_t** pr = (pr_t**)calloc(classes, sizeof(pr_t*));
for (i = 0; i < classes; ++i) {
pr[i] = (pr_t*)calloc(detections_count, sizeof(pr_t));
}
printf("\n detections_count = %d, unique_truth_count = %d \n", detections_count, unique_truth_count);
int* detection_per_class_count = (int*)calloc(classes, sizeof(int));
for (j = 0; j < detections_count; ++j) {
detection_per_class_count[detections[j].class_id]++;
}
int* truth_flags = (int*)calloc(unique_truth_count, sizeof(int));
int rank;
for (rank = 0; rank < detections_count; ++rank) {
if (rank % 100 == 0)
printf(" rank = %d of ranks = %d \r", rank, detections_count);
if (rank > 0) {
int class_id;
for (class_id = 0; class_id < classes; ++class_id) {
pr[class_id][rank].tp = pr[class_id][rank - 1].tp;
pr[class_id][rank].fp = pr[class_id][rank - 1].fp;
}
}
box_prob d = detections[rank];
// if (detected && isn't detected before)
if (d.truth_flag == 1) {
if (truth_flags[d.unique_truth_index] == 0)
{
truth_flags[d.unique_truth_index] = 1;
pr[d.class_id][rank].tp++; // true-positive
} else
pr[d.class_id][rank].fp++;
}
else {
pr[d.class_id][rank].fp++; // false-positive
}
for (i = 0; i < classes; ++i)
{
const int tp = pr[i][rank].tp;
const int fp = pr[i][rank].fp;
const int fn = truth_classes_count[i] - tp; // false-negative = objects - true-positive
pr[i][rank].fn = fn;
if ((tp + fp) > 0) pr[i][rank].precision = (double)tp / (double)(tp + fp);
else pr[i][rank].precision = 0;
if ((tp + fn) > 0) pr[i][rank].recall = (double)tp / (double)(tp + fn);
else pr[i][rank].recall = 0;
if (rank == (detections_count - 1) && detection_per_class_count[i] != (tp + fp)) { // check for last rank
printf(" class_id: %d - detections = %d, tp+fp = %d, tp = %d, fp = %d \n", i, detection_per_class_count[i], tp+fp, tp, fp);
}
}
}
free(truth_flags);
double mean_average_precision = 0;
for (i = 0; i < classes; ++i) {
double avg_precision = 0;
// MS COCO - uses 101-Recall-points on PR-chart.
// PascalVOC2007 - uses 11-Recall-points on PR-chart.
// PascalVOC2010-2012 - uses Area-Under-Curve on PR-chart.
// ImageNet - uses Area-Under-Curve on PR-chart.
// correct mAP calculation: ImageNet, PascalVOC 2010-2012
if (map_points == 0)
{
double last_recall = pr[i][detections_count - 1].recall;
double last_precision = pr[i][detections_count - 1].precision;
for (rank = detections_count - 2; rank >= 0; --rank)
{
double delta_recall = last_recall - pr[i][rank].recall;
last_recall = pr[i][rank].recall;
if (pr[i][rank].precision > last_precision) {
last_precision = pr[i][rank].precision;
}
avg_precision += delta_recall * last_precision;
}
}
// MSCOCO - 101 Recall-points, PascalVOC - 11 Recall-points
else
{
int point;
for (point = 0; point < map_points; ++point) {
double cur_recall = point * 1.0 / (map_points-1);
double cur_precision = 0;
for (rank = 0; rank < detections_count; ++rank)
{
if (pr[i][rank].recall >= cur_recall) { // > or >=
if (pr[i][rank].precision > cur_precision) {
cur_precision = pr[i][rank].precision;
}
}
}
//printf("class_id = %d, point = %d, cur_recall = %.4f, cur_precision = %.4f \n", i, point, cur_recall, cur_precision);
avg_precision += cur_precision;
}
avg_precision = avg_precision / map_points;
}
printf("class_id = %d, name = %s, ap = %2.2f%% \t (TP = %d, FP = %d) \n",
i, names[i], avg_precision * 100, tp_for_thresh_per_class[i], fp_for_thresh_per_class[i]);
float class_precision = (float)tp_for_thresh_per_class[i] / ((float)tp_for_thresh_per_class[i] + (float)fp_for_thresh_per_class[i]);
float class_recall = (float)tp_for_thresh_per_class[i] / ((float)tp_for_thresh_per_class[i] + (float)(truth_classes_count[i] - tp_for_thresh_per_class[i]));
//printf("Precision = %1.2f, Recall = %1.2f, avg IOU = %2.2f%% \n\n", class_precision, class_recall, avg_iou_per_class[i]);
mean_average_precision += avg_precision;
}
const float cur_precision = (float)tp_for_thresh / ((float)tp_for_thresh + (float)fp_for_thresh);
const float cur_recall = (float)tp_for_thresh / ((float)tp_for_thresh + (float)(unique_truth_count - tp_for_thresh));
const float f1_score = 2.F * cur_precision * cur_recall / (cur_precision + cur_recall);
printf("\n for conf_thresh = %1.2f, precision = %1.2f, recall = %1.2f, F1-score = %1.2f \n",
thresh_calc_avg_iou, cur_precision, cur_recall, f1_score);
printf(" for conf_thresh = %0.2f, TP = %d, FP = %d, FN = %d, average IoU = %2.2f %% \n",
thresh_calc_avg_iou, tp_for_thresh, fp_for_thresh, unique_truth_count - tp_for_thresh, avg_iou * 100);
mean_average_precision = mean_average_precision / classes;
printf("\n IoU threshold = %2.0f %%, ", iou_thresh * 100);
if (map_points) printf("used %d Recall-points \n", map_points);
else printf("used Area-Under-Curve for each unique Recall \n");
printf(" mean average precision (mAP@%0.2f) = %f, or %2.2f %% \n", iou_thresh, mean_average_precision, mean_average_precision * 100);
for (i = 0; i < classes; ++i) {
free(pr[i]);
}
free(pr);
free(detections);
free(truth_classes_count);
free(detection_per_class_count);
free(avg_iou_per_class);
free(tp_for_thresh_per_class);
free(fp_for_thresh_per_class);
fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
printf("\nSet -points flag:\n");
printf(" `-points 101` for MS COCO \n");
printf(" `-points 11` for PascalVOC 2007 (uncomment `difficult` in voc.data) \n");
printf(" `-points 0` (AUC) for ImageNet, PascalVOC 2010-2012, your custom dataset\n");
if (reinforcement_fd != NULL) fclose(reinforcement_fd);
// free memory
free_ptrs((void**)names, net.layers[net.n - 1].classes);
free_list_contents_kvp(options);
free_list(options);
if (existing_net) {
//set_batch_network(&net, initial_batch);
//free_network_recurrent_state(*existing_net);
restore_network_recurrent_state(*existing_net);
//randomize_network_recurrent_state(*existing_net);
}
else {
free_network(net);
}
return mean_average_precision;
}
darknet计算mAP的代码详解
最新推荐文章于 2021-12-12 21:27:10 发布