画froc曲线代码,供参考

import pandas as pd
import numpy as np 
import json
import matplotlib.pyplot as plt
from decimal import Decimal
from matplotlib.ticker import FixedFormatter
import os

result_csv_file_path = 'xxx.csv'
crop_csv_file_path = 'yyy.csv'
GT_csv_file_path = 'xxx/DL_info.csv'
output_froc_path = 'output_csv_dir'
json_file_path = 'xxx/cocostyledataset/annotation/annotation.json'

FROC_minX = 0.5  # Mininum value of x-axis of FROC curve
FROC_maxX = 16  # Maximum value of x-axis of FROC curve
bLogPlot = True

GT_df = pd.read_csv(GT_csv_file_path)
RS_df = pd.read_csv(result_csv_file_path)

def iou(dt, gt):
    dt_x1 = min(dt['bbox_x1'], dt['bbox_x2'])
    dt_y1 = min(dt['bbox_y1'],dt['bbox_y2'])
    dt_x2 = max(dt['bbox_x1'], dt['bbox_x2'])
    dt_y2 = max(dt['bbox_y1'],dt['bbox_y2'])
    gt_bbox = [ float(i) for i in gt['Bounding_boxes'].split(',')]
    gt_x1 = min(gt_bbox[0], gt_bbox[2])
    gt_y1 = min(gt_bbox[1], gt_bbox[3])
    gt_x2 = max(gt_bbox[0], gt_bbox[2])
    gt_y2 = max(gt_bbox[1], gt_bbox[3])

    left_column_max  = max(dt_x1,gt_x1)
    right_column_min = min(dt_x2,gt_x2)
    up_row_max       = max(dt_y1,gt_y1)
    down_row_min     = min(dt_y2,gt_y2)
    # if no cross area
    if left_column_max >= right_column_min or down_row_min<=up_row_max:
        return 0
    else:
        S1 = (dt_x2-dt_x1)*(dt_y2-dt_y1)
        S2 = (gt_x2-gt_x1)*(gt_y2-gt_y1)
        S_cross = (down_row_min-up_row_max)*(right_column_min-left_column_max)
        return S_cross/(S1+S2-S_cross)



def compute_froc(GT_df, RS_df):
    scores = RS_df['score'].values
    scores_idxs = np.argsort(-scores)
    RS_df = RS_df.iloc[scores_idxs]
    tp_list = []
    tps = 0 # used for accumulate tp num
    fp_list = []
    fps = 0 # used for accumulate fp num
    ignore_gt = set() # used for avoid match more than dt on one gt
    img_set = set() # used for compute related img nums
    gt_set = set() # used for compute related gt nums
    for _, dt in RS_df.iterrows():
        img_set.add(dt['File_name'])
        match_dict = {'File_name':[dt['File_name']]}
        gts = GT_df.loc[GT_df.isin(match_dict).any(1)]
        iou_matched = 0
        match_id = -1
        for gt_id, gt in gts.iterrows():
            gt_set.add(gt_id)
            iou_ = iou(dt, gt)
            if iou_ > iou_matched:
                iou_matched = iou_
                match_id = gt_id

        if iou_matched >= 0.5 and match_id not in ignore_gt:
            tps += 1
            tp_list.append(tps)
            fp_list.append(fps)
            ignore_gt.add(match_id)
        elif iou_matched < 0.5:
            fps += 1
            tp_list.append(tps)
            fp_list.append(fps)
    
    num_of_img = len(img_set)
    num_of_gt = len(gt_set)
    print(num_of_gt)
    sens_axis = np.array(tp_list) / num_of_gt
    fp_axis = np.array(fp_list) / num_of_img
    fps_itp = np.linspace(FROC_minX, FROC_maxX, num=10001)
    sens_itp = np.interp(fps_itp, fp_axis, sens_axis)

    score = 0
    for i in range(len(fps_itp)):
        if Decimal(fps_itp[i]).quantize(Decimal("0.000")) in [0.500, 1.000, 2.000, 4.000, 8.000, 16.000]:
            score += sens_itp[i]
            print("fps_itp %.2f ,sens_itp %.2f\n"%(fps_itp[i],sens_itp[i]))
    score = score/6.0
    print("    Average sensivity over seven fps : %.9f\n" % (score))

    fps_bs_itp = None
    sens_bs_mean = None
    sens_bs_lb = None
    sens_bs_up = None
    # create FROC graphs
    graphTitle = str("")
    fig1 = plt.figure()
    ax = plt.gca()
    clr = 'b'
    plt.plot(fps_itp, sens_itp, color=clr, lw=2)

    xmin = FROC_minX
    xmax = FROC_maxX
    plt.xlim(xmin, xmax)
    plt.ylim(0, 1)
    plt.xlabel('Average number of false positives per scan')
    plt.ylabel('Sensitivity')
    plt.legend(loc='lower right')
    plt.title('FROC performance ')

    if bLogPlot:
        plt.xscale('log', basex=2)
        ax.xaxis.set_major_formatter(FixedFormatter([0.5, 1, 2, 4, 8, 16]))

    # set your ticks manually
    ax.xaxis.set_ticks([0.5, 1, 2, 4, 8, 16])
    ax.yaxis.set_ticks(np.arange(0, 1.1, 0.1))
    plt.grid(b=True, which='both')
    plt.tight_layout()

    plt.savefig(os.path.join(output_froc_path, "froc.png"), bbox_inches=0, dpi=300)


def map_back(RS_df, crop_df):
    File_name_list = []
    original_cooridinate_x_list = []
    original_cooridinate_y_list = []
    
    for _, row in crop_df.iterrows():
        file_name = row['File_name'][0:-4] + '_' + str(row['bbox_id']) + '.png'
        File_name_list.append(file_name)

        x_center = (row['bbox_x1'] + row['bbox_x2']) / 2
        y_center = (row['bbox_y1'] + row['bbox_y2']) / 2
        w = abs(row['bbox_x1'] - row['bbox_x2'])
        h = abs(row['bbox_y2'] - row['bbox_y1'])
        original_cooridinate_x = max(int(x_center - 2 * w), 0)
        original_cooridinate_y = max(int(y_center - 2 * h), 0)
        original_cooridinate_x_list.append(original_cooridinate_x)
        original_cooridinate_y_list.append(original_cooridinate_y)

    original_cooridinate_df = pd.DataFrame()    
    original_cooridinate_df['File_name'] = File_name_list
    original_cooridinate_df['original_cooridinate_x'] = original_cooridinate_x_list
    original_cooridinate_df['original_cooridinate_y'] = original_cooridinate_y_list

    print(RS_df[0:1]['File_name'])
    print(File_name_list[0])
    RS_df = pd.merge(original_cooridinate_df, RS_df)
    print(RS_df.shape)
    x1_list = []
    y1_list = []
    x2_list = []
    y2_list = []
    original_filename_list = []
    for _, row in RS_df.iterrows():
        x1 = row['original_cooridinate_x'] + row['bbox_x1']
        y1 = row['original_cooridinate_y'] + row['bbox_y1']
        x2 = row['original_cooridinate_x'] + row['bbox_x2']
        y2 = row['original_cooridinate_y'] + row['bbox_y2']
        original_filename = row['File_name'][0:-len(row['File_name'].split('_')[-1])-1] + '.png'
        x1_list.append(x1)
        y1_list.append(y1)
        x2_list.append(x2)
        y2_list.append(y2)
        original_filename_list.append(original_filename)

    RS_df['bbox_x1'] = x1_list
    RS_df['bbox_y1'] = y1_list
    RS_df['bbox_x2'] = x2_list
    RS_df['bbox_y2'] = y2_list
    RS_df['File_name'] = original_filename_list
    return RS_df

def get_file_name(RS_df, json_file_path):
    with open(json_file_path, 'r') as json_file:
        input_dataset = json.load(json_file)
    # RS_df['File_name'] = None
    json_df = pd.DataFrame()
    file_name_list = []
    image_id_list = []
    for image in input_dataset['images']:
        # RS_df
        # match_dict = {'image_id':[image['image_id']]}
        # RS_df.loc[RS_df.isin(match_dict).any(1),'File_name'] = image['File_name']
        image_id_list.append(image['id'])
        file_name_list.append(image['file_name'][0:-3] + 'png')
    json_df['File_name'] = file_name_list
    json_df['image_id'] = image_id_list
    RS_df = pd.merge(json_df, RS_df)
    return RS_df

if __name__ == "__main__":
    RS_df = pd.read_csv(result_csv_file_path)
    # if you already have file name, you dont't need the next line and can commit it
    RS_df = get_file_name(RS_df, json_file_path)
    
    # go back to original img, if you already or don't need to, you can commit next 2 lines
    crop_df = pd.read_csv(crop_csv_file_path)
    RS_df = map_back(RS_df, crop_df)   
    
    RS_df.to_csv('transformed_result.csv')
    GT_df = pd.read_csv(GT_csv_file_path)

    compute_froc(GT_df, RS_df)
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值