首先我们把预训练模型load下来:
pretrained_dict = torch.load('/content/drive/MyDrive/010000.tar')
我们原来的model的参数给他搞出来:
model = NeRF()
model_dict = model.state_dict()
它现在是一个state_dict类型,也是我们经常save的一个模型类型,我们load出来的也是这个类型
接下来我们把load出来的与这个model的进行对比
pretrained_dict = {key: value for key, value in pretrained_dict['network_fn_state_dict'].items() if (key in model_dict and 'Prediction' not in key)}
对比出来了以后就更新我们的Model
model_dict.update(pretrained_dict)
最后!这个类型格式是不能直接当成model的,所以咱需要:
model.load_state_dict(model_dict)
load_state_dict一下就可以啦