PyTorch强化:05.PyTorch 保存和加载模型

本文详细介绍了PyTorch中模型的保存和加载,包括state_dict的概念,如何保存和加载推理模型,保存和加载完整模型,以及在不同设备间保存和加载。重点讲述了state_dict的使用,强调了在加载模型后调用model.eval()进行推理的重要性。
摘要由CSDN通过智能技术生成

 

当保存和加载模型时,需要熟悉三个核心功能:

  1. torch.save:将序列化对象保存到磁盘。此函数使用Python的pickle模块进行序列化。使用此函数可以保存如模型、tensor、字典等各种对象。
  2. torch.load:使用pickle的unpickling功能将pickle对象文件反序列化到内存。此功能还可以有助于设备加载数据。
  3. torch.nn.Module.load_state_dict:使用反序列化函数 state_dict 来加载模型的参数字典。

1.什么是状态字典:state_dict?

在PyTorch中,torch.nn.Module模型的可学习参数(即权重和偏差)包含在模型的参数中,(使用model.parameters()可以进行访问)。state_dict是Python字典对象,它将每一层映射到其参数张量。注意,只有具有可学习参数的层(如卷积层,线性层等)的模型 才具有state_dict这一项。目标优化torch.optim也有state_dict属性,它包含有关优化器的状态信息,以及使用的超参数。

因为state_dict的对象是Python字典,所以它们可以很容易的保存、更新、修改和恢复,为PyTorch模型和优化器添加了大量模块。

下面通过从简单模型训练一个分类器中来了解一下state_dict的使用。

# 定义模型
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
 self.conv1 = nn.Conv2d(3, 6, 5)
 self.pool = nn.MaxPool2d(2, 2)
 self.conv2 = nn.Conv2d(6, 16, 5)
 self.fc1 = nn.Linear(16 * 5 * 5, 120)
 self.fc2 = nn.Linear(120, 84)
 self.fc3 = nn.Linear(84, 10)
def forward(self, x):
 x = self.pool(F.relu(self.conv1(x)))
 x = self.pool(F.relu(self.conv2(x)))
 x = x.view(-1, 16 * 5 * 5)
 x = F.relu(self.fc1(x))
 x = F.relu(self.fc2(x))
 x = self.fc3(x)
return x
# 初始化模型
model = TheModelClass()
# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 打印模型的状态字典
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# 打印优化器的状态字典
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值