常用的库是
thop
,它是一个用于计算PyTorch模型的FLOPs和参数数量的库。
1.安装 thop
库:
pip install thop
2 导入需要计算 FLOPs 的模型和 profile
函数:.
import torch
from torchvision.models import resnet18
from thop import profile
3. 创建模型并准备输入数据:
model = resnet18()
input_data = torch.randn(1, 28, 256, 256)
我这里是跟随训练一起写的,要调用GPU,所以是如下代码:
# 创建模型并将其放在GPU上
model = resnet18().cuda()
# 创建输入数据并将其放在GPU上
input_data = torch.randn(1, 28, 256, 256).cuda()
4.使用 profile
函数计算 FLOPs:
flops, params = profile(model, inputs=(input_data,))
print(f"FLOPs: {flops}, Params: {params}")