pytorch入门——模型保存与加载

转载自这位小姐姐:https://www.jianshu.com/p/55a5e38a6dd2

模型保存与加载

利用PyTorch可以很方便的进行模型的保存和加载,主要有以下两种方式。

方法1:保存加载整个模型

# save model
torch.save(model,'mymodel.pkl')
# load model
model=torch.load('mymodel.pkl')

方法2:仅保存加载模型参数(推荐)

# save model parameters
torch.save(model.state_dict(), 'mymodel.pkl')
# load save model parameters
model_object.load_state_dict(torch.load('mymodel.pkl'))

可以想一下为什么会推荐第二种方式呢?
相比较于保存整个模型而言,仅保存模型参数的做法应该不仅节省空间,更有灵活性的优势。
可以取出特定层的参数,这一点在已经训练好的模型上取与现有模型相同层的参数上应该有帮助。


方法3:加载别的模型中相同的网络参数至新的模型

这个方法在科研上还是很有帮助的,可以用已经训练好的网络参数作为自己模型的网络权重的初始化。
先给出函数代码:

def transfer_weights(model_from, model_to):
    wf = copy.deepcopy(model_from.state_dict())
    wt = model_to.state_dict()
    for k in wt.keys() :
        if (not k in wf)):      
            wf[k] = wt[k]
    model_to.load_state_dict(wf)

以上就实现了从model_frommodel to的相同网络参数的拷贝。

来分析一下程序:

  • wf实现了对 model from中的模型参数的深度拷贝;
  • wt实现了对 model to模型参数的获取;
  • 下面那段for循环就是实现了如果在model to中出现的网络结构,但是在model from中没有出现,那么就拷贝一份给wf。这样做的目的是让wf扩充后的结构跟wt一样,即保留了model from中的模型参数,又将结构扩充到跟 model to的一样。
  • 这样最后一条语句就直接可以通过load_state_dict函数加载我们想要的模型参数到目标模型model to中了。

补充说明:

以上的函数要求两个模型中如果具有相同的名字,那么对应的参数带下应该是一样的。

那么如果出现模型结构名字一样,但是参数大小不一样的情况呢?

例如两个都有fc层,那么要求fc层的参数是一样的。我自己在做的时候刚好就是这种情况。
除了最后一个fc层,其他的结构都是一样的,model_form的fc是[2048,400],而model_to的fc是[2048,101]
所以就会出现如下的错误:
RuntimeError: While copying the parameter named fc.weight, whose dimensions in the model are torch.Size([101, 2048]) and whose dimensions in the checkpoint are torch.Size([400, 2048]).
就是参数维度不匹配啦。
这个时候的做法是让model_from中的该层维度跟model_to一样,在代码上的体现就是:

def transfer_weights(model_from, model_to):
    wf = copy.deepcopy(model_from.state_dict())
    wt = model_to.state_dict()
    for k in wt.keys() :
        #if (not k in wf)):     
        if ((not k in wf) | (k=='fc.weight') | (k=='fc.bias')):     
            wf[k] = wt[k]
    model_to.load_state_dict(wf)

 



作者:与阳光共进早餐
链接:https://www.jianshu.com/p/55a5e38a6dd2
來源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值