matplotlib图片消失问题解决方案
起因: 自己实现了一个绘制数据的类用于可视化训练过程中的数据表现,如训练损失 训练准确度 验证准确度
在Jupyter notebook和IDE环境中都满足设计需求,但是将其加载到我的训练函数中却发现:
图片并没有显示出来
经过代码调试,最终锁定问题的产生代码:
def train_func(net, loss, train_data, valid_data, weight_decay, lr, num_epochs, device):
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
plot_tool = PlotTool(num_epochs)
net.to(device)
print('开始训练')
best_acc =0
for i in range(num_epochs):
net.train()
cnt = 0
train_loss, train_acc, valid_acc_temp = [], [], []
for X, y in train_data:#问题发生在这里
optimizer.zero_grad()
X = X.to(device)
y = y.to(device)
assert X.device == y.device, f'X device : {X.device} y device : {y.device} '
out = net(X)
l = loss(out, y)
l.backward()
optimizer.step()
当我试图从DataLoader对象中读取数据,就会使得我show出来的图片消失
我曾测试过是否是for loop
的问题,但是其依旧可以正常绘制
plot_tool=PlotTool(10) #10个size
x=np.linspace(1,10,num=10)
for data in x:
val=np.power(data,2)
plot_tool.add_ele(data,[val]*3)
最终确定 是在加载数据时产生了错误
for X, y in train_data:#问题发生在这里
在未运行这个语句之前,我的图片一直存活,但是一旦运行,我的图片就会瞬间消失,cell的输出窗口也变成了一片空白。
针对这个问题 个人猜测是因为在读取数据时运行了多个进程导致的。
我查询了很多资料,并没有提及到这个问题的.
最终通过摸索找到了解决方案:
不依赖matplotlib的inline魔法函数 手动display目标fig
from IPython import display
def display(self):
if self.mode == 'IDE':
self.fig.show() # 在IDE里进行绘图操作
elif self.mode in ['notebook', 'train']:
display.display(self.fig) # 当在jupyter时需要使用这个进行绘图操作
display.clear_output(wait=True)
最终版本:
import numpy as np
import matplotlib.pyplot as plt
import math
from IPython import display
class PlotTool():
def __init__(self, num_epoch, num_element=3, sub_size=(2, 1), mode='notebook'):
assert mode in ['notebook', 'IDE', 'train'], f'\n参数mode:{mode} 非法 \n' \
f'合法参数:[''notebook'',''IDE'',''train'']'
self.epoch = num_epoch
self.mode = mode
self.xlim = [1, self.epoch] # 设置坐标范围
self.x_data = [[] for _ in range(num_element)] # 预先声明数据空间
self.y_data = [[] for _ in range(num_element)] # 预先声明数据空间
self.fig, self.axes = plt.subplots(sub_size[0], sub_size[1],figsize=(18,10)) # 确定子图尺寸 figsize格式为(图的宽 图的高)
self.fmt = ['r-', 'm--', 'g-.']
self.display()
self.config_axes()
def display(self):
if self.mode == 'IDE':
self.fig.show() # 在IDE里进行绘图操作
elif self.mode in ['notebook', 'train']:
display.display(self.fig) # 当在jupyter时需要使用这个进行绘图操作
display.clear_output(wait=True)
def config_axes(self):
self.axes[0].legend(['train loss'])
self.axes[1].legend(['train acc', 'valid acc'])
# self.axes[1].set_ylim([0,1])
for ax_bj in self.axes:
ax_bj.set_xlim(self.xlim)
ax_bj.grid()
def add_ele(self, x, y): # 增加数据点
if not hasattr(x, '__len__') and hasattr(y, '__len__'):
x = [x] * len(y) # 拓展
# 当x只有一个维度 而y不止时 上述有用
if not hasattr(y, '__len__') and hasattr(x, '__len__'):
y = [y] * len(x) # 拓展维度 同上
for ind, (x, y) in enumerate(zip(x, y)): # 打包再枚举出来
self.x_data[ind].append(x) # 存入数据
self.y_data[ind].append(y) # 存入数据
self.axes[0].plot(self.x_data[0], self.y_data[0], self.fmt[0])
for x_val, y_val, fmt in zip(self.x_data[1:], self.y_data[1:], self.fmt[1:]):
self.axes[1].plot(x_val, y_val, fmt)
self.display()
self.config_axes()
效果:
当追加数据时,将会动态地绘制出所有的数据
具有动画效果