Pytorch模型保存与加载

本文详细介绍了PyTorch中模型的序列化与反序列化,包括torch.save和torch.load的使用,以及模型的保存与加载方法。重点讲解了如何进行断点续训,需要保存模型状态字典、优化器状态字典和当前训练轮数,以便在训练中断后恢复。此外,还简要提到了模型构建的基本步骤。
摘要由CSDN通过智能技术生成

序列化与反序列化

模型训练时,模型是在内存中的,而内存中的数据不具备长久性的存储功能,因此需要把模型从内存中搬到硬盘中进行长久的存储。
序列化与反序列化主要指内存与硬盘之间的数据转换关系,模型在内存中是以一个对象形式存储的,但是在内存当中对象不能长久的保存,因此需要保存到硬盘中。而在硬盘中,是以二进制的形式进行保存的,即二进制序列。
因此序列化是指将内存中的某一个对象保存到硬盘中,以二进制序列的形式存储。对应于Pytorch中的模型,可以理解为将模型转换成二进制的数存储到硬盘中进行长久的存储。
在这里插入图片描述
反序列化是指将我们存储的二进制数再反序列化的放到内存中,得到一个对象,这样就可以使用模型了。
在这里插入图片描述
序列化与反序列化的目的是将数据/模型可以长久的保存。

Pytorch中的序列化与反序列化

torch.save

主要参数:

  • obj:对象
  • f:输出路径

torch.load

主要参数:

  • f:文件路径
  • map_location:指定存放位置,cpu or gpu

模型保存与加载的两种方式

保存

方法一:保存整个Module

torch.save(net, path)

方法二:保存模型参数

state_dict = net.state_dict()  #字典类型
torch.save(state_dict, path)

state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用;
load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用。
【说明】 state_dict是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项。
如何仅加载某一层的训练到的参数(某一层的state)?

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load("./model_state_dict.pt")['conv1.weight']

保存模型是为了下次继续使用模型,什么东西是在模型训练之后才得到的?即一系列的可学习参数。
将这些模型训练得到的参数保存下来,下一次构建模型时,再把这些参数放回到模型中,完成模型的保存与加载。

举例

# LeNet2为一个模型

# 模型初始化
net = LeNet2(classes=2019)
#模拟训练
print("训练前:", net.features[0].weight[0, ...]) #第一个卷积核的第一个参数
net.initialize() #模拟模型中的参数改变了
print("训练后:", net.features[0].weight[0, ...])

path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"

# 保存整个模型
torch.save(net, path_model)

#保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)

# 模型加载---load net
path_model = "./model.pkl"
net_load = torch.load(path_model)

# 模型加载---load state_dict
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)
# state_dict_load为字典类型
# state_dict_load.keys()为参数的名称吧
print(state_dict_load.keys()) # odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', ……])
#加载整个state_dict, 还需要将它放到一个模型中,这样才算完成一个模型的重新加载,所以通常需要再重新构建一个模型,这个模型里面的参数可以不用管,可以通过load_state_dict()这个方法将加载进来的state_dict_load字典放到新的网络中,这样这个网络就与之前保存下来的网络一样了
net_new = LeNet2(classes = 2019)
print("加载前:", net_new.features[0].weight[0,...])
net_new.load_state_dict(state_dict_load)
print("加载后:", net_new.features[0].weight[0,...])

模型断点续训练

在这里插入图片描述
断点续训练需要保存哪些数据呢?
训练的过程只需要四个模块的东西:数据、模型、损失函数、优化器。那这四个东西当中,哪些东西会随着模型的迭代训练而变化呢?其实,只有模型和优化器随着迭代不断的变化,数据是不变的,而损失函数只是一个函数,里面也没有可变的参数。而模型当中的权值和可学习参数是会变化的,以及优化器当中也有数据会变化,例如优化器当中会有一些buffer缓存会发生变化。

checkpoint = {
			"model_state_dict": net.state_dict(),
			"optimizer_state_dict": optimizer.state_dict(),
			"epoch": epoch
}

模拟意外中断,然后进行续训练:

set_seed(1) # 设置随机种子
rmb_label = {"1":0, "100":1}

#参数设置
checkpoint_interval = 5 #每隔5个epoch就保存一下
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1

#中间一些代码省略

if (epoch+1) % checkpoint_interval == 0:
	checkpoint = {
				"model_state_dict": net.state_dict(),
				"optimizer_state_dict":optimizer.state_dict(),
				"epoch":epoch}
	path_checkpoint = './checkpoint_{}_epoch.pkl'.format(epoch)
	torch.save(checkpoint, path_checkpoint)

if epoch>5:
	print('训练意外中断……')
	break


续训练,前面四个步骤(数据、模型、损失函数、优化器)都是一样的,不需要改动,只需要构建好就行。主要是在训练的时候,把这些已经训练好的数据再加载到所对应的地方,需要进行的操作如下:

#优化器(与之前一样)
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9) #选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1) #设置学习率下降策略

#断点恢复
path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)
net.load_state_dict(checkpoint['model_state_dict']) #模型数据更新
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) #优化器数据更新
start_epoch = checkpoint['epoch'] #设置起始epoch
scheduler.last_epoch = start_epoch #需要注意!学习率也需要修改last_epoch。last_epoch是指上一个迭代的次数

以上就完成了断点恢复,下面就可以接着训练了。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

番外篇

网络模型创建步骤

机器学习模型训练步骤:
在这里插入图片描述

构建网络层:例如卷积神经网络中的卷积层、池化层、全连接层等。
(首先要构建这些子模块网络层,将这些网络层构建完成后,再把它们按照一定的顺序、一定的拓扑结构进行拼接,拼接完成后,就构成了复杂的神经网络)
在这里插入图片描述
【举例】
在这里插入图片描述
在这里插入图片描述
模型构建两要素:

  • 构建子模块 ( init() )
  • 拼接子模块 ( forward() )
class LeNet(nn.Module):
	def __init__(self, classes):
		super(LeNet, self).__init__()
		self.conv1 = nn.Conv2d(3, 6, 5)
		self.conv2 = nn.Conv2d(6, 16, 5)
		self.fc1 = nn.Linear(16*5*5, 120)
		self.fc2 = nn.Linear(120, 84)
		self.fc3 = nn.Linear(84, classes)
	
	def forward(self, x):
		out = F.relu(self.conv1(x))
		out = F.max_pool2d(out, 2)
		out = F.relu(self.conv2(out))
		out = F.max_pool2d(out, 2)

nn.Module属性

所有的模型以及所有的网络层都继承nn.Module类。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值