Paddle模型结构可视化

该文章写于学习paddle框架时发现框架下面的visualdl可以将模型结构可视化出来,但是根据官方例子下面的add_graph接口并未将结构可视化出来,于是根据paddle官网保存模型api做了一定修改,最终实现效果。

实现两种方式:

一、先保存动态图模式模型再可视化

对于动态图模式相比于静态图模型可以更加便于调试。保存动态图模型可以直接使用paddle.jit.save保存模型和参数,直接保存动态图需要注意两点:
1、确保该模型的forward方法仅实现预测功能,避免将训练所需loss计算逻辑写入forward方法

2、使用paddle.jit.save方法保存模型时,需要指定InputSpec(用于描述模型输入的签名信息,包括 shape、dtype 和 name)。Layer对象forward方法的每一个参数均需要对应的InputSpec进行描述,不能省略。

input_spec参数支持两种类型的输入:

        1)InputSpec列表:

                使用InputSpec描述forward输入参数的shape,dtype和name。name可以省略name省略的情况下会使用forward的对应参数名作为name,例子中的name为x

paddle.jit.save(net,save_path,input_spec=[InputSpec(shape=[1,8,24,24],dtype='float32')])

2)Example Tensor列表:

                直接使用forward训练时的示例输入。例如,直接使用DataLoader迭代得到的 image

paddle.jit.save(
    layer=layer,
    path=path,
    input_spec=[image])

本次根据动态图模式实现结构可视化流程:

1、先搭建一个深度学习模型(此处为举例子,随便构建一个模型)

class SelfNet(nn.Layer):
    def __init__(self):
        super(SelfNet,self).__init__()
        self.conv1 = nn.Conv2D(8,16,3)
        self.max1 = nn.MaxPool2D(2)
        self.conv2 = nn.Conv2D(16,32,3)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.max1(x)
        x = F.relu(self.conv2(x))
        return x
    

2、测试模型及保存模型

net = SelfNet()
input = paddle.ones([1,8,24,24])
output = net(input)

save_path = './graph/model'
paddle.jit.save(net,save_path,input_spec=[InputSpec(shape=[1,8,24,24],dtype='float32')])

3、使用visualdl加载保存模型

在命令行中,使用

visualdl --model ./graph/model.pdmodel --port 8080

启动相关服务,下图为服务启动成功,点击地址进入web

模型可视化成功!

右边可以选择显示对应参数,对模型结构进行缩放和保存模型结构png和svg格式。

二、先保存动转静模型再可视化

动转静保存模型参数有以下注意点:

1、Layer对方的forward方法需要经由 paddle.jit.to_static 装饰(例子如下)

    @paddle.jit.to_static(input_spec=[InputSpec(shape=[1,8,24,24],dtype='float32')])
    #@paddle.jit.to_static
    
    def forward(self, x):

若最终需要生成的描述模型的Program支持动态输入,可以同时指明模型的 InputSepc

2、确保该模型的forward方法仅实现预测功能,避免将训练所需loss计算逻辑写入forward方法

3、如果你需要保存多个方法,需要用 paddle.jit.to_static 装饰每一个需要被保存的方法。

注:只有在forward之外还需要保存其他方法时才用这个特性,如果仅装饰非forward的方法,而forward没有被装饰,是不符合规范的。此时 paddle.jit.save 的 input_spec 参数必须为None。

实现代码如下(实现代码仅与动态图有两处不同):

class SelfNet(nn.Layer):
    def __init__(self):
        super(SelfNet,self).__init__()
        self.conv1 = nn.Conv2D(8,16,3)
        self.max1 = nn.MaxPool2D(2)
        self.conv2 = nn.Conv2D(16,32,3)
    
    @paddle.jit.to_static(input_spec=[InputSpec(shape=[1,8,24,24],dtype='float32')])
    #@paddle.jit.to_static
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.max1(x)
        x = F.relu(self.conv2(x))
        return x
    
net = SelfNet()
input = paddle.ones([1,8,24,24])
output = net(input)

save_path = './graph_static/model'
# paddle.jit.save(net,save_path,input_spec=[InputSpec(shape=[1,8,24,24],dtype='float32')])
paddle.jit.save(net,save_path)

在运行代码之后,再次启动visualdl可以查看对应模型

完整代码如下:

import paddle
import paddle.nn.functional as F
import paddle.nn as nn
from visualdl import LogWriter
from paddle.static import InputSpec
class SelfNet(nn.Layer):
    def __init__(self):
        super(SelfNet,self).__init__()
        self.conv1 = nn.Conv2D(8,16,3)
        self.max1 = nn.MaxPool2D(2)
        self.conv2 = nn.Conv2D(16,32,3)
    
    #@paddle.jit.to_static(input_spec=[InputSpec(shape=[1,8,24,24],dtype='float32')])
    #@paddle.jit.to_static
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.max1(x)
        x = F.relu(self.conv2(x))
        return x
    
net = SelfNet()
input = paddle.ones([1,8,24,24])
output = net(input)

save_path = './graph/model'
paddle.jit.save(net,save_path,input_spec=[InputSpec(shape=[1,8,24,24],dtype='float32')])

    
visualdl --model ./graph/model.pdmodel --port 8080

对于未能实现visualdl的add_graph接口实现模型可视化,这个地方可能是我哪个地方没处理好,没有实现效果,等以后有时间看看代码在进行剖析。

  • 23
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值