在pytorch环境下,有两个计算FLOPs和参数量的包thop和ptflops,结果基本是一致的。
thop
参考https://github.com/Lyken17/pytorch-OpCounter
安装方法:pip install thop
使用方法:
from torchvision.models import resnet18
from thop import profile
model = resnet18()
input = torch.randn(1, 3, 224, 224) #模型输入的形状,batch_size=1
flops, params = profile(model, inputs=(input, ))
print(flops/1e9,par