Caffe代码解读3--类别测试程序解读

在上一篇文章Caffe代码解读2--检测程序解读中我们分析了怎么进行目标检测,那在这篇文章中,我们就分析一下我们训练好的神经网络,到底表现怎么样。对于一个目标检测来说,我们可以分为类别检测和位置检测,这篇文章就先说一下类别检测。

在说类别检测指标之前,我们先介绍几个概念

TP(True Positive):正确的标记为正,标签写的是这个目标,预测结果也认为是这个目标
FP(False Positive):错误的标记为正,标签写的不是这个目标,但预测结果却认为是这个目标
FN(False Negative):错误的标记为负,标签写的不是这个目标,预测结果也认为不是这个目标
TN(True Negative):正确的标记为负,标签写的是这个目标,但预测结果却认为不是这个目标

然后说一下用到的检测指标

查准率Precision                             P = TP / (TP+FP)

查全率Recall                                  R = TP / (TP+FN)

准确率Accuracy                             A = (TP+TN) / (TP+FP+FN+TN)

F1指标                                            F1 = 2*P*R / (P+R)

我接下来要讲解的程序,便是对这四个指标的计算过程

'''
Author: Mr.K

You can use this script to test a network performance.
Before running this test script, you need to run the detect script to get a detection result.
The format of a detection result file is as followed:

file_name: the/image/path/***.jpg
para_name: xmin ymin xmax ymax label_id confidence label_name
object:    **** **** **** **** ******** ********** **********
   .        .    .    .    .       .         .          .
   .        .    .    .    .       .         .          .
   .        .    .    .    .       .         .          .
object:    **** **** **** **** ******** ********** **********

This is the format of one iamge, if you have detected lots of images(a dataset), you just need to repeat it in your detection result file.
After running this script, you will get a test result file, including precision, recall, accuracy and F1-score of the whole classes.
The format of a test result file is as followed:

Class:     *** *** *** ... *** *** ***
Precision: *** *** *** ... *** *** ***
Recall:    *** *** *** ... *** *** ***
Accuracy:  *** *** *** ... *** *** ***
F1-score:  *** *** *** ... *** *** ***

'''

import xml.etree.ElementTree as ET
import os
import caffe
import argparse
from google.protobuf import text_format
from caffe.proto import caffe_pb2
from caffe.model_libs import *
####### you are supposed to run this script at the CAFFE ROOT ########################

network_name = "ssd"  #your neural net name, need to config
dataset_name = "VOC0712"#your dataset name, need to config
model_name = "{}_{}".format(network_name,dataset_name)
job_dir = "examples/{}/{}".format(network_name, model_name)
prototxt_dir = "{}/prototxt".format(job_dir)
trainLog_dir = "{}/log".format(job_dir)
trainData_dir = "{}/data".format(job_dir)
trainModel_dir = "{}/model".format(job_dir)

source_file = "{}/result/detect/result.txt".format(job_dir)
annotation_dir = "/home/kangyi/data/VOCdevkit/VOC2007/Annotations"
labelmap_file = "{}/labelmap_voc.prototxt".format(trainData_dir)
result_save_dir = "{}/result/test".format(job_dir)
save_result = True

count=0
gt_bboxs=[]
pre_bbox=[]
label_pairs=[]

def get_class_num():
    lm_f = open(labelmap_file, 'r')
    labelmap = caffe_pb2.LabelMap()
    text_format.Merge(str(lm_f.read()), labelmap)
    class_num = len(labelmap.item)
    lm_f.close()
    return class_num

def get_label_id(labelname):
    for i in range(class_num):
        if labelname == label_pairs[i][1]:
            return label_pairs[i][0]
    return -1

def get_label_name(label_id):
    for i in range(class_num):
        if label_id == int(label_pairs[i][0]):
            return label_pairs[i][1]
    return "None"

def computeIOU(box1, box2):
    # box:[xmin,ymin,xmax,ymax]
    in_h = min(box1[3], box2[3]) - max(box1[1], box2[1]) 
    in_w = min(box1[2], box2[2]) - max(box1[0], box2[0]) 
    inter = 0 if in_h<0 or in_w<0 else in_h*in_w 
    union = (box1[3]-box1[1])*(box1[2]-box1[0])+(box2[3]-box2[1])*(box2[2]-box2[0])-inter 
    iou = float(inter) / float(union) 
    return iou
        

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Test a group of images and get precision, recall, accuracy and F1-score")
    parser.add_argument("--iou", default=0.5, type=float)
    parser.add_argument("--confidence", default=0.5, type=float)
    args = parser.parse_args()
    
    iou_thresh = args.iou
    conf_thresh = args.confidence

    lm_f = open(labelmap_file, 'r')
    labelmap = caffe_pb2.LabelMap()
    text_format.Merge(str(lm_f.read()), labelmap)
    class_num = len(labelmap.item)
    for i in range(class_num):
        label_pairs.append([labelmap.item[i].label,labelmap.item[i].name])
    lm_f.close()

    TP=[0 for i in range(class_num)]
    FP=[0 for i in range(class_num)]
    FN=[0 for i in range(class_num)]
    TN=[0 for i in range(class_num)]
    precision=[0 for i in range(class_num)]
    recall=[0 for i in range(class_num)]
    accuracy=[0 for i in range(class_num)]
    F1=[0 for i in range(class_num)]

    source_f = open(source_file,'r')
    while True:
        line = source_f.readline()
        if line != '':
            item = line.split()
            if item[0] == "file_name:":
                object_cnt = 0
                del gt_bboxs[:]
                basename = os.path.basename(item[1])
                updateTree = ET.parse("{}/{}.xml".format(annotation_dir,os.path.splitext(basename)[0]))
                root = updateTree.getroot()
                for object in root.findall("object"):
                # bbox list format:
                # xmin ymin xmax ymax label_id label_name matched_flag
                    bb_temp = []
                    bb_temp.append(int(object.find("bndbox/xmin").text))
                    bb_temp.append(int(object.find("bndbox/ymin").text))
                    bb_temp.append(int(object.find("bndbox/xmax").text))
                    bb_temp.append(int(object.find("bndbox/ymax").text))
                    bb_temp.append(get_label_id(object.find("name").text))
                    bb_temp.append(object.find("name").text)
                    bb_temp.append(0) # a flag, if it is matched
                    gt_bboxs.append(bb_temp)
                    object_cnt=object_cnt+1
                    del bb_temp
                #print(gt_bboxs)
            elif item[0] == "param_name:":
                pass
                #print("parse param")
            elif item[0] == "object:":
                # result.txt file format:
                # xmin ymin xmax ymax label_id confidence label_name
                del pre_bbox[:]
                pre_bbox.append(int(item[1]))
                pre_bbox.append(int(item[2]))
                pre_bbox.append(int(item[3]))
                pre_bbox.append(int(item[4]))
                pre_bbox.append(int(item[5]))#label id
                pre_bbox.append(item[7])#label name
                pre_bbox.append(0)
                if float(item[6]) > conf_thresh:
                    for gt_bbox in gt_bboxs:
                        if gt_bbox[5] == pre_bbox[5]:
                            iou = computeIOU(gt_bbox,pre_bbox)
                            #print("{} iou is : {}".format(gt_bbox[5],iou))
                            if iou > iou_thresh:
                                TP[pre_bbox[4]]=TP[pre_bbox[4]]+1
                                #for i in range(class_num):
                                #    TN[i]=TN[i]+1
                                #TN[pre_bbox[4]]=TN[pre_bbox[4]]-1
                                gt_bbox[6] = 1
                                pre_bbox[6] = 1
                                break
                    if pre_bbox[6] == 0:
                        FP[pre_bbox[4]]=FP[pre_bbox[4]]+1
            elif item[0] == "---":
                count=count+1
                if count%1000 == 0:
                    print("Processed {} images...".format(count))
                for gt_bbox in gt_bboxs:
                    if gt_bbox[6] == 0:
                        FN[gt_bbox[4]]=FN[gt_bbox[4]]+1
                        #for i in range(class_num):
                        #    TN[i]=TN[i]+1
                        #TN[gt_bbox[4]]=TN[gt_bbox[4]]-1
                #print("gt_bboxs : {}".format(gt_bboxs))
                #print("pre_bbox : {}".format(pre_bbox))
                #print("TP : {}".format(TP))
                #print("FP : {}".format(FP))
                #print("FN : {}".format(FN))
                #print("TN : {}".format(TN))
            else:
                print("Incorrect format! item[0] is: {}".format(item[0]))
        else:
            if count%1000 != 0:
                print("Processed {} images...".format(count))
            print("finished!")
            break
    source_f.close()

    for i in range(1,class_num):
        if TP[i]+FP[i] != 0:
            precision[i]=float(TP[i])/float(TP[i]+FP[i])
        else:
            precision[i]="NaN"
        if TP[i]+FN[i] != 0:
            recall[i]=float(TP[i])/float(TP[i]+FN[i])
        else:
            recall[i]="NaN"
        if TP[i]+FP[i]+FN[i]+TN[i] != 0:
            accuracy[i] = float(TP[i]+TN[i])/float(TP[i]+FP[i]+FN[i]+TN[i])
        else:
            accuracy[i] = "Nan"
        if precision[i]+recall[i] != 0:
            F1[i]=float(2*precision[i]*recall[i])/float(precision[i]+recall[i])
        else:
            F1[i]="NaN"

        print("{} P is: {}".format(get_label_name(i),precision[i]))
        print("{} R is: {}".format(get_label_name(i),recall[i]))
        print("{} A is: {}".format(get_label_name(i),accuracy[i]))
        print("{} F is: {}".format(get_label_name(i),F1[i]))

    if save_result:
        make_if_not_exist(result_save_dir)
        result_save_file = "{}/result.txt".format(result_save_dir)
        result_fp = open(result_save_file, 'w')

        result_fp.write("Class: ")
        for i in range(1,class_num):
            result_fp.write(label_pairs[i][1])
            result_fp.write(" ")
        result_fp.write("\n")

        result_fp.write("Precision: ")
        for i in range(1,class_num):
            result_fp.write("{}".format(float('%.3f'%precision[i])))
            result_fp.write(" ")
        result_fp.write("\n")

        result_fp.write("Recall: ")
        for i in range(1,class_num):
            result_fp.write("{}".format(float('%.3f'%recall[i])))
            result_fp.write(" ")
        result_fp.write("\n")

        result_fp.write("Accuracy: ")
        for i in range(1,class_num):
            result_fp.write("{}".format(float('%.3f'%accuracy[i])))
            result_fp.write(" ")
        result_fp.write("\n")

        result_fp.write("F1-score: ")
        for i in range(1,class_num):
            result_fp.write("{}".format(float('%.3f'%F1[i])))
            result_fp.write(" ")
        result_fp.write("\n")

        result_fp.close()

这段脚本的输入是一组图片,也就是测试集,输出是上文讲的那四个指标,这些都可以在脚本开头的注释中看到,包括输出文件的格式。

程序一开始,我们还是默认置信度是0.5,IoU也是0.5,这两个参数我们一般是不去调整的,但是为什么还要设计成可调整的呢?是因为在之后的测试里,我们会用到P-R曲线,而绘制P-R曲线是需要调整置信度阈值的

然后我们可以看到,是读取了检测的结果,也就是说,在运行测试的脚本之前,要先运行检测的脚本,好得到检测的结果。测试脚本就是将检测结果和标签文件进行比较,得到我们想要的测试结果。

程序很简单,就不多说了

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值