总结一下:
1.加载预训练的模型
ckpt = torch.load(ckpt_path)
2.将预训练模型里面的不需要的键值对给删掉:
restore_dict = {k: v for k, v in ctpt.items() if k in model_dict and xxxxxx }
3.更新现有的model_dict
model_dict.update(restore_dict)
4.加载模型
cnn.load_state_dict(model_dict)
example:
def load_ckpt(
ckpt_path, ckpt_dict
):
restore_ckpt = {}
best_val = epoch_start = 0
if os.path.exists(args.ckpt_path):
ckpt = torch.load(ckpt_path)
for (k, v) in ckpt_dict.items():
model_dict = v.state_dict()
if k in ckpt :
for i,j in ckpt[k].items():
if "cab1.conv3" not in i and "cab2.conv3" not in i and "cab3.conv3" not in i and "cab4.conv3" not in i and "clf_conv" not in i:
restore_ckpt[i] = j
model_dict.update(restore_ckpt)
print(model_dict)
v.load_state_dict(model_dict)
best_val = ckpt.get('best_val', 0)
epoch_start = ckpt.get('epoch_start', 0)
logger.info(" Found checkpoint at {} with best_val {:.4f} at epoch {}".
format(
ckpt_path, best_val, epoch_start
))
return best_val, epoch_start