加载预训练参数的四种方法

一、


model_weight_path = ""
assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)
pre_weights = torch.load(model_weight_path, map_location=device)

 # delete classifier weights
pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}
net.load_state_dict(pre_dict, strict=False)

二、


weights_dict = torch.load("path",map_location=device)
# 删除有关分类类别的权重
for k in list(weights_dict.keys()):
   if "fc" in k:
      del weights_dict[k]
net.load_state_dict(weights_dict, strict=False)

三、


#加载model,mymodel是自己定义好的模型
    pretrainedmodel = resnet18(pretrained=True) 
    # mymodel =Net(...) 
    
    #读取参数 
    pretrained_dict = pretrainedmodel.state_dict() 
    model_dict = net.state_dict() 
    
    #将pretrained_dict里不属于model_dict的键剔除掉 
    pretrained_dict =  {k: v for k, v in pretrained_dict.items() if k in model_dict} 
    
    # 更新现有的model_dict 
    model_dict.update(pretrained_dict) 
    
    # 加载我们真正需要的state_dict 
    net.load_state_dict(model_dict)

四、


def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if kwargs['num_classes'] != 1000 and pretrained:
        model = ResNet(BasicBlock, [2, 2, 2, 2])
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
        model.fc = nn.Linear(model.fc.in_features, kwargs['num_classes'])
        print('done load model')
    else:
        model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)

    return model
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值