Caffe代码解读4--P-R曲线绘制

P-R曲线是一个可以直观反映神经网络分类性能的曲线,在这里我们就看一下怎么绘制P-R曲线

'''
Author: Mr.K

You can use this script to plot the P-R curve.
Before running this script, you need to run the test script to get a test result file. 
The format of the test result file is as followed:

Class:     *** *** *** ... *** *** ***
Precision: *** *** *** ... *** *** ***
Recall:    *** *** *** ... *** *** ***
Accuracy:  *** *** *** ... *** *** ***
F1-score:  *** *** *** ... *** *** ***

This plot script just get precision values and recall values from the test result file.
After running this script, you will get many P-R curve images that it depend on the class number.
Also, you will get a plot result file.
There are the points which are used during ploting curves in the result file.
The format of the plot result file is as followed:

average Precision: *** *** *** ... *** *** ***
average Recall:    *** *** *** ... *** *** ***
average AP:        ***

(class name) Precision: *** *** *** ... *** *** ***
(class name) Recall:    *** *** *** ... *** *** ***
(class name) AP:        ***
            .
            .
            .
(class name) Precision: *** *** *** ... *** *** ***
(class name) Recall:    *** *** *** ... *** *** ***
(class name) AP:        ***

At the beginning of the file, there are "average Precision", "average Recall" and "average AP"(mAP).
The "average Precision" is the average for the precisions of all the class, the "average Recall" is the same.

In this file, AP means average precision, it's the area below the P-R curve.
So "average AP" is usually called mean AP(mAP).

'''

import numpy as np
import os
import shutil
import matplotlib.pyplot as plt
import ssd_VOC0712_test as test
from caffe.model_libs import *

####### you are supposed to run this script at the CAFFE ROOT ########################

network_name = "ssd"  #your neural net name, need to config
dataset_name = "VOC0712"#your dataset name, need to config
model_name = "{}_{}".format(network_name,dataset_name)
job_dir = "examples/{}/{}".format(network_name, model_name)
prototxt_dir = "{}/prototxt".format(job_dir)
trainLog_dir = "{}/log".format(job_dir)
trainData_dir = "{}/data".format(job_dir)
trainModel_dir = "{}/model".format(job_dir)
result_dir = "{}/result".format(job_dir)

img_save_dir = "{}/PRcurve".format(result_dir)

confidence_list = np.arange(0,1,0.1)
count=0

plt_data = []
labelname = []
class_num = test.get_class_num()

precision = []
recall = []
accuracy = []
F1 = []

make_if_not_exist(img_save_dir)

for i in range(class_num):
    temp0_data = []
    temp1_data = []
    temp2_data = []
    for j in range(len(confidence_list)):
        temp1_data.append(0.0)#recall
        temp2_data.append(0.0)#precision
    temp0_data.append(temp1_data)
    temp0_data.append(temp2_data)
    plt_data.append(temp0_data)
    labelname.append(0)
    del temp0_data
    del temp1_data
    del temp2_data
labelname[0]="average"

#os.system("python examples/{}/{}/{}_detect.py --confidence=0.01".format(network_name,model_name,model_name))

for conf in confidence_list:
    print("conf is {}".format(conf))
    os.system("python examples/{}/{}/{}_test.py --confidence={}".format(network_name,model_name,model_name,conf))
    shutil.copyfile("{}/test/result.txt".format(result_dir),"{}/test/result_{}.txt".format(result_dir,count))
    temp_fp = open("{}/test/result.txt".format(result_dir),'r')
    while True:
        line = temp_fp.readline()
        if line != '':
            item = line.split()
            if item[0] == "Class:":
                for i in range(1,class_num):
                    labelname[i]=item[i]
            elif item[0] == "Precision:":
                for i in range(1,class_num):
                    plt_data[i][1][count]=float(item[i])
            elif item[0] == "Recall:":
                for i in range(1,class_num):
                    plt_data[i][0][count]=float(item[i])
            else:
                pass
        else:
            print("conf {} has finished !".format(conf))
            break
    temp_fp.close()
    count += 1


for i in range(count):
    sum_r=0.0
    sum_p=0.0
    for j in range(1,class_num):
        sum_r += plt_data[j][0][i]
        sum_p += plt_data[j][1][i]
    plt_data[0][0][i]=sum_r/float(class_num-1)
    plt_data[0][1][i]=sum_p/float(class_num-1)

area = [0 for i in range(class_num)]
for i in range(class_num):
    for j in range(count-1):
        r1 = plt_data[i][0][j]
        r2 = plt_data[i][0][j+1]
        p1 = plt_data[i][1][j]
        p2 = plt_data[i][1][j+1]
        area[i] += (r1-r2)*((p2-p1)/2.0+p1)
    area[i] += plt_data[i][0][count-1]*plt_data[i][1][count-1]

#************************ save plot data ***********************#
result_fp = open("{}/result.txt".format(img_save_dir),'w')

for i in range(class_num):
    result_fp.write("{} Precision: ".format(labelname[i]))
    for j in range(count):
        result_fp.write(str(float('%.3f'%plt_data[i][1][j])))
        result_fp.write(" ")
    result_fp.write("\n")

    result_fp.write("{} Recall: ".format(labelname[i]))
    for j in range(count):
        result_fp.write(str(float('%.3f'%plt_data[i][0][j])))
        result_fp.write(" ")
    result_fp.write("\n")

    result_fp.write("{} AP: ".format(labelname[i]))
    result_fp.write(str(float('%.3f'%area[i])))
    result_fp.write("\n")

result_fp.close()

#*********************** save P-R curve images ***************#
for i in range(class_num):
    plt.cla()
    plt.plot(plt_data[i][0],plt_data[i][1])
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("P-R curve for {}".format(labelname[i]))
    #plt.show()
    plt.savefig("{}/pr_curve_{}.png".format(img_save_dir, labelname[i]))

程序一上来是一些常规的初始化,没什么好说的,关键从for conf in confidence_list:开始

我们看到,在这个for循环里面,通过改变置信度,反复调用测试脚本,置信度从0到1,每隔0.1取一个值,一共11个值。每一个置信度都对应一组查准率和查全率,那么这条P-R曲线实际上就是11个点组成的曲线

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值