原报错代码:
from torchvision.models import resnet50
import torch
import torchvision.models as models
# import torch
from ptflops import get_model_complexity_info
# model = models.resnet50() #调用官方的模型,
checkpoints = '自己模型的path'
model = torch.load(checkpoints)
model_name = 'yolov3 cut'
flops, params = get_model_complexity_info(model, (3,320,320),as_strings=True,print_per_layer_stat=True)
print("%s |%s |%s" % (model_name,flops,params))
解决方案:
因为我保存的模型是只有参数的,要使用load_state_dict函数来加载。然后为什么加[‘state_dict’]可以参考我前阵子一篇文章“Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict: “module.conv0.weight”,里面有详细说明。完整代码可以见前几篇文章“使用pytorch计算模型的参数量和计算量”。
model.load_state_dict(torch.load("C:\\Users\\83543\\Desktop\\model_best.pth.tar")['state_dict'])