code
pretrained_state = torch.load(file_path)
model_dict = net.state_dict()
pretrained_state = {k: v for k, v in pretrained_state.items() if (k in model_dict and v.shape == model_dict[k].shape)}
model_dict.update(pretrained_state)
net.load_state_dict(model_dict)
解释
- 载入预训练的参数
- 新的模型,根据名字和tensor.shape来保留预训练的参数
- 更新新模型的参数