pytorch 在finetune重训练时,采用torch.load()方式载入模型
经常会报错。
这里给出一种load方式,若模型中存在相同的tensor(名字和大小一致)则载入,否则对模型中tensor只做初始化处理。
代码如下:
def check_keys(model, pretrained_state_dict):
ckpt_keys = set(pretrained_state_dict.keys())
model_keys= set(model.state_dict().keys())
used_pretrained_keys = model_keys & ckpt_keys
unused_pretrained_keys = ckpt_keys-model_keys
missing_keys = model_keys - ckpt_keys
assert len(uesd_pretrained_keys) > 0, 'load None from pretrained checkpoint'
return True
def remove_prefix(state_dict, prefix):
f = lambda x:x.split(prefix,1)[-1] if x.startswith(prefix) else x
return {f(key):value for key,value in state_dict.items()}
def load_model(model,pretrained_path,load_to_cpu):
if load_to_cpu:
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc:storage)
else:
device = torch.cuda.current_device()
pretrained_dict =torch.load(pretrained_path,map_location=lambda storage, loc:storage.cuda(device))
if "state_dict" in pretrained_dict.keys():
pretrained_dict=remove_prefix(pretrained_dict['state_dict'],'module.')
else:
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
check_keys(model, pretrained_dict)
model.load_state_dict(pretrained_dict,strict=False)
return model