Pytorch如何保存和加载模型参数

pytorch 保存和加载模型的方法有两种:

1.保存网络的参数

import torch
#导入模块

net=Net()
#创建网络,当然还需要损失函数梯度等省略


PATH='state_dict_model.pth'
#先建立路径
torch.save(net.state_dict(),PATH)
#保存:可以是pth文件或者pt文件

model=Net()
model.load_state_dict(torch.load(PATH))
#载入保存的模型参数
model.eval()
#不启用 BatchNormalization 和 Dropout

2.保存整个网络

import torch

PATH = "entire_model.pt"
# Save
torch.save(net, PATH)

# Load
model = torch.load(PATH)
model.eval()

Remember too, that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值