checkpoint=torch.load(weight_path)
new_pth=model.state_dict()# 需要加载参数的模型# pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items() if k in new_pth}
pretrained_dict={}# 用于保存公共具有的参数for k,v in checkpoint['state_dict'].items():for kk in new_pth.keys():if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)
问题描述在PyTorch训练好模型以后,需要加载模型,加载模型代码如下ckpt = torch.load(model_path_len)model.load_state_dict(ckpt['state_dict'])结果碰到的问题为:Missing key(s) in state_dictraise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(RuntimeError: Error(s) in