计算量:
FLOPS,浮点运算次数,指运行一次网络模型需要进行浮点运算的次数。
参数量:
Params,是指网络模型中需要训练的参数总数。
第一步:安装模块(thop)
pip install thop
第二步:计算
import torch
from thop import profile
net = Model() # 定义好的网络模型
inputs = torch.randn(1, 3, 112, 112)
flops, params = profile(net, (inputs,))
print('flops: ', flops, 'params: ', params)
注意:
- 输入input的第一维度是批量(batch size),批量的大小不回影响参数量, 计算量是batch_size=1的倍数
- profile(net, (inputs,))的 (inputs,)中必须加上逗号,否者会报错