【笔记 Pytorch】模型网络结构、网络参数可视化

查看网络结构

打印方式

torchsummary 方式(输入格式不好控制)

参考网址

import torch
import torchvision
from torchsummary import summary          #使用 pip install torchsummary
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vgg = torchvision.models.vgg16().to(device)

# summary(your_model, input_size=(channels, H, W))
summary(vgg, input_size=(3, 224, 224))

print方式 (简便,存在输出顺序与执行顺序不一致的问题)

for name, parameters in your_model.named_parameters():
    print(name, ':', parameters.size())

可视化方式

HiddenLayer

 pip install hiddenlayer
 import hiddenlayer as h
 vis_graph = h.build_graph(MyConvNet, torch.zeros([1 ,1, 28, 28]))   # 获取绘制图像的对象
 vis_graph.theme = h.graph.THEMES["blue"].copy()     # 指定主题颜色
 vis_graph.save("./demo1.png")   # 保存图像的路径

PytorchVIZ

 pip install torchviz
 from torchviz import make_dot
 x = torch.randn(1, 1, 28, 28).requires_grad_(True)  # 定义一个网络的输入值
 y = MyConvNet(x)    # 获取网络的预测值
 ​
 MyConvNetVis = make_dot(y, params=dict(list(MyConvNet.named_parameters()) + [('x', x)]))
 MyConvNetVis.format = "png"
 # 指定文件生成的文件夹
 MyConvNetVis.directory = "data"
 # 生成文件
 MyConvNetVis.view()

tensorboardX(会存在一些版本的匹配问题,不太直观)

graphviz + torchviz (依赖于graphviz和GitHub第三方库torchviz)

微软的tensorwatch (只能在jupyter notebook中使用)

netron可视化工具(.pt 或者是 .pth 文件)

查看网络参数

params = list(model.parameters())
k = 0
for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
                l *= j
        print("该层参数和:" + str(l))
        k = k + l
print("总参数数量和:" + str(k))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

坚果仙人

谢谢!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值