需求: 只加载预训练模型的某些层,不加载如:分类层的参数
if opt.continue_model != '':
print(f'loading pretrained model from {opt.continue_model}')
pretrained_dict = torch.load(opt.continue_model)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'Prediction' not in k)}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
关键的地方就是pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'Prediction' not in k)}
,这里用了一个if进行参数的筛选。