关于Mask R-CNN 画PR曲线

最近太多人问我如何绘制PR曲线了,我又很少及时看到你们的消息,在这里跟大家道个歉,我直接把代码贴出来,你们看着改参数就好。

################ 导入相关包 #####################
import os
import sys
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
ROOT_DIR = os.path.abspath("../../")
sys.path.append(ROOT_DIR)  # To find local version of the library
from mrcnn import utils
import mrcnn.model as modellib
from samples.hpv import hpv    # 这里是我自己写的脚本  继承的参数 一般是nucleus继承过来的
##############  配置参数  ####
LOGS_DIR = os.path.join(ROOT_DIR, "logs")
DATASET_DIR = os.path.join(ROOT_DIR, "datasets/hpv")   #  数据集
config = hpv.NucleusInferenceConfig()
DEVICE = "/cpu:0"
TEST_MODE = "inference"
def get_ax(rows=1, cols=1, size=16):
    fig, ax = plt.subplots(rows, cols, figsize=(size * cols, size * rows))
    fig.tight_layout()
    return ax
def text_save(filename, data):#filename为写入CSV文件的路径,data为要写入数据列表.
    file = open(filename, 'a')
    for i in range(len(data)):
        s = str(data[i]).replace('[','').replace(']','')#去除[],这两行按数据不同,可以选择
        s = s.replace("'",'').replace(',','') +'\n'   #去除单引号,逗号,每行末尾追加换行符
        file.write(s)
    file.close()
    print("保存txt文件成功")

#####  加载测试集数据  #####
dataset = hpv.NucleusDataset()
dataset.load_nucleus(DATASET_DIR, "stage1_test")
dataset.prepare()
print("Images: {}\nClasses: {}".format(len(dataset.image_ids), dataset.class_names))

#####  导入模型  ####
with tf.device(DEVICE):
    model = modellib.MaskRCNN(mode="inference", model_dir=LOGS_DIR, config=config)
weights_path = "/Mask_RCNN/logs/model1-120211011T1528/mask_rcnn_model1-1_0300.h5"
model.load_weights(weights_path, by_name=True)
image_ids = dataset.image_ids

APs = []
count1 = 0
for image_id in image_ids:
    info = dataset.image_info[image_id]
    print("image_id: ", image_id)
    # ####重要步骤:获得测试图片的信息
    image, image_meta, gt_class_id, gt_bbox, gt_mask = modellib.load_image_gt(dataset, config, image_id, use_mini_mask=False)
    # ###保存实际结果
    if count1 == 0:
        save_box, save_class, save_mask = gt_bbox, gt_class_id, gt_mask
    else:
        save_box = np.concatenate((save_box, gt_bbox), axis=0)
        save_class = np.concatenate((save_class, gt_class_id), axis=0)
        save_mask = np.concatenate((save_mask, gt_mask), axis=2)
    molded_images = np.expand_dims(modellib.mold_image(image, config), 0)
    # # 显示检测结果
    # results = model.detect_molded(np.expand_dims(image, 0), np.expand_dims(image_meta, 0), verbose=1)
    results = model.detect_molded(np.expand_dims(image, 0), np.expand_dims(image_meta, 0), verbose=1)
    r = results[0]
    # 保存预测结果
    if count1 == 0:
        save_roi, save_id, save_score, save_m = r["rois"], r["class_ids"], r["scores"], r['masks']
    else:
        save_roi = np.concatenate((save_roi, r["rois"]), axis=0)
        save_id = np.concatenate((save_id, r["class_ids"]), axis=0)
        save_score = np.concatenate((save_score, r["scores"]), axis=0)
        save_m = np.concatenate((save_m, r['masks']), axis=2)
    count1 += 1
# AP, precisions, recalls, overlaps = utils.compute_ap(gt_bbox, gt_class_id, gt_mask, r['rois'], r['class_ids'], r['scores'], r['masks'])
# APs.append(AP)


# # 在阈值0.5到0.95之间每隔0.1显示AP值
# utils.compute_ap_range(gt_bbox_all, gt_class_id_all, gt_mask_all, pre_rois_all, pre_class_ids_all, pre_scores_all, pre_masks_all, verbose=1)
## 在图片中显示真实与预测之间的差异
# visualize.display_differences(image, gt_bbox, gt_class_id, gt_mask, r['rois'], r['class_ids'], r['scores'], r['masks'],
#                               dataset.class_names, ax=get_ax(), show_box=False, show_mask=False, iou_threshold=0.5, score_threshold=0.5)
# plt.show()

# ######绘制PR曲线######

AP, precisions, recalls, overlaps = \
        utils.compute_ap(save_box, save_class, save_mask,
                         save_roi, save_id, save_score, save_m)
print("AP: ", AP)
# print("mAP: ", np.mean(APs))

plt.plot(recalls, precisions, 'b', label='PR')
plt.title('Precision-Recall Curve. AP@50 = {:.3f}'.format(AP))
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.legend()
plt.show()
text_save('preci-model1.txt', precisions)
text_save('recall-model1.txt', recalls)

这个方法十分吃虚拟内存,就是这个脚本在哪个盘运行,就要设置大量虚拟内存,一般20图片挺快的,后来尝试修改代码,但是识别的结果有点不一样,后来也就没有继续研究了。
第一次写这个,不足之处请见谅。
在这里插入图片描述

  • 2
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值