目标跟踪PR,SR的python代码—PR,SR画图工具

1.p_norm脚本

#只需修改路径 line182,line183

import numpy as np
from matplotlib import pyplot as plt
import matplotlib
import os
from scipy.interpolate import make_interp_spline
from scipy.integrate import simps
import scipy
import math
from numpy import trapz
import glob

# 获取每个算法目录
def get_algorithm_dir(bbPath):
    algorithm_path = []
    for root, dirs, files in os.walk(bbPath):
        for dire in dirs:
            algorithm_path.append(os.path.join(root, dire)) # 获取每个跟踪算法的路径
            # print(dire)
    return algorithm_path

# 获取一个算法中的数据文件
def get_datafile(bbPath):
    for root,dirs,files in os.walk(bbPath):
        pass
    return files

#计算归一化距离
def getNormDistance(bbPath, gtPath):
    #dell=['7.txt','114.txt','115.txt','12.txt','129.txt','131.txt','142.txt','145.txt','176.txt','197.txt','320.txt','335.txt','394.txt']
    dell=[]
    algorithm = get_algorithm_dir(bbPath)
    # print(algorithm)
    distance = {
   } # 存放所有跟踪算法的norm distance,字典嵌套列表{[[],[],...] , ... ,[[],[],...]}
    for dire in algorithm:
        # print(dire)
        algo_name = dire.split('\\')[-1] # 获取算法名
        data_file = get_datafile(dire) # 获取该算法下面的数据文件path

        algo_distance = [] # 存放单个跟踪算法的norm distance 列表嵌套列表
        for fname in data_file: # 处理一个算法的数据
            if fname in dell:
                continue
            else:
                bb_file_path = os.path.join(dire, fname)# 单个数据文件路径,例如1.txt
                gt_file_path = os.path.join(gtPath, fname)
            # bb_file_path = os.path.join(dire, fname)# 单个数据文件路径,例如1.txt
            # gt_file_path = os.path.join(gtPath, fname) 
            
            try:
                bb_data = np.loadtxt(bb_file_path, dtype=np.float)
            except ValueError:
                bb_data = np.loadtxt(bb_file_path, dtype=np.float, delimiter=',')

            try:
                gt_data = np.loadtxt(gt_file_path, dtype=np.float)
            except ValueError:
                gt_data = np.loadtxt(gt_file_path, dtype=np.float, delimiter=',')
            #gt_data = np.loadtxt(gt_file_path, dtype=np.float, delimiter='\t')
            #gt_data = np.loadtxt(gt_file_path, dtype=np.float, delimiter=',')
            seq_distance = [] # seq_distance中存放的是一个文件中所有bbox之间的norm distance
            for i in range(len(bb_data)): # 处理一个序列的数据

                gt_x,gt_y,gt_w,gt_h = gt_data[i]
                bb_x,bb_y,bb_w,bb_h = bb_data[i]
                if gt_w==0:
                    gt_w=1
                if gt_h==0:
                    gt_h=1
                # gt中心点位置 and bbox中心点位置
                gt_center = np.array([gt_x+gt_w/2,gt_y+gt_h/2]) #groundtruth bbox center point position
                bb_center = np.array([bb_x+bb_w/2,bb_y+bb_h/2]) #trace algorithm bbox center poitn position
                
                dx = (gt_center[0]-bb_center[0])/gt_w
                dy = (gt_center[1]-bb_center[1])/gt_h
                ndistance = math.sqrt(dx**2 + dy**2) # compute the norm distance
                seq_distance.append(ndistance) # 存放了一个数据文件中所有bbox的normalized distance

            algo_distance.append(seq_distance) # algo_distance中存放的是单个算法中的所有norm distance
            #print(algo_distance)
        # print(algo_name, algo_distance)
        distance[algo_name] = algo_distance # distance是一个字典,key对应算法名,value对应normalized distance
    return distance

# 计算精确度
def calculate_accuracy(threshold, bbPath, gtPath):
    norm_distance = getNormDistance(bbPath, gtPath)
    algo_accuracy = {
   }
    key_list = []
    for algo_name, algo_distance in norm_distance.items():
        #print(algo_name)
        accuracy_list = []
        for thre in threshold:
            accuracy = 0
            for ndistance in algo_distance: # len(algo_distance)相当于一个跟踪算法中的序列数
                cnt = 0
                for dist in ndistance: # 计算单个算法的accuracy
                    if dist < thre:
                        cnt = cnt+1
                accuracy = accuracy+cnt/len(ndistance) # 计算每个序列的平均accuracy len(ndistance)相当于帧数
            accuracy_list.append(accuracy/len(algo_distance)) # 计算算法的平均accuracy
        # print(algo_name, accuracy_list)

        y = np.array(accuracy_list)
        x = np.array(threshold)
        area = trapz(y, x, dx=0.001)*2
        area = '%.03f'%area  # 保留三位小数

        algo_accuracy['['+area+']'+algo_name] = accuracy_list
        key_list.append('['+area+']'+algo_name)
        #print(key_list)
    return algo_accuracy, key_list

# 绘制图片
def plot_figure(threshold, list_accuracy, key_list):#,list_name,k
    # 设置图像的大小
    plt.figure(figsize=(10,10))
    # 设置坐标轴上坐标刻度
    plt.xticks(list(np.arange(0,0.6,0.1)),['0','0.1','0.2','0.3','0.4','0.5'])
    plt.yticks(list(np.arange(0,1.0,0.1)), ['0','0.1','0.2','0.3','0.4','0.5','0.6','0.7','0.8','0.9'])
    plt.tick_params(labelsize=13) # 设置坐标轴刻度字体的大小
    plt.grid(alpha=1,ls='--') # 设置网格线背景,并设置透明度为1
    plt.axis([0, 0.5, 0, 0.9])# 设置坐标轴起始点
    colors = [ 'red','darkred', 'yellow', 'dodgerblue', 'black','lime', 'darkorchid','cyan','slategrey','maroon','rosybrown',\
        'deeppink', 'coral', 'tan', 'green','magenta', 'pink', 'olive', 'gold','plum','peru','chocolate','crimson',\
            'crimson', 'deepskyblue',  'springgreen', 'slategrey', 'plum', 'steelblue', 'lawngreen','royalblue']
    #linestyles = ['-','--']
    i = 0
    x = np.array(threshold)
    t = len(key_list
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值