pytorch学习笔记(二十): ignite (写更少的代码训练模型)

最近自己想写一个高级一点的抽象来更方便的训练 pytorch 网络, 无意间发现, pytorch 用户下面有个 ignite repo, 好奇就看了一下这是个什么东西. 原来是 pytorch 已经提供了一个高级抽象库来训练 pytorch模型了, 既然有了轮子, 那就没必要自己造了, 好好用着就行了. 没事读读源码, 也可以学习一下大佬们是怎么抽象的. 本博文主要是对 ignite 做一个宏观上的介绍.

官方文档

例子

为了减少源码篇幅, 特地将与 ignite 关系不大的代码给删除了, 如果想跑完整示例的话, 可以查看上面提到的链接.

from argparse import ArgumentParser
from torch import nn
from torch.optim import SGD
from torchvision.transforms import Compose, ToTensor, Normalize

from ignite.engines import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import CategoricalAccuracy, Loss

def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
    cuda = torch.cuda.is_available()
    device = torch.device("gpu") if torch.cuda.is_avaliable() else ("cpu")
    train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)

    model = Net()
    
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={'accuracy': CategoricalAccuracy(),
                                                     'nll': Loss(F.nll_loss)},
                                            device=device)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.iteration - 1) % len(train_loader) + 1
        if iter % log_interval == 0:
            print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}".format(engine.state.epoch, iter, len(train_loader), engine.state.output))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        metrics = evaluator.run(val_loader).metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        print("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
              .format(engine.state.epoch, avg_accuracy, avg_nll))

    trainer.run(train_loader, max_epochs=epochs)

先对流程做一下总结, 再看API做了些什么

  • 创建模型, 创建 Dataloader
  • 创建 trainer
  • 创建 evaluator
  • 为一些事件注册事件处理函数, @trainer.on()
  • trainer.run()

Event

"""
类似枚举类, 定义了几个事件
"""
class Events(Enum):
    EPOCH_STARTED = "epoch_started"               # 当一个新的 epoch 开始时会触发此事件
    EPOCH_COMPLETED = "epoch_completed"           # 当一个 epoch 结束时, 会触发此事件
    STARTED = "started"                           # 开始训练模型是, 会触发此事件
    COMPLETED = "completed"                       # 当训练结束时, 会触发此事件
    ITERATION_STARTED = "iteration_started"       # 当一个 iteration 开始时, 会触发此事件
    ITERATION_COMPLETED = "iteration_completed"   # 当一个 iteration 结束时, 会触发此事件
    EXCEPTION_RAISED = "exception_raised"         # 当有异常发生时, 会触发此事件

State

class State(object):
    def __init__(self, **kwargs):
        self.iteration = 0            # 记录 当前的 iteration
        self.output = None            # 当前iteration(process_function)的 输出. 对于 Supervised Trainer 来说, 是 loss.
        self.batch = None             # 本次 iteration 的 mini-batch 样本
        for k, v in kwargs.items():   # 其它一些希望 State 记录下来的 状态,metrics等.
            setattr(self, k, v)

Engine

def __init__(self, process_function):
	pass 

"""
对于 训练过程 来说,process_function 是一个 前向+反向+参数更新 过程
process_function 的 signature 是 func(batch)->anything
def func(self, batch): # batch会保存在 state.batch 中, self用来接收当前对象
	1. process batch
	2. forward compution
	3. compute loss
	4. computer gradient
	5. update parameters
	6. return loss or else # 返回的值会被保存在 state.output 中

对于 评估过程来说, process_function 是一个 前向+计算 metrics 的过程。
def func(self, batch): # batch会保存在 state.batch 中
	1. process batch
	2. forward compution
	3. return something # 返回的值会被保存在 state.output 中,
	#  用来计算 Metric
"""


""" 为某事件注册处理函数,注册在当前engine上, 当事件发生时, 此函数就会被调用
函数的 signature 必须是 def func(engine)
"""
@engine.on(...)
def some_func(trainer):
    pass

Engine.run() # 训练/评估 模型

Metric

定义了一些模型评估标准

  • 在创建 evaluator 的时候会指定一些 metric,这些metric 会在create_supervised_evaluator 中通过调用Metric.attach 方法 将Metric 中的一些方法 注册成为 evaluator 的 event handler。会在相应的事件发生时调用相应的处理函数。
  • Metric的一些接口如下所示,可以通过重写以下接口进行自定义Metric
# reset : epoch 开始之前调用一次
# update: 每次iteration 结束时调用
# compute : epoch 结束时调用

工厂方法,用来创建 Engine

可以根据自己的需要改写以下两个函数

create_supervised_trainer

def create_supervised_trainer(model, optimizer, loss_fn, cuda=False):
    """
    Factory function for creating a trainer for supervised models

    Args:
        model (torch.nn.Module): the model to train
        optimizer (torch.optim.Optimizer): the optimizer to use
        loss_fn (torch.nn loss function): the loss function to use
        cuda (bool, optional): whether or not to transfer batch to GPU (default: False)

    Returns:
        Trainer: a trainer instance with supervised update function
    """

create_supervised_evaluator

def create_supervised_evaluator(model, metrics={}, cuda=False):
    """
    Factory function for creating an evaluator for supervised models

    Args:
        model (torch.nn.Module): the model to train
        metrics (dict of str: Metric): a map of metric names to Metrics
        cuda (bool, optional): whether or not to transfer batch to GPU (default: False)

    Returns:
        Evaluator: a evaluator instance with supervised inference function
    """
  • 11
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
Pytorch是机器学习中的一个重要框架,它与TensorFlow一起被认为是机器学习的两大框架。Pytorch学习可以从以下几个方面入手: 1. Pytorch基本语法:了解Pytorch的基本语法和操作,包括张量(Tensors)的创建、导入torch库、基本运算等\[2\]。 2. Pytorch中的autograd:了解autograd的概念和使用方法,它是Pytorch中用于自动计算梯度的工具,可以方便地进行反向传播\[2\]。 3. 使用Pytorch构建一个神经网络:学习使用torch.nn库构建神经网络的典型流程,包括定义网络结构、损失函数、反向传播和新网络参数等\[2\]。 4. 使用Pytorch构建一个分类器:了解如何使用Pytorch构建一个分类器,包括任务和数据介绍、训练分类器的步骤以及在GPU上进行训练等\[2\]。 5. Pytorch的安装:可以通过pip命令安装Pytorch,具体命令为"pip install torch torchvision torchaudio",这样就可以在Python环境中使用Pytorch了\[3\]。 以上是一些关于Pytorch学习笔记,希望对你有帮助。如果你需要详细的学习资料,可以参考引用\[1\]中提到的网上帖子,或者查阅Pytorch官方文档。 #### 引用[.reference_title] - *1* [pytorch自学笔记](https://blog.csdn.net/qq_41597915/article/details/123415393)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [Pytorch学习笔记](https://blog.csdn.net/pizm123/article/details/126748381)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值