在finetun时,我们经常需要对pretrain model进行裁剪,到底如何做呢,下面直接上代码
import torch
from collections import OrderedDict
import os
import torch.nn as nn
def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith('module'):
start_idx = 1
else:
start_idx = 0
new_state_dict = OrderedDict()
for k,v in state_dict.items():
name = '.'.join(k.split('.')[start_idx:])
new_state_dict[name] = v
return new_state_dict
model = MobileFaceNet1(512)
state_dict = torch.load('./model-180.pth')
new_dict=copyStateDict(state_dict)
keys=[]
for k,v in new_dict.items():
if k.startswith('arc'): #将‘arc’开头的key过滤掉,这里是要去除的层的key
continue
keys.append(k)
new_dict = {k:new_dict[k] for k in keys}
print(new_dict)
model.load_state_dict(new_dict)
torch.save(model.state_dict(),'new-model.pth')