最开始
model.load_state_dict(torch.load(model_path), strict=True)
加载模型报错如下
Missing key(s) in state_dict: "conv_first.weight", "conv_first.bias", ....
Unexpected key(s) in state_dict: "params", "params_ema"
训练好的模型键值有许多,只需要模型参数文件,我的参数文件在params_ema内
修改如下即可
state_dict = torch.load(model_path)
model.load_state_dict(state_dict['params_ema'], strict=True)