matplotlib动态绘制训练进度【训练精度、训练损失、测试精度】

27 篇文章 8 订阅
20 篇文章 1 订阅


一、注意事项

  1. plt.legend() 是展示图例的,就是 ax.plot() 里面的label值,plt.legend() 必须在 ax.plot() 之后调用才生效,并且因为 plt.legend() 每调用一次就会生成一次图例,所以仅在首次绘制时调用即可。
  2. plt.pause(0.001) 不可省略
  3. 绘制时 x、y 数组的长度必须相等且值一一对应,长度可逐步增加,但是不要出现 None 值,否则绘制的线条会出现断点

二、实现

import matplotlib.pyplot as plt
import time
from matplotlib_inline import backend_inline


class TrainVision(object):
    def __init__(self):

        # svg模式
        backend_inline.set_matplotlib_formats('svg')

        # 用于显示正常中文标签
        plt.rcParams['font.sans-serif'] = ['SimHei']

        # 在 1*1 的画布 fig 上创建图纸 ax
        self.fig, self.ax = plt.subplots(1, 1, figsize=(3.5, 2.5))
        # 展示网格线
        self.ax.grid()
        # x轴标签
        self.ax.set_xlabel('epochs')
        # y轴标签
        self.ax.set_ylabel('acc & loss')
        # epoch、训练精度、训练损失、测试精度 的累加数组
        # y是精度或者损失值,x是y对应的epoch
        self.train_acc = {'x': [], 'y': []}
        self.train_loss = {'x': [], 'y': []}
        self.test_acc = {'x': [], 'y': []}
        # 是否加载过图例的标记
        self.is_init_legend = False

    def draw(self, epoch_x, train_acc_y1, train_loss_y2, test_acc_y3):

        # 加入位置信息数组
        if epoch_x:
            if train_acc_y1:	# 训练精度 为 None 时不加入数组
                self.train_acc['x'].append(epoch_x)
                self.train_acc['y'].append(train_acc_y1)
            if train_loss_y2:	# 训练损失 为 None 时不加入数组
                self.train_loss['x'].append(epoch_x)
                self.train_loss['y'].append(train_loss_y2)
            if test_acc_y3:		# 测试精度 为 None 时不加入数组
                self.test_acc['x'].append(epoch_x)
                self.test_acc['y'].append(test_acc_y3)

        # 绘制
        self.ax.plot(self.train_acc['x'], self.train_acc['y'], color='blue', label='train loss')
        self.ax.plot(self.train_loss['x'], self.train_loss['y'], color='red', label='train acc')
        self.ax.plot(self.test_acc['x'], self.test_acc['y'], color='green', label='test acc')

        # 图例仅在首次加载时创建
        if not self.is_init_legend:
            plt.legend()
            self.is_init_legend = True

        plt.draw()
        plt.pause(0.001)


if __name__ == '__main__':

    # 初始化
    vt = TrainVision()

    # 随便写的 训练精度、训练损失 和 测试精度
    x = [a for a in range(50)]
    train_acc_list = [a**2 for a in x]
    train_loss_list = [a**2+a*2 for a in x]
    test_acc_list = [a**2+a*4 for a in x]

    # 模拟训练
    for epochs, (train_acc, train_loss, test_acc) in enumerate(zip(train_acc_list, train_loss_list, test_acc_list)):

        print(epochs)

        # 训练耗时
        time.sleep(0.3)

        # 每个 epoch 绘制 训练精度 和 训练损失
        vt.draw(epochs, train_acc, train_loss, None)

        # 每5个 epoch 绘制 测试精度
        if (epochs+1) % 5 == 0:
            vt.draw(epochs, None, None, test_acc)

跑一下康康

在这里插入图片描述

  • 15
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 10
    评论
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

什么都干的派森

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值