根据Results.csv文件画性能曲线图

import pandas as pd
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
import os
import math

# 训练结果列表 (请确保这里的路径是正确的)
# Training results file list (Please ensure the paths here are correct)
results_files = [
    r'F:\ultralyticsPro\yolocsv\ours.csv',
    r'F:\ultralyticsPro\yolocsv\yolov5.csv',
    r'F:\ultralyticsPro\yolocsv\yolov8.csv',
    r'F:\ultralyticsPro\yolocsv\yolov9.csv',
    r'F:\ultralyticsPro\yolocsv\yolov10.csv',
    r'F:\ultralyticsPro\yolocsv\yolov11.csv',
    r'F:\ultralyticsPro\yolocsv\yolov12.csv',
]

# 与results_files顺序对应的自定义标签
# Custom labels corresponding to the order of results_files
custom_labels = [
    'ours',
    'YOLOv5n',
    'YOLOv8n',
    'YOLOv9n',
    'YOLOv10n',
    'YOLOv11n',
    'YOLOv12n',
]

# 定义不同的线型和标记点用于区分曲线
# Define different linestyles and markers to distinguish curves
linestyles = ['-', '--', '-.', ':'] # 实线, 虚线, 点划线, 点线 Solid, dashed, dash-dot, dotted
markers = ['o', 's', '^', 'v', 'D', '*', 'p'] # 圆, 方块, 上三角, 下三角, 菱形, 星号, 五边形 Circle, square, triangle_up, triangle_down, diamond, star, pentagon

def plot_metrics_grid(
        metrics_row1,
        labels_row1,
        metrics_row2,
        labels_row2,
        custom_labels,
        results_files,
        output_path='metrics_grid_plot.png',
        max_x_axis_epoch=200,
        sigma=1,
        dpi=300
):
    """
    绘制性能指标的网格图,分为两行显示,每个指标一个独立的坐标框。
    使用不同的颜色、线型和标记点区分不同模型的曲线。
    网格布局会根据指标数量动态调整。

    Draws a grid plot of performance metrics, split into two rows,
    with each metric in its own subplot.
    Uses different colors, linestyles, and markers to distinguish curves for different models.
    The grid layout adjusts dynamically based on the number of metrics per row.

    Args:
        metrics_row1 (list): 第 1 行性能指标的键名列表.
        labels_row1 (list): 第 1 行性能指标的显示标签列表.
        metrics_row2 (list): 第 2 行性能指标的键名列表.
        labels_row2 (list): 第 2 行性能指标的显示标签列表.
        custom_labels (list): 每个结果文件的自定义标签.
        results_files (list): CSV 文件路径列表.
        output_path (str): 输出图像的文件路径.
        max_x_axis_epoch (int): X 轴(Epoch)的最大值.
        sigma (float): 高斯平滑的 sigma 值.
        dpi (int): 保存图像的分辨率.
    """
    num_metrics_r1 = len(metrics_row1)
    num_metrics_r2 = len(metrics_row2)
    ncols = max(num_metrics_r1, num_metrics_r2)

    # 创建图形和 GridSpec (2 行) - 增大 figsize 使子图更大
    # Create figure and GridSpec (2 rows) - Increase figsize for larger subplots
    fig = plt.figure(figsize=(max(14, 7 * ncols), 14)) # 增大图形尺寸 Increase figure size
    gs = fig.add_gridspec(2, ncols)

    # --- 绘制第 1 行性能指标 ---
    print("--- 开始绘制第 1 行性能指标 ---")
    for i, (metric_key, metric_label) in enumerate(zip(metrics_row1, labels_row1)):
        ax = fig.add_subplot(gs[0, i])
        print(f"绘制指标: {metric_label} ({metric_key})")
        plot_count = 0
        for idx, (file_path, custom_label) in enumerate(zip(results_files, custom_labels)):
            try:
                if not os.path.exists(file_path):
                    print(f"警告: 文件未找到,跳过: {file_path}")
                    continue
                df = pd.read_csv(file_path)
                df.columns = df.columns.str.strip()
                if df.empty:
                    print(f"警告: 文件为空或读取失败,跳过: {file_path}")
                    continue
                if 'epoch' not in df.columns or metric_key not in df.columns:
                    print(f"警告: 文件 {os.path.basename(file_path)} 缺少列 'epoch' 或 '{metric_key}',跳过。")
                    continue
                if df[metric_key].isnull().all() or not pd.api.types.is_numeric_dtype(df[metric_key]):
                     print(f"警告: 文件 {os.path.basename(file_path)} 中的列 '{metric_key}' 数据无效或非数值,跳过。")
                     continue

                epochs_to_plot = df['epoch']
                y_data = df[metric_key].astype(float)
                y_smooth = gaussian_filter1d(y_data, sigma=sigma)
                max_epoch_data = epochs_to_plot.max()

                # 应用不同的线型和标记点
                current_linestyle = linestyles[idx % len(linestyles)]
                current_marker = markers[idx % len(markers)]
                line, = ax.plot(epochs_to_plot, y_smooth, label=f'{custom_label}',
                                linewidth=2.5,
                                linestyle=current_linestyle,
                                marker=current_marker,
                                markersize=4,
                                markevery=10
                               )
                plot_count += 1
                print(f"  已绘制: {custom_label} (来自 {os.path.basename(file_path)}, 线型: {current_linestyle}, 标记: {current_marker})")

                if max_epoch_data < max_x_axis_epoch:
                    last_y_value = y_smooth[-1]
                    last_epoch_value = epochs_to_plot.iloc[-1]
                    ax.plot(
                        [last_epoch_value, last_epoch_value],
                        [last_y_value, 0],
                        linestyle='dotted',
                        color=line.get_color(),
                    )
            except FileNotFoundError:
                print(f"错误: 文件未找到: {file_path}")
            except pd.errors.EmptyDataError:
                print(f"错误: 文件为空: {file_path}")
            except Exception as e:
                print(f"错误: 读取或处理文件 {os.path.basename(file_path)} 时发生错误: {e}")

        ax.set_title(f'{metric_label}', fontsize=14)
        ax.set_xlabel('Epochs', fontsize=12)
        ax.set_ylabel('Value', fontsize=12)
        if plot_count > 0:
             ax.legend(fontsize=10, loc='best')
        ax.set_ylim(bottom=0)
        ax.set_xlim(left=0, right=max_x_axis_epoch)
        ax.grid(True)

    # --- 绘制第 2 行性能指标 ---
    print("\n--- 开始绘制第 2 行性能指标 ---")
    for i, (metric_key, metric_label) in enumerate(zip(metrics_row2, labels_row2)):
        ax = fig.add_subplot(gs[1, i])
        print(f"绘制指标: {metric_label} ({metric_key})")
        plot_count = 0
        for idx, (file_path, custom_label) in enumerate(zip(results_files, custom_labels)):
            try:
                if not os.path.exists(file_path):
                    print(f"警告: 文件未找到,跳过: {file_path}")
                    continue
                df = pd.read_csv(file_path)
                df.columns = df.columns.str.strip()
                if df.empty:
                    print(f"警告: 文件为空或读取失败,跳过: {file_path}")
                    continue
                if 'epoch' not in df.columns or metric_key not in df.columns:
                    print(f"警告: 文件 {os.path.basename(file_path)} 缺少列 'epoch' 或 '{metric_key}',跳过。")
                    continue
                if df[metric_key].isnull().all() or not pd.api.types.is_numeric_dtype(df[metric_key]):
                     print(f"警告: 文件 {os.path.basename(file_path)} 中的列 '{metric_key}' 数据无效或非数值,跳过。")
                     continue

                epochs_to_plot = df['epoch']
                y_data = df[metric_key].astype(float)
                y_smooth = gaussian_filter1d(y_data, sigma=sigma)
                max_epoch_data = epochs_to_plot.max()

                # 应用不同的线型和标记点
                current_linestyle = linestyles[idx % len(linestyles)]
                current_marker = markers[idx % len(markers)]
                line, = ax.plot(epochs_to_plot, y_smooth, label=f'{custom_label}',
                                linewidth=2.5,
                                linestyle=current_linestyle,
                                marker=current_marker,
                                markersize=4,
                                markevery=10
                               )
                plot_count += 1
                print(f"  已绘制: {custom_label} (来自 {os.path.basename(file_path)}, 线型: {current_linestyle}, 标记: {current_marker})")

                if max_epoch_data < max_x_axis_epoch:
                    last_y_value = y_smooth[-1]
                    last_epoch_value = epochs_to_plot.iloc[-1]
                    ax.plot(
                        [last_epoch_value, last_epoch_value],
                        [last_y_value, 0],
                        linestyle='dotted',
                        color=line.get_color(),
                    )
            except FileNotFoundError:
                print(f"错误: 文件未找到: {file_path}")
            except pd.errors.EmptyDataError:
                print(f"错误: 文件为空: {file_path}")
            except Exception as e:
                print(f"错误: 读取或处理文件 {os.path.basename(file_path)} 时发生错误: {e}")

        ax.set_title(f'{metric_label}', fontsize=14)
        ax.set_xlabel('Epochs', fontsize=12)
        ax.set_ylabel('Value', fontsize=12)
        if plot_count > 0:
             ax.legend(fontsize=10, loc='best')
        ax.set_ylim(bottom=0)
        ax.set_xlim(left=0, right=max_x_axis_epoch)
        ax.grid(True)

    # 调整布局并保存图像
    plt.tight_layout(pad=2.0, h_pad=3.0)
    if output_path:
        try:
            output_dir = os.path.dirname(output_path)
            if output_dir:
                 os.makedirs(output_dir, exist_ok=True)
            plt.savefig(output_path, dpi=dpi, bbox_inches='tight')
            print(f"\n图像已成功保存到: {output_path}, dpi={dpi}")
        except Exception as e:
            print(f"\n错误: 保存图像到 {output_path} 时失败: {e}")
    else:
        print("\n警告: output_path 为空或无效,图像未保存。")

    # plt.show()


if __name__ == '__main__':
    metrics_row1 = [
        'metrics/precision(B)',
        'metrics/recall(B)',
    ]
    labels_row1 = ['Precision', 'Recall']

    metrics_row2 = [
        'metrics/mAP50(B)',
        'metrics/mAP50-95(B)',
    ]
    labels_row2 = ['mAP@50', 'mAP@50-95']

    output_path_grid = r'F:\ultralyticsPro\metrics_grid_plot_perf_markers_larger.png' # 输出路径

    plot_metrics_grid(
        metrics_row1,
        labels_row1,
        metrics_row2,
        labels_row2,
        custom_labels,
        results_files,
        output_path=output_path_grid,
        max_x_axis_epoch=200,
        sigma=1,
        dpi=300
    )
### YOLO模型训练过程中的损失曲线绘制 为了从YOLO训练生成的 `.csv` 文件中提取并绘制损失函数的变化曲线,可以利用 Python 的 `pandas` 和 `matplotlib` 库完成此操作。以下是具体实现方法: #### 数据读取与处理 `.csv` 文件通常由 YOLO 训练过程中自动生成,其中包含了多个指标列,例如 `train/box_loss`, `val/box_loss`, `train/obj_loss`, `val/obj_loss` 等。这些列分别表示不同阶段的损失值。 通过 Pandas 加载 CSV 文件,并选取特定的损失列进行可视化[^1]。 ```python import pandas as pd import matplotlib.pyplot as plt # 读取 .csv 文件 data = pd.read_csv('results.csv') # 查看数据结构以便确认所需列名 print(data.columns) # 提取所需的损失列 (例如 train/box_loss, val/box_loss) epochs = data['epoch'] # 假设存在 epoch 列用于标记迭代次数 train_box_loss = data['train/box_loss'] val_box_loss = data['val/box_loss'] # 如果不存在 'epoch' 列,则创建默认索引作为横坐标 if 'epoch' not in data.columns: epochs = range(len(train_box_loss)) ``` #### 可视化损失曲线 使用 Matplotlib 创建折线图以展示损失随时间的变化趋势[^2]。 ```python plt.figure(figsize=(10, 6)) # 绘制训练集上的 box loss 曲线 plt.plot(epochs, train_box_loss, label='Train Box Loss', color='blue') # 绘制验证集上的 box loss 曲线 plt.plot(epochs, val_box_loss, label='Validation Box Loss', color='red') # 添加图表标题和轴标签 plt.title('Box Loss Over Epochs') plt.xlabel('Epochs') plt.ylabel('Loss Value') # 显示图例 plt.legend() # 展示图形 plt.grid(True) plt.show() ``` 上述代码片段展示了如何针对 `box_loss` 进行绘图。如果需要绘制其他类型的损失(如对象检测损失或分类损失),只需替换对应的列名称即可。 --- #### 注意事项 - **CSV 文件路径**:确保指定正确的文件路径。 - **列名匹配**:不同的 YOLO 版本可能具有略微不同的列命名方式,请先打印 `data.columns` 来确认实际可用的列名。 - **多条曲线比较**:当涉及多种实验条件下的对比时,可以通过叠加更多线条来增强直观效果。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值