一、注意事项
- plt.legend() 是展示图例的,就是 ax.plot() 里面的label值,plt.legend() 必须在 ax.plot() 之后调用才生效,并且因为 plt.legend() 每调用一次就会生成一次图例,所以仅在首次绘制时调用即可。
- plt.pause(0.001) 不可省略
- 绘制时 x、y 数组的长度必须相等且值一一对应,长度可逐步增加,但是不要出现 None 值,否则绘制的线条会出现断点
二、实现
import matplotlib.pyplot as plt
import time
from matplotlib_inline import backend_inline
class TrainVision(object):
def __init__(self):
# svg模式
backend_inline.set_matplotlib_formats('svg')
# 用于显示正常中文标签
plt.rcParams['font.sans-serif'] = ['SimHei']
# 在 1*1 的画布 fig 上创建图纸 ax
self.fig, self.ax = plt.subplots(1, 1, figsize=(3.5, 2.5))
# 展示网格线
self.ax.grid()
# x轴标签
self.ax.set_xlabel('epochs')
# y轴标签
self.ax.set_ylabel('acc & loss')
# epoch、训练精度、训练损失、测试精度 的累加数组
# y是精度或者损失值,x是y对应的epoch
self.train_acc = {'x': [], 'y': []}
self.train_loss = {'x': [], 'y': []}
self.test_acc = {'x': [], 'y': []}
# 是否加载过图例的标记
self.is_init_legend = False
def draw(self, epoch_x, train_acc_y1, train_loss_y2, test_acc_y3):
# 加入位置信息数组
if epoch_x:
if train_acc_y1: # 训练精度 为 None 时不加入数组
self.train_acc['x'].append(epoch_x)
self.train_acc['y'].append(train_acc_y1)
if train_loss_y2: # 训练损失 为 None 时不加入数组
self.train_loss['x'].append(epoch_x)
self.train_loss['y'].append(train_loss_y2)
if test_acc_y3: # 测试精度 为 None 时不加入数组
self.test_acc['x'].append(epoch_x)
self.test_acc['y'].append(test_acc_y3)
# 绘制
self.ax.plot(self.train_acc['x'], self.train_acc['y'], color='blue', label='train loss')
self.ax.plot(self.train_loss['x'], self.train_loss['y'], color='red', label='train acc')
self.ax.plot(self.test_acc['x'], self.test_acc['y'], color='green', label='test acc')
# 图例仅在首次加载时创建
if not self.is_init_legend:
plt.legend()
self.is_init_legend = True
plt.draw()
plt.pause(0.001)
if __name__ == '__main__':
# 初始化
vt = TrainVision()
# 随便写的 训练精度、训练损失 和 测试精度
x = [a for a in range(50)]
train_acc_list = [a**2 for a in x]
train_loss_list = [a**2+a*2 for a in x]
test_acc_list = [a**2+a*4 for a in x]
# 模拟训练
for epochs, (train_acc, train_loss, test_acc) in enumerate(zip(train_acc_list, train_loss_list, test_acc_list)):
print(epochs)
# 训练耗时
time.sleep(0.3)
# 每个 epoch 绘制 训练精度 和 训练损失
vt.draw(epochs, train_acc, train_loss, None)
# 每5个 epoch 绘制 测试精度
if (epochs+1) % 5 == 0:
vt.draw(epochs, None, None, test_acc)
跑一下康康