问题描述
使用Pytorch的童鞋在训练完一个模型或者导入模型的时候,会想要查看这个网络的每一层的结构和shape。虽然可以使用model.modules来查看,但是以文本段落来展示就显得有些凌乱。如果用过Tensorflow的话,一定对model.summary情有独钟,简洁的表格形式展示了模型各层的名字、shape等信息。那么如何在torch中同样以表格形式打印呢?这里使用的是torchsummary库来进行实现。
废话少说,上实例
在这里模型构建我就不再赘述,这里我是用torch搭建的Yolo v3的darknet53网络。如果当前环境下没有torchsummary的话,可以通过pip/pip3来进行便捷安装。
pip install torchsummary
torch的查看模型具体代码
# 导入相关包
from torchsummary import summary
import torch
# 定义darknet53模型
class Darknet(torch.nn.Module):
def __init__(self):
super().__init__()
...
def forward(self):
...
# 定义模型,并打印模型结构
model=Darknet()
# yolo3的输入的一个样本的size是(3,416,416),这里需要注意一点,如果模型中的forward设定了CUDA的参数,则要保证在summary中的device选择与model中的CUDA是一样的,否则会报错。
summary(model,(3,416,416),device='cpu')
上述代码运行的部分结果如下:
这是summary模型结果。不过这么显示有个弊端:类似darknet53的模型,其中是由多个Sequential串联的卷积块组成,summary没有相关表示。在torch的默认显示下就体现这些层次关系:
最后嘱咐一句
如果想对模型各层进行统计,可以简单修改一下torchsummary的源文件。
可以看出summary默认是将summary注释掉的。如果想要将表格的内容保存下来或者统计,可以把注释去掉,并将summary赋值到某一个变量中。