matplotlib当在Jupyter读取数据导致图片消失

matplotlib图片消失问题解决方案

起因: 自己实现了一个绘制数据的类用于可视化训练过程中的数据表现,如训练损失 训练准确度 验证准确度

在Jupyter notebook和IDE环境中都满足设计需求,但是将其加载到我的训练函数中却发现:

图片并没有显示出来

经过代码调试,最终锁定问题的产生代码:

def train_func(net, loss, train_data, valid_data, weight_decay, lr, num_epochs, device):
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
    plot_tool = PlotTool(num_epochs)
    net.to(device)
    print('开始训练')
    best_acc =0
    for i in range(num_epochs):
        net.train()
        cnt = 0
        train_loss, train_acc, valid_acc_temp = [], [], []
        for X, y in train_data:#问题发生在这里
            optimizer.zero_grad()
            X = X.to(device)
            y = y.to(device)
            assert X.device == y.device, f'X device : {X.device} y device : {y.device} '
            out = net(X)
            l = loss(out, y)
            l.backward()
            optimizer.step()

当我试图从DataLoader对象中读取数据,就会使得我show出来的图片消失

我曾测试过是否是for loop的问题,但是其依旧可以正常绘制

plot_tool=PlotTool(10) #10个size
x=np.linspace(1,10,num=10)
for data in x:
    val=np.power(data,2)
    plot_tool.add_ele(data,[val]*3)

测试示例

最终确定 是在加载数据时产生了错误

for X, y in train_data:#问题发生在这里

在未运行这个语句之前,我的图片一直存活,但是一旦运行,我的图片就会瞬间消失,cell的输出窗口也变成了一片空白。

针对这个问题 个人猜测是因为在读取数据时运行了多个进程导致的。

我查询了很多资料,并没有提及到这个问题的.

最终通过摸索找到了解决方案

不依赖matplotlib的inline魔法函数 手动display目标fig

from IPython import display
def display(self):
    if self.mode == 'IDE':
        self.fig.show()  # 在IDE里进行绘图操作
        elif self.mode in ['notebook', 'train']:
            display.display(self.fig)  # 当在jupyter时需要使用这个进行绘图操作
            display.clear_output(wait=True)

最终版本:

import numpy as np
import matplotlib.pyplot as plt
import math
from IPython import display
class PlotTool():
    def __init__(self, num_epoch, num_element=3, sub_size=(2, 1), mode='notebook'):
        assert mode in ['notebook', 'IDE', 'train'], f'\n参数mode:{mode} 非法 \n' \
                                                    f'合法参数:[''notebook'',''IDE'',''train'']'
        self.epoch = num_epoch
        self.mode = mode
        self.xlim = [1, self.epoch]  # 设置坐标范围
        self.x_data = [[] for _ in range(num_element)]  # 预先声明数据空间
        self.y_data = [[] for _ in range(num_element)]  # 预先声明数据空间
        self.fig, self.axes = plt.subplots(sub_size[0], sub_size[1],figsize=(18,10))  # 确定子图尺寸 figsize格式为(图的宽 图的高)
        self.fmt = ['r-', 'm--', 'g-.']
        self.display()
        self.config_axes()

    def display(self):
        if self.mode == 'IDE':
            self.fig.show()  # 在IDE里进行绘图操作
        elif self.mode in ['notebook', 'train']:
            display.display(self.fig)  # 当在jupyter时需要使用这个进行绘图操作
            display.clear_output(wait=True)

    def config_axes(self):
        self.axes[0].legend(['train loss'])
        self.axes[1].legend(['train acc', 'valid acc'])
#         self.axes[1].set_ylim([0,1])
        for ax_bj in self.axes:
            ax_bj.set_xlim(self.xlim)
            ax_bj.grid()

    def add_ele(self, x, y):  # 增加数据点
        if not hasattr(x, '__len__') and hasattr(y, '__len__'):
            x = [x] * len(y)  # 拓展
            # 当x只有一个维度 而y不止时 上述有用
        if not hasattr(y, '__len__') and hasattr(x, '__len__'):
            y = [y] * len(x)  # 拓展维度 同上
        for ind, (x, y) in enumerate(zip(x, y)):  # 打包再枚举出来
            self.x_data[ind].append(x)  # 存入数据
            self.y_data[ind].append(y)  # 存入数据
        self.axes[0].plot(self.x_data[0], self.y_data[0], self.fmt[0])
        for x_val, y_val, fmt in zip(self.x_data[1:], self.y_data[1:], self.fmt[1:]):
            self.axes[1].plot(x_val, y_val, fmt)
        self.display()
        self.config_axes()

效果:

当追加数据时,将会动态地绘制出所有的数据

具有动画效果

绘制效果

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值