一、使用第三方工具:
torchstat:
安装:pip install torchstat
torchstat GitHub 源码页面
例子:
from torchstat import stat
model = model()
stat(model, (3, 1280, 1280))
输出:会输出模型各层网络的信息,最后进行总结统计。
ptflops:
安装:pip install ptflops
ptflops GithHub源码页面
例子:
import torchvision.models as models
import torch
from ptflops import get_model_complexity_info
with torch.cuda.device(0):
net = models.densenet161()
macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
print_per_layer_stat=True, verbose=True)
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))
输出:同样会输出模型各层的信息,最后总结统计 参数量 和 FLOPs。
注意: 使用第三方工具时, 网络中有些层可能会不支持计算。
其他工具:
- torchsummary
- thop
二、使用函数统计模型参数量:
计算模型参数量 与 可训练参数量:
def get_parameter_number(model):
total_num = sum(p.numel() for p in model.parameters())
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
result = get_parameter_number(model)
print(result['Total'],result['Trainable']) #打印参数量