《深入浅出Pytorch》第七章学习笔记

一、可视化网络结构

        本节我们将介绍如何使用torchinfo来可视化网络结构。

1. 使用print函数打印模型基础信息

import torchvision.models as models
model = models.resnet18()

print(model)

        print(model),只能得出基础构件的信息,既不能显示出每一层的shape,也不能显示对应参数量的大小。

2.使用torchinfo可视化网络结构

  • 安装方法
    # 安装方法一
    pip install torchinfo 
    # 安装方法二
    conda install -c conda-forge torchinfo
  •  torchinfo的使用 

        trochinfo的使用也是十分简单,我们只需要使用torchinfo.summary()就行了,必需的参数分别是model,input_size[batch_size,channel,h,w]。 

import torchvision.models as models
from torchinfo import summary

resnet_18 = models.resnet18()  # 实例化模型
summary(resnet_18, (1, 3, 224, 224))  # 1:batch_size 3:图片的通道数 224: 图片的高宽

        torchinfo提供了更加详细的信息,包括模块信息(每一层的类型、输出shape和参数量)、模型整体的参数量、模型大小、一次前向或者反向传播需要的内存大小等。

注意

        但你使用的是colab或者jupyter notebook时,想要实现该方法,summary()一定是该单元(即notebook中的cell)的返回值,否则我们就需要使用print(summary(...))来可视化。

二、CNN可视化

1. CNN卷积核可视化

        在PyTorch中可视化卷积核非常方便,核心在于特定层的卷积核即特定层的模型权重,可视化卷积核就等价于可视化对应的权重矩阵。下面给出在PyTorch中可视化卷积核的实现方案,以torchvision自带的VGG11模型为例。

2. CNN特征图可视化方法

        在PyTorch中,提供了一个专用的接口使得网络在前向传播过程中能 够获取到特征图,这个接口的名称非常形象,叫做hook。可以想象这样的场景,数据通过网络向前传 播,网络某一层我们预先设置了一个钩子,数据传

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值