(Pytorch)模型的保存与加载

本文介绍了如何保存和加载训练好的深度学习模型。方法一和二分别展示了完整模型和参数的保存与加载,而方法三则展示了如何将已训练模型的参数转移到新模型中作为初始化权重。这些技巧对于模型的复用和迁移学习至关重要。
摘要由CSDN通过智能技术生成

对于一个已经训练完毕的网络,保存模型可以方便后续直接使用。

再使用时,通过加载方式即可。

方法一、保存、加载整个模型

#mymodel.pkl是生成的文件名
#save
torch.save(model.dtate_dict(),'mymodel.pkl')  #
#load
model=torch.load('mymodel.pkl')    #

方法二、保存、加载模型的参数♥

只保存模型的参数,节省空间,在后续加载时可以去除特定层的参数

#save 
torch.save(model.state_dict(),'mymodel.pkl')
#load
model_object.load_state_dict(torch.load('mymodel.pkl'))   #model_object是新文件中新定义的模型名

例如:

方法三、把 别的模型中相同的网络参数 加载到 新的模型中

可以用已经训练好的网络参数作为自己模型的网络权重的初始化

如下,实现了从model_frommodel to的相同网络参数的拷贝:

def transfer_weights(model_from, model_to):
    wf = copy.deepcopy(model_from.state_dict())  #对model_from 中的模型参数的深层拷贝
    wt = model_to.state_dict()  #获取model_to模型参数
    #for循环的目的是让wf扩充后的结构跟wt一样,即保留model_from的模型参数,又将结构扩充到和model_to的一样。
    for k in wt.keys() :  #若在model_to中出现的网络结构,但在model_from中没有出现,则拷贝一份给wf.
        if (not k in wf)):      
            wf[k] = wt[k]
    model_to.load_state_dict(wf)  #load_state_dict函数加载想要的模型参数到目标模型model_to中

 

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值