使用Pytorch如何查看网络结构

文章介绍了如何在Pytorch中使用torchsummary库以简洁的表格形式展示模型(如Darknet53)的结构和各层shape,类似于Tensorflow的model.summary。通过安装torchsummary并定义模型后,调用summary函数即可查看模型详情。然而,对于包含多个Sequential卷积块的模型,summary无法完全展示层次关系。
摘要由CSDN通过智能技术生成

问题描述

使用Pytorch的童鞋在训练完一个模型或者导入模型的时候,会想要查看这个网络的每一层的结构和shape。虽然可以使用model.modules来查看,但是以文本段落来展示就显得有些凌乱。如果用过Tensorflow的话,一定对model.summary情有独钟,简洁的表格形式展示了模型各层的名字、shape等信息。那么如何在torch中同样以表格形式打印呢?这里使用的是torchsummary库来进行实现。


废话少说,上实例

在这里模型构建我就不再赘述,这里我是用torch搭建的Yolo v3的darknet53网络。如果当前环境下没有torchsummary的话,可以通过pip/pip3来进行便捷安装。

pip install torchsummary

torch的查看模型具体代码

# 导入相关包
from torchsummary import summary
import torch

# 定义darknet53模型
class Darknet(torch.nn.Module):
	def __init__(self):
		super().__init__()
		...
	def forward(self):
		...

# 定义模型,并打印模型结构
model=Darknet()
# yolo3的输入的一个样本的size是(3,416,416),这里需要注意一点,如果模型中的forward设定了CUDA的参数,则要保证在summary中的device选择与model中的CUDA是一样的,否则会报错。
summary(model,(3,416,416),device='cpu')

上述代码运行的部分结果如下:
在这里插入图片描述
这是summary模型结果。不过这么显示有个弊端:类似darknet53的模型,其中是由多个Sequential串联的卷积块组成,summary没有相关表示。在torch的默认显示下就体现这些层次关系:
在这里插入图片描述

最后嘱咐一句

如果想对模型各层进行统计,可以简单修改一下torchsummary的源文件。在这里插入图片描述
可以看出summary默认是将summary注释掉的。如果想要将表格的内容保存下来或者统计,可以把注释去掉,并将summary赋值到某一个变量中。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值