import csv
import matplotlib.pyplot as plt
def load_results(filename):
with open(filename, 'r') as file:
reader = csv.reader(file)
header = next(reader)
data = list(reader)
return header, data
def plot_results(results, filenames):
num_plots = len(results[0][0]) - 1 # 列的数量,除去第一列
num_rows = 4 # 每行显示的子图数量
num_cols = (num_plots + num_rows - 1) // num_rows # 计算总列数
fig, axs = plt.subplots(num_rows, num_cols, figsize=(18, 14))
fig.subplots_adjust(hspace=0.6, wspace=0.4) # 调整子图之间的垂直和水平间距
axs = axs.flatten() # 展平子图数组以便索引
for i in range(num_plots):
ax = axs[i]
for j, result in enumerate(results):
header, data = result
column_name = header[i+1] # 列名称,从第二列开始
values = [float(row[i+1]) for row in data] # 列数据,从第二列开始
x = [float(row[0]) for row in data] # 每个文件的第一列作为横坐标
ax.plot(x, values, label=f"{filenames[j]} - {column_name}") # 添加文件名到标签中
ax.set_xlabel('Epoch')
ax.set_ylabel('')
ax.set_title(column_name)
ax.legend(loc='lower right') # 将标签固定显示在右下角
ax.grid(True)
# 自动调整纵坐标刻度
ax.autoscale(axis='y')
# 隐藏未使用的子图
for j in range(num_plots, num_rows * num_cols):
axs[j].axis('off')
plt.suptitle('Comparison of Results', fontsize=16) # 添加整幅图的标题并指定字号
plt.tight_layout()
plt.savefig('comparison_plot.png') # 保存图像
plt.show()
# 读取多个结果文件
results_files = ['result1.csv', 'result2.csv', 'result3.csv']
results = []
filenames = []
for filename in results_files:
header, data = load_results(filename)
results.append((header, data))
filenames.append(filename)
# 绘制对比图并保存
plot_results(results, filenames)
绘制yoloV5的多个模型训练结果
最新推荐文章于 2024-04-01 23:31:58 发布