功能:
查看模型信息,便于调试
主要参数
- model: pytorch模型
- input_size: 模型输入size
- batch_size:batch size
- device: “cuda” or “cpu”
安装命令:
pip install torchsummary
示例
from torchsummary import summary
from model.lenet import LeNet
lenet = LeNet(classes=2)
print(summary(lenet, (3, 32, 32), device="cpu"))
输出结果
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 6, 28, 28] 456
Conv2d-2 [-1, 16, 10, 10] 2,416
Linear-3 [-1, 120] 48,120
Linear-4 [-1, 84] 10,164
Linear-5 [-1, 2] 170
================================================================
Total params: 61,326
Trainable params: 61,326
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.05
Params size (MB): 0.23
Estimated Total Size (MB): 0.30
----------------------------------------------------------------
None