torch 中 模型的保存和恢复

保存和加载模型

'''https://pytorch-cn.readthedocs.io/zh/latest/notes/serialization/'''
import torch
import torchvision
model = torchvision.models.resnet18(pretrained=True)
"""---------------保存和加载模型参数(推荐)----------------------------"""
torch.save(model.state_dict(),"temp_file/resnet18")

model.state_dict() :
保存模型的状态字典
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.
Example:
>>> module.state_dict().keys()

model.state_dict()['conv1.weight'].shape
torch.Size([64, 3, 7, 7])
"""加载模型  和对应的参数"""
the_model = torchvision.models.resnet18()
the_model.load_state_dict(torch.load("temp_file/resnet18"))
<All keys matched successfully>

“”“保存和加载整个模型(不推荐)”“”

torch.save(the_model, PATH)
the_model = torch.load(PATH)
ps :在这种情况下,序列化的数据被绑定到特定的类和固定的目录结构,所以当在其他项目中使用时,或者在一些严重的重构器之后它可能会以各种方式break

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值