python保存模型与参数_Pytorch - 模型和参数的保存与恢复

本文介绍了PyTorch中模型和参数的保存与恢复,包括最佳实践和不同场景的应用,如仅保存参数、保存整个模型、模型用于推断、恢复训练和分享。提供了代码示例,详细解释了如何在不同情况下加载和保存模型的状态字典、优化器状态以及检查点信息。
摘要由CSDN通过智能技术生成

模型训练后,需要保存到文件,以供测试和部署;或,继续之前的训练状态.

1. Best Practices

主要有两种模型序列化保存和加载恢复的方法.

1.1 方法 M1 - 推荐

只保存和加载恢复模型参数(model parameters):import torch

# 保存

torch.save(the_model.state_dict(), PATH)

# 恢复

the_model = TheModelClass(*args, **kwargs)

the_model.load_state_dict(torch.load(PATH))

# 该方法需要自己另导入模型的网络结构信息.

1.2 方法 M2

同时保存模型的参数和网络结构信息:import torch

# 保存

torch.save(the_model, PATH)

# 恢复

the_model = torch.load(PATH)

# 该方法保存的数据绑定着特定的 classes 和所用的确切目录结构. ‘

# 因此,再加载后经过许多重构后,可能会被打乱.

2. Stackoverflow 回答

根据应用场景,选择模型保存和加载恢复方法.

场景 C1 - 模型保存自用于推断

自己保存模型,自己恢复模型,然后,修改模型为 evaluation 模式.

这是因为,默认情况时,网络模型训练时往往有 BatchNorm 和 Dropout 网络层.# 模型保存

torch.save(model.state_dic

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值