pytorch模型可视化

1. 使用dot

1.1 安装graphviz和torchviz

sudo apt-get install graphviz

sudo pip install torchviz

1.2 使用torchviz

import torch
from torch import nn
from torchviz import make_dot, make_dot_from_trace

# Visualize gradients of simple MLP
# The method below is for building directed graphs of PyTorch operations, built during forward propagation and showing which operations will be called on backward. It omits subgraphs which do not require gradients.

model=nn.Sequential()
model.add_module("W0", nn.Linear(8, 16))
model.add_module("tanh", nn.Tanh())
model.add_module("W1", nn.Linear(16, 1))

x = Variable(torch.randn(1, 8))
y = model(x)

make_dot(y.mean(), params=dict(model.named_parameters()))  # 直接在ipython notebook中显示

dot=make_dot(y.mean(), params=dict(model.named_parameters()))
dot.render("model.pdf")  #保存为pdf

github教程

2. 使用tensorwatch

安装pytorch = 1.2, tensorwatch = 0.8.7

import tensorwatch as tw
import torchvision.models

alexnet_model = torchvision.models.alexnet()
tw.draw_model(alexnet_model, [1, 3, 224, 224]).save("alextnet.png")

教程: tensorwatch

3. 使用hiddenlayer

pip install torch===1.4.0 torchvision===0.5.0 -f https://download.pytorch.org/whl/torch_stable.html

pip install hiddenlayer

import torch
import torchvision.models
import hiddenlayer as hl

# Resnet101
model = torchvision.models.resnet101()

# Rather than using the default transforms, build custom ones to group
# nodes of residual and bottleneck blocks.
transforms = [
    # Fold Conv, BN, RELU layers into one
    hl.transforms.Fold("Conv > BatchNorm > Relu", "ConvBnRelu"),
    # Fold Conv, BN layers together
    hl.transforms.Fold("Conv > BatchNorm", "ConvBn"),
    # Fold bottleneck blocks
    hl.transforms.Fold("""
        ((ConvBnRelu > ConvBnRelu > ConvBn) | ConvBn) > Add > Relu
        """, "BottleneckBlock", "Bottleneck Block"),
    # Fold residual blocks
    hl.transforms.Fold("""ConvBnRelu > ConvBnRelu > ConvBn > Add > Relu""",
                       "ResBlock", "Residual Block"),
    # Fold repeated blocks
    hl.transforms.FoldDuplicates(),
]

# Display graph using the transforms above
dot = hl.build_graph(model, torch.zeros([1, 3, 224, 224]), transforms=transforms)
dot.attr("graph", rankdir="TD")
dot.render("resnet101")  # save as resnet101.pdf

教程: hiddenlayer

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值