这里参考了一种部分模型权重加载的方法。
def initialize_weights(self):
resnet50 = models.resnet50(pretrained=True)
pretrained_dict = resnet50.state_dict()
all_params = {}
for k, v in self.resnet.state_dict().items():
if k in pretrained_dict.keys():
v = pretrained_dict[k]
all_params[k] = v
# elif '_1' in k:
# name = k.split('_1')[0] + k.split('_1')[1]
# v = pretrained_dict[name]
# all_params[k] = v
# elif '_2' in k:
# name = k.split('_2')[0] + k.split('_2')[1]
# v = pretrained_dict[name]
# all_params[k] = v
assert len(all_params.keys()) == len(self.resnet.state_dict().keys())
self.resnet.load_state_dict(all_params)
print('[INFO] initialize weights from resnet50')