今天介绍的工具是torchsummary,可以用来统计PyTorch每层的参数情况。一来可以用于参数剪枝优化,二来可以用于了解模型的参数分布。
安装:
pip install torchsummary
源代码链接:https://github.com/sksq96/pytorch-summary
使用:
from torchvision.models.alexnet import alexnet
from torchsummary import summary
model = alexnet(pretrained=True).eval().cuda()
summary(model, input_size=(3, 224, 224), batch_size=-1)
效果:
结束~