捕获神经网络训练过程中的Checkpoints方法

有时候自己写的网络模型在训练时想看某个epoch的模型参数,或者想按照某一epoch的模型参数进行测试,就需要看log相关的checkpoints。

方法如下:

1.要捕获训练过程中的checkpoints权重信息,可以使用Python中的回调函数。

在 PyTorch 中,可以使用 torch.save() 函数来保存模型的状态字典。可以使用一个自定义的 checkpoint 函数来在每个 epoch 结束时保存模型的状态字典。

import torch
import os

def checkpoint(model, epoch, optimizer, loss, checkpoint_dir):
    """
    Saves a checkpoint of the model at a given epoch
    """
    state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'loss': loss
    }
    filename = os.path.join(checkpoint_dir, f'checkpoint-epoch{epoch}.pt')
    torch.save(state, filename)

# 训练过程中,在每个 epoch 结束时调用 checkpoint 函数保存模型状态字典
for epoch in range(30):
    # 训练过程中的代码
    # ...
    if epoch > 20:
        checkpoint(model, epoch, optimizer, loss, checkpoint_dir)

使用 torch.load() 函数可以从保存的状态字典中恢复模型的权重。示例代码如下:

# 加载某个 epoch 的模型
filename = os.path.join(checkpoint_dir, 'checkpoint-epoch25.pt')
state = torch.load(filename)
model.load_state_dict(state['state_dict'])

2.例子

#写一个新的MLP网络,网络参数读取此前已经存的epoch=25时的网络模型
#假设我们已经将模型保存在名为 model_epoch25.pth 的文件中,以下是加载该模型并用它进行测试的示例
#网络模型和训练时保持一致就行
import torch
import torch.nn as nn

# 定义 MLP 网络
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        return x

# 创建 MLP 模型实例
model = MLP()

# 加载模型状态
checkpoint = torch.load('model_epoch25.pth')
model.load_state_dict(checkpoint['state_dict'])

# 将模型设置为评估模式
model.eval()

with torch.no_grad():
    # 进行测试
    # 这里省略测试代码

在使用 PyTorch 模型进行测试或评估时,通常不需要计算梯度,因为我们不会更新模型的权重。因此,我们可以使用 torch.no_grad() 上下文管理器,将计算图上下文中的梯度计算禁用掉,以提高计算效率。

3.补充说明

在深度学习模型训练过程中,我们通常需要保存模型的状态,以便在需要时可以恢复模型并继续训练或进行推断。通常,我们需要保存以下状态:

  1. 模型的权重或参数;

  1. 优化器的状态,包括学习率、动量等;

  1. 当前的训练 epoch;

  1. 当前的训练损失等。

为了方便地保存这些状态,通常将它们保存在一个 Python 字典中。在 PyTorch 中,通常会使用以下代码创建这个字典:

state = {
    'epoch': epoch,                      # 当前训练 epoch
    'state_dict': model.state_dict(),    # 模型的权重或参数
    'optimizer': optimizer.state_dict(), # 优化器的状态
    'loss': loss                         # 当前训练损失
}

在上面的代码中,epoch 表示当前训练 epoch,model.state_dict() 返回模型的权重或参数的字典表示,optimizer.state_dict() 返回优化器的状态字典表示,loss 表示当前训练损失。将这些状态保存在一个字典中,可以方便地将它们保存到磁盘中,并在需要时加载到模型中。

  • 8
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值