之前需要计算模型的浮点操作数和参数量,发现thop对自己设计的模型有时候会漏算,发现还可以用ptflops去计算
from ptflops import get_model_complexity_info
flops, params = get_model_complexity_info(model, (3,640,640), as_strings=True, print_per_layer_stat=True)
print('FLOPs: ', flops)
print('Params: ', params)
from thop import profile
input = torch.randn(1, 3, 640, 640).to(device) # 模型输入的形状,batch_size=1
flops, params = profile(model, inputs=(input,))
print(flops / 1e9, 'G', params / 1e6, 'M') # flops单位G,para单位M