一、可视化网络结构
本节我们将介绍如何使用torchinfo来可视化网络结构。
1. 使用print函数打印模型基础信息
import torchvision.models as models
model = models.resnet18()
print(model)
print(model),只能得出基础构件的信息,既不能显示出每一层的shape,也不能显示对应参数量的大小。
2.使用torchinfo可视化网络结构
- 安装方法
# 安装方法一 pip install torchinfo # 安装方法二 conda install -c conda-forge torchinfo
- torchinfo的使用
trochinfo的使用也是十分简单,我们只需要使用torchinfo.summary()
就行了,必需的参数分别是model,input_size[batch_size,channel,h,w]。
import torchvision.models as models
from torchinfo import summary
resnet_18 = models.resnet18() # 实例化模型
summary(resnet_18, (1, 3, 224, 224)) # 1:batch_size 3:图片的通道数 224: 图片的高宽
torchinfo提供了更加详细的信息,包括模块信息(每一层的类型、输出shape和参数量)、模型整体的参数量、模型大小、一次前向或者反向传播需要的内存大小等。
注意:
但你使用的是colab或者jupyter notebook时,想要实现该方法,summary()
一定是该单元(即notebook中的cell)的返回值,否则我们就需要使用print(summary(...))
来可视化。
二、CNN可视化
1. CNN卷积核可视化
在PyTorch中可视化卷积核非常方便,核心在于特定层的卷积核即特定层的模型权重,可视化卷积核就等价于可视化对应的权重矩阵。下面给出在PyTorch中可视化卷积核的实现方案,以torchvision自带的VGG11模型为例。
2. CNN特征图可视化方法
在PyTorch中,提供了一个专用的接口使得网络在前向传播过程中能 够获取到特征图,这个接口的名称非常形象,叫做hook。可以想象这样的场景,数据通过网络向前传 播,网络某一层我们预先设置了一个钩子,数据传