绘制yoloV5的多个模型训练结果

import csv
import matplotlib.pyplot as plt

def load_results(filename):
    with open(filename, 'r') as file:
        reader = csv.reader(file)
        header = next(reader)
        data = list(reader)
    return header, data

def plot_results(results, filenames):
    num_plots = len(results[0][0]) - 1  # 列的数量,除去第一列
    num_rows = 4  # 每行显示的子图数量
    num_cols = (num_plots + num_rows - 1) // num_rows  # 计算总列数

    fig, axs = plt.subplots(num_rows, num_cols, figsize=(18, 14))
    fig.subplots_adjust(hspace=0.6, wspace=0.4)  # 调整子图之间的垂直和水平间距

    axs = axs.flatten()  # 展平子图数组以便索引

    for i in range(num_plots):
        ax = axs[i]

        for j, result in enumerate(results):
            header, data = result

            column_name = header[i+1]  # 列名称,从第二列开始
            values = [float(row[i+1]) for row in data]  # 列数据,从第二列开始

            x = [float(row[0]) for row in data]  # 每个文件的第一列作为横坐标

            ax.plot(x, values, label=f"{filenames[j]} - {column_name}")  # 添加文件名到标签中

        ax.set_xlabel('Epoch')
        ax.set_ylabel('')
        ax.set_title(column_name)
        ax.legend(loc='lower right')  # 将标签固定显示在右下角
        ax.grid(True)

        # 自动调整纵坐标刻度
        ax.autoscale(axis='y')

    # 隐藏未使用的子图
    for j in range(num_plots, num_rows * num_cols):
        axs[j].axis('off')

    plt.suptitle('Comparison of Results', fontsize=16)  # 添加整幅图的标题并指定字号
    plt.tight_layout()
    plt.savefig('comparison_plot.png')  # 保存图像
    plt.show()

# 读取多个结果文件
results_files = ['result1.csv', 'result2.csv', 'result3.csv']
results = []
filenames = []

for filename in results_files:
    header, data = load_results(filename)
    results.append((header, data))
    filenames.append(filename)

# 绘制对比图并保存
plot_results(results, filenames)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值