给定一个预训练模型,如果你对模型结构做了一定的修改,那么可以只加载未改变的模型参数,从而加快模型的训练。代码如下:
pretrained_dict = ‘…….pkl’#预训练模型参数保存地址
model_dict = model.state_dict() #自己的模型参数变量
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#去除一些不需要的参数
model_dict.update(pretrained_dict)#参数更新
model.load_state_dict(model_dict)#加载