随着现在模型越来越大,一次性训练完模型在低算力平台也越来越难以实现,因此很有必要在训练过程中保存模型,以便下次之前训练的基础上进行继续训练,节约时间。代码如下:
导包
import torch
from torch import nn
import numpy as np
定义模型
定义一个三层的MLP分类模型
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(64, 32)
self.linear1 = nn.Linear(32, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.linear(x)
x = self.relu(x)
x = self.linear1(x)
return x
## 随机生成2组带标签的数据
rand1 = torch.rand((100, 64)).to(torch.float)
label1 = np.random.randint(0, 10, size=100)
label1 = torch.from_numpy(label1).to(torch.long)
rand2 = torch.rand((100, 64)).to(torch.float)
label2 = np.random.randint(0, 10, size=100)
label2 = torch.from_numpy(label2).to(torch.long)
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss = nn.CrossEntropyLoss()
## 训练10个epoch
epoch = 10
for i in range(epoch):
output = model(rand1)
my_loss = loss(output, label1)
optimizer.zero_grad()
my_loss.backward()
optimizer.step()
print("epoch:{} loss:{}".format(i, my_loss))
结果如下:记下这些loss值,观察下次继续训练的初始loss
epoch:0 loss:2.3494179248809814
epoch:1 loss:2.287858009338379
epoch:2 loss:2.2486231327056885
epoch:3 loss:2.2189149856567383
epoch:4 loss:2.193182945251465
epoch:5 loss:2.167125940322876
epoch:6 loss:2.140075206756592
epoch:7 loss:2.1100614070892334
epoch:8 loss:2.0764594078063965
epoch:9 loss:2.0402779579162598
模型保存
采用torch.save函数保存模型,一般分为两种模式,分别是简单的保存所有参数,第二种是保存各部分参数,到一个字典结构里面。
# 保存模型的整体参数
save_path = r'model_para/'
torch.save(model, save_path+'model_full.pth')
保存模型参数,优化器参数和epoch情况。
def save_model(save_path, epoch, optimizer, model):
torch.save({'epoch': epoch+1,
'optimizer_dict': optimizer.state_dict(),
'model_dict': model.state_dict()},
save_path)
print("model save success")
save_model(save_path+'model_dict.pth',epoch, optimizer, model)
加载模型
对于保存的pth参数文件,使用torch.load进行加载,代码如下:
def load_model(save_name, optimizer, model):
model_data = torch.load(save_name)
model.load_state_dict(model_data['model_dict'])
optimizer.load_state_dict(model_data['optimizer_dict'])
print("model load success")
观察当前训练模型的权重参数
print(model.state_dict()['linear.weight'])
tensor([[-0.0215, 0.0299, -0.0255, ..., -0.0997, -0.0899, 0.0499],
[-0.0113, -0.0974, 0.1020, ..., 0.0874, -0.0744, 0.0801],
[ 0.0471, 0.1373, 0.0069, ..., -0.0573, -0.0199, -0.0654],
...,
[ 0.0693, 0.1900, 0.0013, ..., -0.0348, 0.1541, 0.1372],
[ 0.1672, -0.0086, 0.0189, ..., 0.0926, 0.1545, 0.0934],
[-0.0773, 0.0645, -0.1544, ..., -0.1130, 0.0213, -0.0613]])
命名一个新模型,加载之前保存的参数文件,并打印出层参数
new_model = MyModel()
new_optimizer = torch.optim.Adam(new_model.parameters(), lr=0.01)
load_model(save_path+'model_dict.pth', new_optimizer, new_model)
print(new_model.state_dict()['linear.weight'])
可以看出新模型和当前模型的参数一致,说明参数加载成功。
model load success
tensor([[-0.0215, 0.0299, -0.0255, ..., -0.0997, -0.0899, 0.0499],
[-0.0113, -0.0974, 0.1020, ..., 0.0874, -0.0744, 0.0801],
[ 0.0471, 0.1373, 0.0069, ..., -0.0573, -0.0199, -0.0654],
...,
[ 0.0693, 0.1900, 0.0013, ..., -0.0348, 0.1541, 0.1372],
[ 0.1672, -0.0086, 0.0189, ..., 0.0926, 0.1545, 0.0934],
[-0.0773, 0.0645, -0.1544, ..., -0.1130, 0.0213, -0.0613]])
继续训练
在新模型加载原来模型参数的基础上,继续训练,观察loss值,是在之前训练的最终loss,继续下降,说明模型继续训练成功。
epoch = 10
for i in range(epoch):
output = new_model(rand1)
my_loss = loss(output, label1)
new_optimizer.zero_grad()
my_loss.backward()
new_optimizer.step()
print("epoch:{} loss:{}".format(i, my_loss))
epoch:0 loss:2.0036799907684326
epoch:1 loss:1.965193271636963
epoch:2 loss:1.924098253250122
epoch:3 loss:1.881495714187622
epoch:4 loss:1.835693359375
epoch:5 loss:1.7865667343139648
epoch:6 loss:1.7352293729782104
epoch:7 loss:1.6832704544067383
epoch:8 loss:1.6308385133743286
epoch:9 loss:1.5763107538223267
数据分布不一致带来的问题
同样,在这里我发现一个问题,因为之前随机产生了2组数据,之前模型训练使用的rand1,这里只有继续训练rand1,之前模型的参数才有效,如果使用rand2,模型相当于从0训练(如下loss),这是因为,两组数据都是随机生成的,数据分布几乎不一样,所以上一组数据训练的模型在第二组数据几乎无效。
epoch:0 loss:2.523787498474121
epoch:1 loss:2.469816207885742
epoch:2 loss:2.4141526222229004
epoch:3 loss:2.379054069519043
epoch:4 loss:2.3563807010650635
epoch:5 loss:2.319946765899658
epoch:6 loss:2.271805763244629
epoch:7 loss:2.2274367809295654
epoch:8 loss:2.186885118484497
epoch:9 loss:2.144239902496338
但是在真实情况中,由于batch数据都是假设同一分布,所以不用考虑这个问题,
那么以上,就完成了pytorch的模型保存,加载和继续训练的三种重要过程,希望能够帮到您!!!
祝您训练愉快。