一、
model_weight_path = ""
assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)
pre_weights = torch.load(model_weight_path, map_location=device)
# delete classifier weights
pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}
net.load_state_dict(pre_dict, strict=False)
二、
weights_dict = torch.load("path",map_location=device)
# 删除有关分类类别的权重
for k in list(weights_dict.keys()):
if "fc" in k:
del weights_dict[k]
net.load_state_dict(weights_dict, strict=False)
三、
#加载model,mymodel是自己定义好的模型
pretrainedmodel = resnet18(pretrained=True)
# mymodel =Net(...)
#读取参数
pretrained_dict = pretrainedmodel.state_dict()
model_dict = net.state_dict()
#将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
net.load_state_dict(model_dict)
四、
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
if kwargs['num_classes'] != 1000 and pretrained:
model = ResNet(BasicBlock, [2, 2, 2, 2])
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
model.fc = nn.Linear(model.fc.in_features, kwargs['num_classes'])
print('done load model')
else:
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
return model