【PyTorch框架】模型保存、加载与断点续训练

一、模型保存与加载 Saving & Loading Model

模型的保存与加载,也可以称之为序列化与反序列化。

1. 原因

训练好的模型是为了以后可以更方便的使用它,训练好的模型是被存储在内存当中的,而内存中数据一般不具有这种长久性的存储的功能,但硬盘可以长期的存储数据,所以在训练好模型之后,我们需要将模型从内存中转移到硬盘上进行长期存储。

2. 序列化与反序列化

序列化与反序列化主要描述的是内存与硬盘之间的一个转换关系,训练好的模型在内存中是以对象的形式存储的,在硬盘中则是以二进制序列的形式进行存储的。

  • 序列化(pickling):将模型对象转换成二进制的数据并存储在硬盘中的过程。
  • 反序列化(unpickling):将存储在硬盘中的二进制序列化数据再重新以模型对象的形式存储至内存中的过程。

二者相互对应的两个操作如下:
在这里插入图片描述

3. PyTorch序列化与反序列化

  • 序列化:
// 序列化
// 主要参数:obj-对象,f-输出路径
torch.save

说明:
对象:想要保存的数据,如模型、数据等;
输出路径:指定硬盘的路径。

  • 反序列化:
// 反序列化
// 主要参数:f-文件路径,map-location:指定存放位置,cpu or gpu
torch.load

说明:
map-location:GPU保存的模型时,不能直接load,需要设置map-location;CPU保存时直接load。

4. 模型保存

  • Module的数据结构
    在这里插入图片描述
    说明: Module中有8个有序字典去管理一系列参数,还有一些属性。保存模型的目的是下一次可以继续使用,模型训练之后得到的参数是一系列的可学习参数,而另一种方法就是只保存可学习的参数,即训练之后得到的一些参数,下一次构建模型后可以把保存的可学习参数加载进新模型中,这就完成了模型的保存与加载。
  • 保存整个Module
// 保存整个Module
torch.save(net,path)

优点: 保存整个net,不需要考虑该保存哪些参数
缺点: 占内存,耗时

  • 保存模型参数
// state_dict()保存模型中可学习参数,返回字典形式
state_dict = net.state_dict()
torch.save(state_dict,path)

说明: 官方推荐方法。

  • 实例演示:
    运行 model_save.py文件
# -*- coding: utf-8 -*-
"""
# @brief      : 模型的保存
"""
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools2 import set_seed


class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def initialize(self):
        for p in self.parameters():
            p.data.fill_(20191104)


net = LeNet2(classes=2019)

# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])

# 设置保存整个模型和保存模型参数的路径
path_model = "./model.pkl"
path_state_dict = "./model_state_dict.pkl"

# 保存整个模型
torch.save(net, path_model)

# 保存模型参数,调用state_dict()方法获取模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)


运行后:
保存了两个文件,一个是保存整个模型;一个是保存模型中的可学习参数。
在这里插入图片描述

5. 模型加载

  • 加载整个模型
# 加载整个模型
path_model = "./model.pkl"
net_load = torch.load(path_model)

print(net_load)
  • 加载模型参数
# 加载模型参数
path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)

print(state_dict_load.keys())

net_new = LeNet2(classes=2019)
net_new.load_state_dict(state_dict_load)
  • 实例演示:
    运行 model_load.py文件
# -*- coding: utf-8 -*-
"""
# @file name  : model_load.py
# @brief      : 模型的加载
"""
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed


class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def initialize(self):
        for p in self.parameters():
            p.data.fill_(20191104)


# ================================== load net ===========================
# flag = 1
flag = 0
if flag:
    # 读取路径,并加载保存的整个模型
    path_model = "./model.pkl"
    net_load = torch.load(path_model)

    print(net_load)

# ================================== load state_dict ===========================

flag = 1
# flag = 0
if flag:

    path_state_dict = "./model_state_dict.pkl"
    state_dict_load = torch.load(path_state_dict)

    print(state_dict_load.keys())

# ================================== update state_dict ===========================
flag = 1
# flag = 0
if flag:
    # 需要新建一个与保存的模型(参数)时结构一样的模型(LeNet2)
    net_new = LeNet2(classes=2019)
    # 获取的是模型参数字典的键值(如’features.0.weight’,‘features.0.bias’…等)
    print("加载前: ", net_new.features[0].weight[0, ...])
    # 加载state_dict_load,放到新模型里
    net_new.load_state_dict(state_dict_load)
    print("加载后: ", net_new.features[0].weight[0, ...])

运行后:

  • 加载模型:(可以单步运行打印出模型的结构。)
    在这里插入图片描述

  • 打印键值:
    在这里插入图片描述
    在这里插入图片描述
    说明:
    features.0、3-第一、二个卷积层的权值和偏置;
    classifier.0-全连接层的权值和偏置。

  • 打印“加载前”和“加载后”的模型参数:(可以Debug查看)在这里插入图片描述
    说明: 保存模型的卷积层的weight全部为20191104,证明模型成功保存并加载。

二、模型段点续训练

1. 原因

因为某种原因如断电、模型大等问题,导致模型训练意外终止。模型断电续训练能够保证模型训练中断之后可以接着(中断点)这个checkpoint继续训练,而不需要从头训练。因此,需要在模型训练过程中保存模型参数。
在这里插入图片描述

2. 模型保存的参数

训练过程中数据和损失函数是不变的,模型和优化器(动量优化器要利用之前的信息不断更新当前值)中的参数会随着迭代而不断变化。因此需要保存的参数有:模型中的参数、优化器中的参数、Epoch(或迭代次数)。

在这里插入图片描述

  • 保存参数的代码片段:
# 需要保存的模型参数
checkpoint = {
    "model_state_dict": net.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(), 
    "epoch": epoch
    path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)
}

说明: checkpoint不要写在iteration里面,要写在epoch循环中。

  • 实例演示:
    运行 save_checkpoint.py文件
# -*- coding: utf-8 -*-
"""
# @file name  : save_checkpoint.py
# @brief      : 模拟训练意外停止
"""
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from PIL import Image
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed
import torchvision


set_seed(1)  # 设置随机种子
rmb_label = {"1": 0, "100": 1}

# 参数设置
checkpoint_interval = 5  #  每隔5个EPOCH保存一下
MAX_EPOCH = 10  #总共10次个EPOCH
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1


# ============================ step 1/5 数据 ============================

split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.8),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) #shuffle=False表示数据没有打乱,建议为True
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)     # 设置学习率下降策略

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

start_epoch = -1
for epoch in range(start_epoch+1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率

    if (epoch+1) % checkpoint_interval == 0:

        checkpoint = {"model_state_dict": net.state_dict(),
                      "optimizer_state_dict": optimizer.state_dict(),
                      "epoch": epoch}
        path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)
    # 如果epoch在第5个之后中断了
    if epoch > 5:
        print("训练意外中断...")
        break

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss.item())
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val/len(valid_loader), correct / total))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

运行后:

  • 中断在第6个epoch,存储的参数在第5个epoch:
    解释: epoch从0开始,所以第5个epoch=checkpoint_4_epoch.pkl;中断在第6个epoch=5,所以续训练的开始epoch=5。
    在这里插入图片描述
  • 第5个epoch的保存文件:
    在这里插入图片描述

3. 断点续训练

在续训练的代码中,构建的模型、数据、损失函数、优化器都是一样的,只需要把训练好的数据加载到对应位置。

  • 保存参数的代码片段:
# 断点恢复
path_checkpoint = "./checkpoint_4_epoch.pkl"#恢复的文件路径
checkpoint = torch.load(path_checkpoint)#load文件

net.load_state_dict(checkpoint['model_state_dict'])#恢复模型参数

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])#恢复优化器参数

start_epoch = checkpoint['epoch']#设置要恢复的epoch

scheduler.last_epoch = start_epoch#设置学习率
  • 实例演示:
    运行 save_resume.py文件
# -*- coding: utf-8 -*-
"""
# @file name  : checkpoint_resume.py
# @brief      : 模拟训练意外停止
"""
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from PIL import Image
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed
import torchvision


set_seed(1)  # 设置随机种子
rmb_label = {"1": 0, "100": 1}

# 参数设置
checkpoint_interval = 5
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1


# ============================ step 1/5 数据 ============================

split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.8),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) #shuffle=False表示数据没有打乱,建议为True
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)     # 设置学习率下降策略


# ============================ step 5+/5 断点恢复 ============================

path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

start_epoch = checkpoint['epoch']

scheduler.last_epoch = start_epoch

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(start_epoch + 1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率

    if (epoch+1) % checkpoint_interval == 0:

        checkpoint = {"model_state_dict": net.state_dict(),
                      "optimizer_state_dic": optimizer.state_dict(),
                      "loss": loss,
                      "epoch": epoch}
        path_checkpoint = "./checkpint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)

    # if epoch > 5:
    #     print("训练意外中断...")
    #     break

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss.item())
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val/len(valid_loader), correct / total))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

运行后:

  • 加载保存的参数文件checkpoint_4_epoch.pkl,从第6个epoch开始训练:
    解释: epoch从0开始,所以第5个epoch=checkpoint_4_epoch.pkl。
    在这里插入图片描述
    在这里插入图片描述
  • loss曲线的对比:
    在这里插入图片描述在这里插入图片描述
    说明: 断点前的数据加载器DataLoder中的SHUFFLE=False,数据没有打乱,影响了优化,导致损失函数偏高在3.5左右。
    建议: 在save_checkpoint.py和checkpoint_resume.py中的DataLoder都改为SHUFFLE=True。

三、参考

[1]【深度之眼】【Pytorch打卡第15天】:模型保存与加载
[2] Pytorch系列之——模型保存与加载、finetune
[3] [二十五]深度学习Pytorch-模型保存与加载、断点续训练
[4]【深度之眼】Pytorch框架班第五期-模型保存与加载代码解析
[5] 07-01-模型保存与加载

  • 2
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值