PR、ROC曲线的绘制及AP计算

import numpy as np
import matplotlib.pyplot as plt

predicts = np.load('/home/lixuan/workspace/project/recognition/cigarette_shelve_or_not/predict.npy')#
labels = np.load('/home/lixuan/workspace/project/recognition/cigarette_shelve_or_not/label.npy')#

x = []
y = []
for thr in [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]:
    output = predicts.copy()
    label = labels.copy()
    positive_output = np.where(output >= thr)
    output[positive_output] = 1
    negative_output = np.where(output < thr)
    output[negative_output] = 0
    positive_label = np.where(label >= thr)
    label[positive_label] = 1
    negative_label = np.where(label < thr)
    label[negative_label] = 0

    all = 0
    correct = 0
    correct += sum(output == label)
    all += len(output)
    print(correct / all)#0.8898191560616209(0.8716437459070072),0.9152712659075687(0.8948919449901768)

#     TP = 0
#     FP = 0
#     TN = 0
#     FN = 0
#     label = np.abs(label - 1)
#     output = np.abs(output - 1)
#     for idx,real in enumerate(label):
#         predict = output[idx]
#         if predict == 1 and real == 1:
#             TP += 1
#         if predict == 1 and real == 0:
#             FP += 1
#         if predict == 0 and real == 1:
#             FN += 1
#         if predict == 0 and real == 0:
#             TN += 1
#     x.append(FP / (FP + TN))
#     y.append(TP / (TP + FN))
#     # print(thr,TP / len(labels),FP / 880)
#     # print(thr,sum(output == label) / len(predicts),(len(output) - sum(output)) / (len(label) - sum(label)))
#
# plt.plot(x,y)
# plt.plot([x[0],x[-1]],[y[0],y[-1]],linestyle='dashed')
# plt.title('ROC curve')
# plt.xlabel('False Positive Rate')
# plt.ylabel('True Positive Rate')
# plt.legend(['ROC curve'],bbox_to_anchor=(1,0.1))
# plt.show()
"""
计算map
"""
import numpy as np
import matplotlib.pyplot as plt

p1 = []
r1 = []
fpr1 = []
tpr1 = []
p2 = []
r2 = []
fpr2 = []
tpr2 = []
for thr in np.linspace(0.05,0.8,16):
    predicts = np.load('/home/lixuan/SunloginFiles/zjzy_scenes_detection/best3/{}.npy'.format(round(thr,2)),allow_pickle=True)
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    with open('names.txt') as f:
        for index,name in enumerate(f.readlines()):
            predict = predicts[index]
            name = name.strip()
            bboxs = []
            try:
                with open('/home/lixuan/workspace/dataset/zhejiangtxt/{}'.format(name.replace('jpg','txt'))) as fr:
                    for boxs in fr.readlines():
                        boxs = boxs.strip().split(' ')
                        bboxs.append(int(boxs[0]))
            except FileNotFoundError:
                pass
            if len(bboxs) == len(predict):
                TP += len(predict)
                if len(bboxs) == 0:
                    TN += 1
            if len(bboxs) > len(predict):
                TP += len(predict)
                FN += (len(bboxs) - len(predict))
            if len(bboxs) < len(predict):
                TP += len(bboxs)
                FP += (len(predict) - len(bboxs))
    P = TP / (TP + FP)
    R = TP / (TP + FN)
    FPR = FP / (FP + TN)
    TPR = TP / (TP + FN)
    p1.append(R)
    r1.append(P)
    fpr1.append(FPR)
    tpr1.append(TPR)

for thr in np.linspace(0.05,0.8,16):
    predicts = np.load('/home/lixuan/SunloginFiles/zjzy_scenes_detection/best2/{}.npy'.format(round(thr,2)),allow_pickle=True)
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    with open('names.txt') as f:
        for index,name in enumerate(f.readlines()):
            predict = predicts[index]
            name = name.strip()
            bboxs = []
            try:
                with open('/home/lixuan/workspace/dataset/zhejiangtxt/{}'.format(name.replace('jpg','txt'))) as fr:
                    for boxs in fr.readlines():
                        boxs = boxs.strip().split(' ')
                        bboxs.append(int(boxs[0]))
            except FileNotFoundError:
                pass
            if len(bboxs) == len(predict):
                TP += len(predict)
                if len(bboxs) == 0:
                    TN += 1
            if len(bboxs) > len(predict):
                TP += len(predict)
                FN += (len(bboxs) - len(predict))
            if len(bboxs) < len(predict):
                TP += len(bboxs)
                FP += (len(predict) - len(bboxs))
    P = TP / (TP + FP)
    R = TP / (TP + FN)
    FPR = FP / (FP + TN)
    TPR = TP / (TP + FN)
    p2.append(R)
    r2.append(P)
    fpr2.append(FPR)
    tpr2.append(TPR)

def voc_ap(rec, prec, use_07_metric=False):
    """ ap = voc_ap(rec, prec, [use_07_metric])
    Compute VOC AP given precision and recall.
    If use_07_metric is true, uses the
    VOC 07 11 point method (default:False).
    """
    # 针对2007年VOC,使用的11个点计算AP,现在不使用
    if use_07_metric:
        # 11 point metric
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):
            if np.sum(rec >= t) == 0:
                p = 0
            else:
                p = np.max(prec[rec >= t])
            ap = ap + p / 11.
    else:
        # correct AP calculation
        # first append sentinel values at the end
        mrec = np.concatenate(([0.], rec, [1.]))  #[0.  0.0666, 0.1333, 0.4   , 0.4666,  1.]
        mpre = np.concatenate(([0.], prec, [0.])) #[0.  1.,     0.6666, 0.4285, 0.3043,  0.]

        # compute the precision envelope
        # 计算出precision的各个断点(折线点)
        for i in range(mpre.size - 1, 0, -1):
            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])  #[1.     1.     0.6666 0.4285 0.3043 0.    ]

        # to calculate area under PR curve, look for points
        # where X axis (recall) changes value
        i = np.where(mrec[1:] != mrec[:-1])[0]  #precision前后两个值不一样的点
        # print(mrec[1:], mrec[:-1])
        # print(i) #[0, 1, 3, 4, 5]

        # AP= AP1 + AP2+ AP3+ AP4
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

# ap1 = voc_ap(r1, p1)#0.9435914828900585 0.9800751903020692(0.4) 0.9678572170490943 2、建议阈值0.3
# ap2 = voc_ap(r2,p2)
# print(ap1,ap2)
# plt.plot(p1,r1)
# plt.plot(p2,r2)
# plt.title('PR curve')
# plt.xlabel('')
# plt.ylabel('True Positive Rate')
# plt.legend(['ROC curve'],bbox_to_anchor=(1,0.1))
for index,fpr in enumerate(fpr1):
    print(fpr,fpr2[index],tpr1[index])
# plt.plot(fpr1,tpr1)
# plt.plot(fpr2,tpr2)
# plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值