Pytorch高级训练框架Ignite详细介绍与常用模版

引言

Ignite是Pytorch配套的高级框架,我们可以借其构筑一套标准化的训练流程,规范训练器在每个循环、轮次中的行为。本文将不再赘述Ignite的具体细节或者API,详见官方教程和其他博文。本文将分析Ignite的运行机制、如何将Pytorch训练代码转为Ignite范式,最后给出个人设计的标准化Ignite训练模版。

Ignite简介

 Ignite所做的事情就是我们在pytorch里常写的范式用更加机械、更加标注格式展现出来,这也就是为啥其核心被称为–Engine,高效而精密。Pytorh里常用的训练范式如下:

for ep in Epoch:
	for batch in train_loader:
	    model.train()
	    inputs, targets = batch
	    optimizer.zero_grad()
	    outputs = model(inputs)
	    loss = criterion(outputs, targets)
	    loss.backward()
	    optimizer.step()
	    
		if it%log_period:
			print()
	if ep%save_perid:
		torch.save()

具体而言,可以拆解为批训练、批完结处理、轮次完结处理三个组成部分,批训练部分是网络训练的基础单元,完成数据当前批次读取、前向传播、反向传播等步骤,批完结处理负责在每个批次结束后输出模型训练的相关信息,轮次完结处理负责在每个epoch结束进行模型的保存、对模型的训练参数进行更新。这三个模型训练的主要组成部分在ignite中得到了完整的封装,围绕批训练构造了一个核心的Engine,将批完结处理轮次完结处理附加该Engine运行的时间轴中,形成了批训练->批完结处理->轮次完结处理的流水线作业范式,更为详细的时间轴如下1
在这里插入图片描述
以下将从实用性的角度出发给出Ignite的建设框架,最终给出个人设计的Ignite使用模版,后续直接在train.py文件里直接调用do_train()函数即可利用Ignite进行模型训练。为讲解需要,中间每个子部分的代码为最终代码中相应部分重新排序得到,最终代码中其顺序会进行调整。

批训练

 批训练的代码较为简单,只需将原本的Pytorch版本批处理流程复制粘贴,最后将该过程函数化,并且实例化成Engine即可,代码如下所示,最终启动Engine,即可进行模型的训练。到此为止,实际上已经完成了狭义上的“模型”训练部分。

 def create_supervised_trainer(model,optimizer,criterion,
                              device=None, non_blocking=False,
                              prepare_batch=_prepare_batch,
                              output_transform=lambda x, y, y_pred, loss: loss.item()):
      """
      有监督模型的Engine创建

      Args:
          model (`torch.nn.Module`):
          optimizer (`torch.optim.Optimizer`):
          loss_fn (torch.nn loss function):
          device (str, optional):
          non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
              with respect to the host. For other cases, this argument has no effect.
          prepare_batch (callable, optional): 批处理函数,对dataloader的输出进行处理
          output_transform (callable, optional): 输出变换函数,设定输出,默认情况下,输入为x,y,y_pred,loss,输出loss.item()

      Note: engine在每个batch下的最终输出由transform所指定,默认传回loss.item

      Returns:
          Engine: 有监督任务的engine实例
      """
      if device:
          model.to(device)

      def _update(engine, batch):
          model.train()
          optimizer.zero_grad()
          x,y = prepare_batch(batch, device=device, non_blocking=non_blocking)
          output = model(x)
          loss=criterion(output,y)
          loss.backward()
          optimizer.step()
          return output_transform(x, y, None, loss)

      return Engine(_update)

trainer=create_supervised_trainer(model,optimizer,criterion,device)  # 建立ignite的engine
trainer.run(train_loader,max_epochs=cfg['max_epochs'])

批完结处理

 批完结处理部分我们常做的操作是输出模型在当前批的损失,Ignite中这一过程通过在Engine上附着于ITERAION_COMPLETEDE时触发的回调函数实现。实际上这只是限定了触发时间,具体进行何种操作,完全依赖于个人的选择。我们只需要知道该函数可以利用engine保留的当前批属性信息进行各种操作即可,具体可以利用哪些属性,见官方API2,本文只利用了常用的几个。

    ##########################################################################################
    ###########                    Events.ITERATION_COMPLETED                    #############
    ##########################################################################################

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        """
        隔一定iteration输出模型损失
        """
        log_period=int(cfg["log_period"]*len(train_loader))  # 跑了log_period*len输出一次,取值<=1
        if engine.state.iteration%log_period==0:
            pbar.write(f"Epoch {engine.state.epoch}, iter {engine.state.iteration}: Loss {engine.state.output:.2f}")
            pbar.update(log_period)


    @trainer.on(Events.ITERATION_COMPLETED)
    def scheduler_update(engine):
        """
        optional 每个ITER更新学习率
        """
        scheduler.step()

轮次完结处理

  轮次完结处理和批完结处理相同,也是通过回调函数实现,我们通常在轮次完结处理要进行模型的保存,这里就要做两件事:

  1. 在val_loader上验证模型效果
  2. 保留迄今为止效果最好的模型

针对于第一个要求,这里我同样采用了Ignite风格的Engine驱动范式,读者可以自行选择在这里切换为Pytorch范式,构建验证集Engine的代码如下:

    def create_supervised_evaluator(model, metric,
                                device=None, non_blocking=False,
                                prepare_batch=_prepare_batch,
                                output_transform=lambda x, y, y_pred: (y_pred,y)):
        """
        构造evaluator
        :param model:
        :param metric: dict,key为metric名字,value为Metric类
        :param device:
        :param non_blocking:
        :param prepare_batch:
        :param output_transform:
        :return:
        """
        if device:
            model.to(device)

        def _inference(engine, batch):
            model.eval()
            with torch.no_grad:
                x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
                output = model(x)
            return output_transform(x, y, output)

        engine=Engine(_inference)

		# 附着metric
        for name, metric in metric.items():
            metric.attach(engine,name)

        return engine

可以看到和trainer较为不同的点在于去除了opt等等选项,此外,由于保存模型时我们要依据验证集上的metric来判断是否要保存当前模型还是沿用此前的模型,因此额外将一个Metric类附着在了Engine上,它使得模型可以自动收集eval_engine每个轮次的输出,并进行metric的计算,ignite中提供了许多metric选项3,这里笔者给出自己定制mertric的范式如下,主要由reset(),update()commpute()组成,reset()完成每个epoch的记录状态重置,update()则接受某一批次engine的输出值,commpute()完成最终的metric计算。值得一提的是,在trainer上我们并没有额外附着Loss类,而是直接用engine输出了loss,实际上或许你也可以用相同的方式对eval_engine进行处理。

class CustomMetric(Metric):
    def __init__(self):
        super(CustomMetric,self).__init__()

    def reset(self) -> None:
        self._num_correct=0
        self._num_examples=0

    def update(self, output) -> None:
        '''
        保存该轮次的输出
        :param output: 每个batch engine的输出
        :return:
        '''
        pred,label=output
        pred=pred.detach()
        label=label.detach()

        indices=torch.argmax(pred,dim=1)
        correct=torch.eq(indices,label).view(-1)

        self._num_correct += torch.sum(correct).item()
        self._num_examples += correct.shape[0]

    def compute(self):
        '''
        计算总ACC
        :return:
        '''
        return self._num_correct/self._num_examples

 完成第二步的方法Ignite同样进行了封装,即Checkpoint4,但笔者也进行了自己的定制化,如下:

class BestCheckPoint():
    def __init__(self,save_path,n_saved,model_name):
        '''
        建立存档点类
        :param save_path: 存档点保存路径
        :param n_saved:  保留的存档点数目
        '''
        self.save_path=save_path
        self.n_save=n_saved
        self.model_name=model_name
        self.score=[]
        if not os.path.exists(save_path):
            os.mkdir(self.save_path)

    def update(self,score):
        '''
        更新最优记录
        :param score: 当前模型的metric
        :return:
        '''
        if type(self.score)==torch.Tensor:
            score=score.item()
        if len(self.score)<self.n_save:
            self.score.append(score)
            self.score.sort()
            self.removed=[]
            self._in=[]
            return True
        else:
            value=self.score[0]
            if score>value:
                self.score.remove(value)
                self.score.append(score)
                self.score.sort()
                return value
            else:
                return False

    def save(self,score,model):
        '''
        视当前得分判断是否保存当前模型并删除
        :param score: 当前模型得分
        :param model: 模型
        :return:
        '''
        is_save=self.update(score)
        if is_save:
            torch.save(model.state_dict(), os.path.join(self.save_path, self.model_name + f"_{score:.4f}.pth"))
            # pop的存档要删除
            if not isinstance(is_save,bool):
                # 似乎ignite存在并行机制,单步运行的时候没问题,多步就会发生早就remove的错误,可以通过保存每一次更新后的score验证
                try:
                    os.remove(os.path.join(self.save_path,self.model_name+f"_{is_save:.4f}.pth"))
                except:
                    print("already removed")

主要就是设置了一个metric池,新的metric进来后判断是否优于池子里最烂的模型,并以此判断是否进行保存。将这个CheckPoint类实例化,并且在trainer每个Epoch完成后遍历eval_engine得到当前模型在验证集上的metric,对其进行更新即可完成模型的保存,代码如下:

    evaluator=create_supervised_evaluator(model,{"ACC":CustomMetric()})
    CP=BestCheckPoint(cfg['save_path'],cfg['n_saved'],cfg['model_name'])
   ##########################################################################################
    ###########                    Events.EPOCH_COMPLETED                    #############
    ##########################################################################################

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_model(engine):
        '''
        保存模型
        :param engine:
        :return:
        '''
        if engine.state.epoch % cfg['save_period']==0:
            evaluator.run(val_loader)
            metrics=evaluator.state.metrics
            print(f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['Acc']:.2f}")
            CP.save(metrics['ACC'],model)

实际上这里还附加了判断,每隔 cfg['save_period']个轮次才进行验证集上的评估和模型保存。

运行框架

 将上述模块封装在一起,我们就可以得到了最终的ignite运行框架,而后只需导入该文件,并运行其中的do_train()函数即可轻松完成模型训练,其中为了方便模型进程的可视化,使用了pbar模块来进行显示,pbar在固定iter次后输出当前训练信息并更新进度条,在epoch完成后重置,同样通过回调函数的形式附加在了trainer上。

整体代码:

# -*- coding: utf-8 -*-
# ---
# @File: trainer.py
# @Author: sgdy3
# @E-mail: sgdy03@163.com
# @Time: 2023/5/9 19:44
# Describe: 
# ---
import os

from tqdm import tqdm
import ignite
import torch
from ignite.engine import Engine
from ignite.utils import convert_tensor
from ignite.engine.engine import Engine, State, Events
from ignite.engine import create_supervised_evaluator
from ignite.metrics import Metric,Accuracy


class BestCheckPoint():
    def __init__(self,save_path,n_saved,model_name):
        '''
        建立存档点类
        :param save_path: 存档点保存路径
        :param n_saved:  保留的存档点数目
        '''
        self.save_path=save_path
        self.n_save=n_saved
        self.model_name=model_name
        self.score=[]
        if not os.path.exists(save_path):
            os.mkdir(self.save_path)

    def update(self,score):
        '''
        更新最优记录
        :param score: 当前模型的metric
        :return:
        '''
        if type(self.score)==torch.Tensor:
            score=score.item()
        if len(self.score)<self.n_save:
            self.score.append(score)
            self.score.sort()
            self.removed=[]
            self._in=[]
            return True
        else:
            value=self.score[0]
            if score>value:
                self.score.remove(value)
                self.score.append(score)
                self.score.sort()
                return value
            else:
                return False

    def save(self,score,model):
        '''
        视当前得分判断是否保存当前模型并删除
        :param score: 当前模型得分
        :param model: 模型
        :return:
        '''
        is_save=self.update(score)
        if is_save:
            torch.save(model.state_dict(), os.path.join(self.save_path, self.model_name + f"_{score:.4f}.pth"))
            # pop的存档要删除
            if not isinstance(is_save,bool):
                # 似乎ignite存在并行机制,单步运行的时候没问题,多步就会发生早就remove的错误,可以通过保存每一次更新后的score验证
                try:
                    os.remove(os.path.join(self.save_path,self.model_name+f"_{is_save:.4f}.pth"))
                except:
                    print("already removed")



class CustomMetric(Metric):
    def __init__(self):
        super(CustomMetric,self).__init__()

    def reset(self) -> None:
        self._num_correct=0
        self._num_examples=0

    def update(self, output) -> None:
        '''
        保存该轮次的输出
        :param output: 每个batch engine的输出
        :return:
        '''
        pred,label=output
        pred=pred.detach()
        label=label.detach()

        indices=torch.argmax(pred,dim=1)
        correct=torch.eq(indices,label).view(-1)

        self._num_correct += torch.sum(correct).item()
        self._num_examples += correct.shape[0]

    def compute(self):
        '''
        计算总ACC
        :return:
        '''
        return self._num_correct/self._num_examples


def do_train(model,optimizer,criterion,scheduler,device,train_loader,val_loader,cfg):

    def _prepare_batch(batch, device=None, non_blocking=False):
        """
        对dataloader每个batch的输出进行进一步的处理
        :param batch: dataloader输出
        :param device:
        :param non_blocking:
        :return:
        """
        device = "cuda:" + str(device)
        x, y = batch
        x = convert_tensor(x,device=device,non_blocking=non_blocking)
        y = convert_tensor(y,device=device,non_blocking=non_blocking)
        return x,y


    def create_supervised_trainer(model,optimizer,criterion,
                                device=None, non_blocking=False,
                                prepare_batch=_prepare_batch,
                                output_transform=lambda x, y, y_pred, loss: loss.item()):
        """
        有监督模型的Engine创建

        Args:
            model (`torch.nn.Module`):
            optimizer (`torch.optim.Optimizer`):
            loss_fn (torch.nn loss function):
            device (str, optional):
            non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
                with respect to the host. For other cases, this argument has no effect.
            prepare_batch (callable, optional): 批处理函数,对dataloader的输出进行处理
            output_transform (callable, optional): 输出变换函数,设定输出,默认情况下,输入为x,y,y_pred,loss,输出loss.item()

        Note: engine在每个batch下的最终输出由transform所指定,默认传回loss.item

        Returns:
            Engine: 有监督任务的engine实例
        """
        if device:
            model.to(device)

        def _update(engine, batch):
            model.train()
            optimizer.zero_grad()
            x,y = prepare_batch(batch, device=device, non_blocking=non_blocking)
            output = model(x)
            loss=criterion(output,y)
            loss.backward()
            optimizer.step()
            return output_transform(x, y, None, loss)

        return Engine(_update)

    def create_supervised_evaluator(model, metric,
                                device=None, non_blocking=False,
                                prepare_batch=_prepare_batch,
                                output_transform=lambda x, y, y_pred: (y_pred,y)):
        """
        构造evaluator
        :param model:
        :param metric: dict,key为metric名字,value为Metric类
        :param device:
        :param non_blocking:
        :param prepare_batch:
        :param output_transform:
        :return:
        """
        if device:
            model.to(device)

        def _inference(engine, batch):
            model.eval()
            with torch.no_grad:
                x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
                output = model(x)
            return output_transform(x, y, output)

        engine=Engine(_inference)

        for name, metric in metric.items():
            metric.attach(engine,name)

        return engine



    trainer=create_supervised_trainer(model,optimizer,criterion,device)  # 建立ignite的engine
    evaluator=create_supervised_evaluator(model,{"ACC":CustomMetric()})
    CP=BestCheckPoint(cfg['save_path'],cfg['n_saved'],cfg['model_name'])
    pbar=tqdm(total=len(train_loader))  # 为训练器迭代器建立进度条




    ##########################################################################################
    ###########                    Events.ITERATION_COMPLETED                    #############
    ##########################################################################################

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        """
        隔一定iteration输出模型损失
        """
        log_period=cfg["log_period"]
        if engine.state.iteration%log_period==0:
            pbar.write(f"Epoch {engine.state.epoch}, iter {engine.state.iteration}: Loss {engine.state.metrics['avg_loss']:.2f}")
            pbar.update(log_period)

    @trainer.on(Events.ITERATION_COMPLETED)
    def scheduler_update(engine):
        """
        optional 每个ITER更新学习率
        """
        scheduler.step()

    ##########################################################################################
    ###########                    Events.EPOCH_COMPLETED                    #############
    ##########################################################################################

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_model(engine):
        '''
        保存模型
        :param engine:
        :return:
        '''
        if engine.state.epoch % cfg['save_period']==0:
            evaluator.run(val_loader)
            metrics=evaluator.state.metrics
            print(f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['Acc']:.2f}")
            CP.save(metrics['ACC'],model)

    @trainer.on(Events.EPOCH_COMPLETED)
    def reset_bar(engine):
        '''
        重置进度条
        :param engine:
        :return:
        '''
        pbar.reset()

    ##########################################################################################
    #################                    training Start                    ###################
    ##########################################################################################
    trainer.run(train_loader,max_epochs=cfg['max_epochs'])
    pbar.close()

参考


  1. Events | Pytorch-Ignite ↩︎

  2. State | Ignite ↩︎

  3. IGNITE.METRICS ↩︎

  4. CHECKPOINT ↩︎

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值