checkpoint=torch.load(weight_path)
new_pth=model.state_dict()# 需要加载参数的模型# pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items() if k in new_pth}
pretrained_dict={}# 用于保存公共具有的参数for k,v in checkpoint['state_dict'].items():for kk in new_pth.keys():if kk in k:
pretrained_dict[kk]=v
break
new_pth.update(pretrained_dict)
model.load_state_dict(new_pth)