利用HiddenLayer和netron进行pytorch模型结构可视化

0、简介

模型可视化是通过直观方式查看我们模型的结构。通常我们使用pytorch定义的网络模型都是代码堆叠,实现的和我们想象的是否一致呢,除了细致推敲代码外,直接通过图的方式展示出来更加直观。在这里介绍HiddenLayer和netron进行模型可视化,HiddenLayer是可以直接对pt模型进行可视化的,而netron无法直接可视化pt模型,所以我们通过将pt转为onnx模型,再通过netron进行可视化。

1、利用HiddenLayer进行模型可视化

模型可视化的方法有很多,可以看看这篇文章:超实用的7种 pytorch 网络可视化方法,进来收藏一波

这里记录一下HiddenLayer这个工具的使用,先看效果图:
在这里插入图片描述
相比较于其他工具,这个库非常简介,并且只包含给人看的节点,还能展示输入输出的shape,非常的人性化。

首先在环境中安装:pip install hiddenlayer
然后使用代码如下:

import torch
import hiddenlayer as h
from torchvision.models import resnet18

myNet = resnet18()  # 实例化 resnet18
x = torch.zeros(16, 3, 64, 64)  # 随机生成一个输入

myNetGraph = h.build_graph(myNet, x)  # 建立网络模型图
# myNetGraph.theme = h.graph.THEMES['blue']  # blue 和 basic 两种颜色,可以不要
myNetGraph.save(path='./demoModel.png', format='png')  # 保存网络模型图,可以设置 png 和 pdf 等

问题
我遇到的是不显示shape并出现警告
Pango-WARNING **: couldn’t load font “Times Not-Rotated 10”, falling back to “Sans Not-Rotated 10”, expect ugly output.
解决办法:

其他问题,参见:
pytorch 网络可视化(六):hiddenlayer
hiddenlayer库使用出现的一系列问题

2、使用netron进行模型可视化

首先我们需要安装netron:pip install netron
然后使用代码如下:

import netron
import torch
from torch import nn

# 定义我们的模型
class TestNet(nn.Module):
    def __init__(self):
        super(TestNet, self).__init__()
        self.block1 = nn.Sequential(nn.Conv2d(3, 10, 3, 1, 1),
                                    nn.Conv2d(10, 10, 3, 1, 1))
        self.block2 = nn.Sequential(nn.Conv2d(10, 10, 3, 1, 1),
                                    nn.Conv2d(10, 10, 3, 1, 1))
        self.block3 = nn.Sequential(nn.Conv2d(10, 10, 3, 1, 1),
                                    nn.Conv2d(10, 10, 3, 1, 1))

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return x

# 进行模型推理(pt转onnx是通过跟踪计算流实现的,所以需要先推理一下)
net = TestNet()
input = torch.rand([1, 3, 10, 10])
output = net(input)

# 转为onnx模型
torch.onnx.export(net, input, "testnet.onnx", opset_version=11)
netron.start("testnet.onnx")	# 使用netron可视化onnx模型

执行代码会自动弹出web页面:
在这里插入图片描述

3、(高阶)使用netron进行模型可视化

通过上面的方法我们实现了模型的可视化,但是这个模型比较简单,如果来个复杂的模型,那么这个图就会很大很复杂,以至于我们都分不清和TestNet中的对应关系。

TestNet中的conv算子会转换成onnx中的conv算子,那意味着我们可以设计一个特殊的算子,暂且命名为DebugOp,转为onnx后在图中也会出现DebugOp算子,通过找到这个算子就能大致和TestNet进行关系对应。

具体代码如下:

import netron
import torch
from torch import nn

# 这里就是我们定义的特殊算子DebugOP
class DebugOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, name):  # 这个DebugOp算子将输入x直接返回,不做任何云算,插入到网络中也就不改变网络结构
        return x
    @staticmethod
    def symbolic(g, x, name):
        return g.op("my::Debug", x, name_s=name)
# 获取自定义算子的调用接口(用法上相当于实例化),后面就可以用debug_apply(x,name进行使用),在不同的地方可以传入不同的name
debug_apply = DebugOp.apply


class TestNet(nn.Module):
    def __init__(self):
        super(TestNet, self).__init__()
        self.block1 = nn.Sequential(nn.Conv2d(3, 10, 3, 1, 1),
                                    nn.Conv2d(10, 10, 3, 1, 1))
        self.block2 = nn.Sequential(nn.Conv2d(10, 10, 3, 1, 1),
                                    nn.Conv2d(10, 10, 3, 1, 1))
        self.block3 = nn.Sequential(nn.Conv2d(10, 10, 3, 1, 1),
                                    nn.Conv2d(10, 10, 3, 1, 1))

    def forward(self, x):
        x = debug_apply(x, "this is block1")	# 将我们的特殊算子插入到网络中
        x = self.block1(x)
        x = debug_apply(x, "this is block2")
        x = self.block2(x)
        x = debug_apply(x, "this is block3")
        x = self.block3(x)
        return x


net = TestNet()
input = torch.rand([1, 3, 10, 10])
output = net(input)
torch.onnx.export(net, input, "testnet1.onnx", opset_version=11)
netron.start("testnet1.onnx")

可视化结构如下,插入了很多debug算子,并且点击查看算子属性可以看到name是我们传入的name,于是我们就能够很清楚知道下面哪一部分是block1、block2和block3了,这在复杂网络结构中寻找某些层是非常有用的。
在这里插入图片描述

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

我是一个对称矩阵

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值