『ignite』模型的训练过程

trainer的父类

from typing import Mapping, Dict, Optional

import torch
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Engine, Events
from ignite.handlers import TerminateOnNan, EarlyStopping
from ignite.metrics import Loss, RunningAverage

from src.exception import ModelNotFoundException
from src.experiment import Number


class Trainer(object):

    def __init__(self, *, model: torch.nn.Module = None, file: str = None, save: str = None, device: str = None):
        if model is not None:
            self.model = model
        elif file is not None:
            self.model = torch.load(file, map_location=device)
        else:
            raise ModelNotFoundException("模型未定义,请传入 torch.nn.Module 对象或可加载的模型的文件路径.")

        if device is not None:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        if save is not None:
            self.save: str = save
        else:
            raise ValueError("模型存储路径未定义!")

        # 评价指标,一般都至少使用 MSE
        self.metrics: Dict = {"MSE": Loss(torch.nn.MSELoss())}

        self.trainer: Optional[Engine] = None
        self.evaluator: Optional[Engine] = None

    def set_dataset(self, train_batch_size, val_batch_size=1) -> None:
        raise NotImplementedError("请重写 set_dataset.")

    def set_metrics(self, metric: Mapping) -> None:
        """
        设置自定义的评价指标,以字典形式传入
        """
        self.metrics.update(metric)

    @staticmethod
    def score_function(engine: Engine) -> Number:
        return -engine.state.metrics["MSE"]

    def early_stop(self, every: int = 1, patience: int = 10, min_delta: float = 0,
                   output_transform=lambda x: {'MSE': torch.nn.MSELoss()(*x)}) -> None:
        """
        如果模型试集的性能没有提升,则提前停止训练

        :param every:                      间隔多少个 EPOCH 验证一次测试集
        :param patience:                   多少次模型在测试集上性能没有优化就停止训练
        :param min_delta:                  分数最少提高多少才认为有改进
        :param output_transform:           对 engine 的输出进行转换的函数,转换成日志要输出的评估值
        :return:
        """
        evaluator_bar_format = "\033[0;32m 测试集验证:{percentage:3.0f}%|{bar}{postfix} 【已执行时间:{elapsed},剩余时间:{remaining}】\033[0m"
        bar = ProgressBar(persist=True, bar_format=evaluator_bar_format)
        bar.attach(self.evaluator, output_transform=output_transform)

        handler = EarlyStopping(patience=patience, score_function=self.score_function,
                                trainer=self.trainer, min_delta=min_delta)
        self.evaluator.add_event_handler(Events.COMPLETED, handler)
        # noinspection PyUnresolvedReferences
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED(every=every), lambda: self.evaluator.run(self.test_set))

    def create_trainer(self) -> None:
        """
        创建 trainer engine
        """
        raise NotImplementedError("请重写 create_trainer.")

    def create_evaluator(self) -> None:
        """
        创建 evaluator engine
        """
        raise NotImplementedError("请重写 create_evaluator.")

    def set_trainer(self):
        """
        配置切面操作
        :return:
        """
        # 遇到 NaN 终止训练
        self.trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

        """
        控制台记录日志
        bar_format:输出的格式
        """
        trainer_bar_format = "\033[0;34m{desc}【{n_fmt:0>5s}/{total_fmt:0>5s}】 {percentage:3.0f}%|{bar}{postfix} 【已执行时间:{elapsed},剩余时间:{remaining}】\033[0m"

        # 第一行是指定输出的指标,第二行方法是连接Engine对象,输出的是设定的 Loss 的就计算结果
        # ProgressBar(persist=True, bar_format=trainer_bar_format).attach(self.trainer, metric_names=['MSE'])
        bar = ProgressBar(persist=True, bar_format=trainer_bar_format)
        # bar.attach(self.trainer, output_transform=lambda x: {'loss': x})

        RunningAverage(output_transform=lambda x: x, alpha=0.98).attach(self.trainer, 'loss')  # 计算指标的运行平均值
        bar.attach(self.trainer, metric_names=["loss"])
        self.trainer.add_event_handler(Events.COMPLETED, lambda: torch.save(self.model, self.save))
        self.trainer.add_event_handler(Events.COMPLETED, lambda: print("训练结束...."))
        self.trainer.add_event_handler(Events.STARTED, lambda: print("训练开始...."))

    def run(self, max_epochs, test_frequency=10) -> None:
        if not hasattr(self, "train_set") or not hasattr(self, "test_set"):
            raise FileExistsError("请先通过 set_dataset 方法设置数据集.")
        self.create_trainer()
        self.create_evaluator()
        self.set_trainer()
        self.early_stop(every=test_frequency)
        # noinspection PyUnresolvedReferences
        self.trainer.run(self.train_set, max_epochs=max_epochs)


trainer的实现类

import torch
from ignite.contrib.handlers import LRScheduler
from ignite.engine import create_supervised_trainer, Events, create_supervised_evaluator
from torch import nn, optim
from torch.optim.lr_scheduler import ExponentialLR

from src.data import get_data_loaders
from src.experiment.Trainer import Trainer
from src.model import ConvLSTM
from src.util import config
from src.util.patch import reshape_patch_back

cfg = config.load_model_parameters("ConvLSTM")


class ConvLSTMTrainer(Trainer):
    def __init__(self, *, model: torch.nn.Module = None, file: str = None, save: str = None, device: str = None):
        super().__init__(model=model, file=file, save=save, device=device)
        self.model.to(device)

    def create_trainer(self) -> None:
        """
        学习率衰减的代码可以写在这,虽然也是创建 Handler,我认为在这比较适合
        """

        criterion = nn.MSELoss()

        optimizer = optim.Adam(self.model.parameters(), lr=0.01)

        self.trainer = create_supervised_trainer(model=self.model, optimizer=optimizer, loss_fn=criterion,
                                                 device=self.device)

        # 学习率衰减
        step_scheduler = ExponentialLR(optimizer=optimizer, gamma=0.98)
        scheduler = LRScheduler(step_scheduler)
        self.trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)

    def create_evaluator(self) -> None:
        self.evaluator = create_supervised_evaluator(model=self.model, metrics=self.metrics, device=self.device,
                                                     output_transform=lambda x, y, y_pred: (
                                                         reshape_patch_back(y_pred, patch_size=4),
                                                         reshape_patch_back(y, patch_size=4)
                                                     ))

    def set_dataset(self, train_batch_size, val_batch_size=1) -> None:
        train_set, test_set = get_data_loaders("ConvLSTM", train_batch_size, val_batch_size)
        setattr(self, "train_set", train_set)
        setattr(self, "test_set", test_set)

测试代码

if __name__ == '__main__':
    net = ConvLSTM(in_channels=cfg["in_channels"] * 4 * 4, hidden_channels_list=cfg["hidden_channels_list"],
                   kernel_size_list=cfg["kernel_size_list"], forget_bias=cfg["forget_bias"])
    trainer = ConvLSTMTrainer(model=net, save="test.pth", device="cuda")
    trainer.set_dataset(train_batch_size=2, val_batch_size=1)
    trainer.run(max_epochs=3, test_frequency=1)

控制台显示

在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值