仅作为记录,大佬请跳过。
感谢老师的示范。
fc_keys = [k for k in state_dict.keys() if "fc" in k]
for k in fc_keys:
del state_dict[k]
查看设计的网络与加载的网络权重的有没有不同的层
def load_from_pretrained(self, ckpt_path):
print(f"==============> Loading weight {ckpt_path} for fine-tuning......")
ckpt = torch.load(ckpt_path, map_location='cpu')
state_dict = ckpt
fc_keys = [k for k in state_dict.keys() if "fc" in k]
for k in fc_keys:
del state_dict[k]
from pprint import pprint
missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
print('missing_keys = ')
pprint(missing_keys)
print('unexpected_keys = ')
pprint(unexpected_keys)
print(f"=> loaded successfully '{ckpt_path}'")
print('ok')
其中,self指设计的网络