1.安装torchsummary
Windows+r,打开cmd命令,使用pip下载安装包
pip install torchsummary
torchsummary:能够查看模型的输入和输出的形状,可以更加清楚地输出模型的结构
下面是torchsummary的结构:
torchsummary.summary(model, input_size, batch_size=-1, device="cuda")
功能:查看模型的信息,便于调试
1.model: pytorch 模型,必须继承自 nn.Module
2.input_size: 模型输入 size,形状为 C,H ,W
3.batch_size:batch_size,默认为 -1,在展示模型每层输出的形状时显示的batch_size
4.device:“cuda"或者"cpu”
5.使用时需要注意,默认device=‘cuda’,如果是在‘cpu’,那么就需要更改。不匹配就会出现下面的错误:
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
2.输出网格结构
完成以上步骤后,进入自己的python环境,运行代码如下:
from torchsummary import summary # 导入summary
from torchvision.models import vgg16 # 导入vgg16,以 vgg16 为例
model = vgg16() # 实例化网络,可以换成自己的网络
summary(model, (3, 64, 64)) # 输出网络结构
from torchsummary import summary
summary(model, input_size=(channels, H, W))
torchsummary的使用基于下述核心API,只要提供给summary函数模型以及输入的size就可以了。来源
运行代码,结果如下:
经上图可见,torchsummary可以查看每一层的网格参数量,网络模型大小等信息。