Paddle的实时画图功能

记得第一次跑Paddle的官方例子时,最让我惊奇的是它能实时地显示train loss和val loss随着训练次数的变化图。

之前用Pytorch都是在训练过程中先打印训练过程的loss,并且把他们保存起来。最后等训练完成再作图。所以在训练完成之前只能看着数字的变化,却无法有一个实时的图形化结果。百度的Paddle给了我一个惊喜。所以抱着如何实现的疑问,我去查看了框架源码,才发现,原来只是一个简单技巧就可以达成这种效果。

首先看看一个使用的例子:

from paddle.utils.plot import Ploter
import numpy as np
%matplotlib inline

train_prompt = "train cost"
test_prompt = "test cost"
plot_prompt = Ploter(train_prompt, test_prompt)

for step in range(10):
    plot_prompt.append(title=train_prompt, step=step, value=0.1-np.random.randn()*step) 
    plot_prompt.plot()

下图是官方例子训练出来的图,上述例子只能有个大概的样子
这是官方例子训练出来的图,上述例子只能有个大概的样子
这个功能的实现主要是靠paddle.utils.plot里面的Ploter这个类。

class Ploter(object):
    def __init__(self, *args):
        self.__args__ = args
        self.__plot_data__ = {}
        for title in args:
            self.__plot_data__[title] = PlotData()

从Ploter类的初始化方法,我们可以看到它接受任意个参数。每个参数会作为一个主题,也就是我们要画图的图标,在我们例子的代码里就是train cost和test cost。Ploter类初始化的时候会维护一个字典__plot_data__, 每个主题作为key,新创建的PlotData类作为对应的value。那么这个PlotData是个什么东西?看一下PlotData类的定义:

class PlotData(object):
    def __init__(self):
        self.step = []
        self.value = []

    def append(self, step, value):
        self.step.append(step)
        self.value.append(value)

    def reset(self):
        self.step = []
        self.value = []

非常简单的类结构,维持了两个list分别对应step和value,定义了append和reset方法。这两个方法都是在Ploter内会被调用。

由于不停调用plot多少还是会影响效率,所以在正式训练的时候建议通过设置环境变量来决定是否进行plot. Ploter类中会获取DISABLE_PLOT这个环境变量,如果其为True,则Ploter将不会实时画图。这样子就无需修改我们在测试阶段调用Ploter的plot方法代码了. 设置方式:os.environ["DISABLE_PLOT"]="True"

    def __init__(self, *args):
        self.__args__ = args
        self.__plot_data__ = {}
        for title in args:
            self.__plot_data__[title] = PlotData()
        self.__disable_plot__ = os.environ.get("DISABLE_PLOT")
        if not self.__plot_is_disabled__():
            import matplotlib.pyplot as plt
            from IPython import display
            self.plt = plt
            self.display = display

    def __plot_is_disabled__(self):
        return self.__disable_plot__ == "True"

看到百度Paddle把这个功能增加到框架里面,看来也是遇到过类似的需求:在代码调试阶段进行实时显示loss的变化情况,如果loss没有下降就要回去看看代码哪里有问题了。

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值