import json
import matplotlib.pyplot as plt
import argparse
'''
解析参数
'''
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str, default='val')
parser.add_argument("--select", type=str, default='AP50')
parser.add_argument("--json_paths", type=str, nargs='+')
parser.add_argument("--line_names", type=str, nargs='+')
parser.add_argument("--out_dir", type=str, default='F:\\fangweijie_weight_file\\retinanet_r50_fpn_1x_voc0712_cocoPretrain_180epoch_lr0.001')
parser.add_argument("--epoch_num", type=int, default=20)
parser.add_argument("--pic_name", type=str,default="result")
args = parser.parse_args()
select=args.select
pic_name=args.pic_name
mode = args.mode # 选择log文件中的模式
json_paths = args.json_paths
line_names = args.line_names
out_dir = args.out_dir
epoch_num = args.epoch_num
plt.figure(figsize=(12, 8), dpi=300)
for i, json_path in enumerate(json_paths):
epoch_now = 0
x = [] # 存放epoch
y = [] # 存放指标
y_min = 1000000 # 存放指标最大值 ap不会超过1 绘制loss可自由更改
y_max = -1 # 存放指标最小值 ap不会小于-1 绘制loss可自由更改
x_min = 0 # 出现最小值的epoch
x_max = 0 # 出现最大值的epoch
isFirst = True
with open(json_path, 'r') as f:
for jsonstr in f.readlines():
if epoch_now == epoch_num:
break
if isFirst: # mmdetection生成的log json文件第一行是配置信息 跳过
isFirst = False
continue
row_data = json.loads(jsonstr)
if row_data['mode'] == mode: # 选择train或者val模式中的指标数据
epoch_now = epoch_now + 1
item_select = float(row_data[select])
x_select = int(row_data['epoch'])
x.append(x_select)
y.append(item_select)
if item_select >= y_max: # 选择最大值 为什么不用numpy.argmin呢? 因为epoch可能不从1开始 xmin和ymin可能匹配错误 比较麻烦
y_max = item_select
x_max = x_select
if item_select <= y_min: # 选择最大值
y_min = item_select
x_min = x_select
plt.grid(True, linestyle='--', alpha=0.5)
plt.plot(x, y, label=line_names[i])
plt.plot(x_min, y_min, 'g-p', x_max, y_max, 'r-p')
show_min = '[' + str(x_min) + ' , ' + str(y_min) + ']'
show_max = '[' + str(x_max) + ' , ' + str(y_max) + ']'
plt.annotate(show_min, xy=(x_min, y_min), xytext=(x_min, y_min))
plt.annotate(show_max, xy=(x_max, y_max), xytext=(x_max, y_max))
plt.xlabel('epoch')
plt.legend()
plt.ylabel(select)
# plt.ylim(0.8, 1.0) # 设置y轴坐标范围
plt.savefig(args.out_dir + '/' + pic_name + '.jpg', dpi=300)
绘制多个mmdetection的log.json日志文件
于 2023-01-16 14:51:45 首次发布
该Python脚本使用argparse解析命令行参数,读取json日志文件,用matplotlib绘制训练或验证模式下的指标(如AP50)随epoch变化的曲线,并标注出最大值和最小值的epoch。用户可以指定模式、指标、日志路径、图例名称、输出目录、世代数和图片名。
摘要由CSDN通过智能技术生成