原来的resnet为:
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
修改为:
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
#if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
再通过以下程序加载预训练模型:
model = resnet152(pretrained=True, **kwargs)
model_dict = model.state_dict()
state_dict = torch.load('resnet152-b121ed2d.pth')
new_state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)
或者直接尝试如下方法:
model.load_state_dict(torch.load("model.th"),strict=False)
将strict设置为False