使用 Pytorch 训练深度学习模型时常用的功能代码(保持更新)

本文介绍了如何在深度学习项目中通过固定随机种子确保模型一致性,包括seed_torch_everywhere函数的应用。此外,讲解了模型参数的保存与load_ckpt函数,以及如何使用正则化技术EarlyStopping来监控训练过程。还涉及超参数搜索和模型状态管理的关键步骤。
摘要由CSDN通过智能技术生成

固定随机种子以确保模型可复现

import os
import torch
import random
import numpy as np
def seed_torch_everywhere(seed=24):

	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed)
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
	torch.backends.cudnn.benchmark = False
	torch.backends.cudnn.deterministic = True

保存与加载模型参数

应用场景:模型训练意外中断后,在最后一次保存的模型参数上接续训练

import torch
import shutil
def save_ckpt(state, checkpoint_dir, best_model_dir, is_best=False,  file_name='checkpoint.pt'):
    r"""在训练时将模型参数 state 保存在 checkpoint_dir 文件夹下,
    若当前模型为迄今最优模型则将此时的参数另复制一份到 best_model_dir 下。
    除了保存模型参数外,还可保存优化器、学习率规划器的状态,以及当前 epoch 值等。
    Usage:
    >>> checkpoint = {
    >>>     'epoch': epoch + 1,
    >>>     'state_dict': model.state_dict(),
    >>>     'optimizer': optimizer.state_dict()
    >>> }
    >>> save_ckpt(checkpoint, checkpoint_dir, best_model_dir, is_best)
    """
    f_path = os.path.join(checkpoint_dir, file_name)
    torch.save(state, f_path)
    if is_best:
        best_f_path = os.path.join(best_model_dir, file_name)
        shutil.copyfile(f_path, best_f_path)
def load_ckpt(checkpoint_fpath, model, optimizer=None, lr_scheduler=None):
    r"""从 checkpoint_fpath 中加载模型、优化器、学习率规划器、epoch 值等
    Usage:
    >>> model = MyModel(**kwargs)
    >>> optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    >>> ckpt_path = "path/to/checkpoint/checkpoint.pt"
    >>> model, optimizer, start_epoch = load_ckpt(ckpt_path, model, optimizer) 
    """
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['state_dict'])
    epoch = checkpoint['epoch']
    outputs = (model, epoch)
    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer'])
        outputs += (optimizer, )
    if lr_scheduler:
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        outputs += (lr_scheduler, )

    return outputs

正则化 EarlyStopping

代码与使用方式可参考另一篇博客

超参的随机搜索

伪代码与使用方式可参考另一篇博客

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值