方法1:统计模型参数量
total = sum([param.nelement() for param in model.parameters()])
print("Number of parameter: %.2fM" % (total/1e6))
方法2:统计flops和参数量
pip install thop
from thop import profile
dummy_input = torch.randn(1, 3, 32, 32)#.to(device)
flops, params = profile(model, (dummy_input,))
print('flops: ', flops, 'params: ', params)
print('flops: %.2f M, params: %.2f M' % (flops / 1000000.0, params / 1000000.0))
问题:当网络中有自定义参数时,flops和params就很有可能漏掉那部分参数
方法3:统计macs指标和参数量
pip install ptflops
from ptflops import get_model_complexity_info
macs, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_la