load model

def load_weights(self, base_file):
        #pretrain_dict = model_zoo.load_url(model_url)
        print('Loading weights into state dict...')
        pretrain_dict = torch.load(base_file, map_location=torch.device('cpu'))
        model_dict = self.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items()
            if k in model_dict and model_dict[k].size() == v.size()
        }
        model_dict.update(pretrain_dict)
        self.load_state_dict(model_dict)
        print('Finished!')

def load_weight(new_model, model_dir):
    #当前网络权重字典
    orginal_dict = new_model.state_dict()
    #读取的网络权重字典
    weight_dict = torch.load(model_dir, map_location=torch.device('cpu'))
    for key, value in orginal_dict.items():
        for key2,vlaue2 in weight_dict.items():
            if key==key2 and value.size() == vlaue2.size():
                print("key对应且形状相同!")
                orginal_dict[key] = weight_dict[key2]
    new_model.load_state_dict(orginal_dict)
    print('load model Finished!')

按key读取已有的模型参数,继续训练,在网络结构略作修改的时候可以使用

加载模型,简单复现

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值