300行代码实现一个优雅的PyTorch Trainer

如果觉得下面哪部分讲的不够清楚,欢迎评论指出。

2022.5.30 更新:支持CI自动化测试,代码质量更加可靠

2022.5.29 更新:新增详细的API文档

为什么需要Trainer?

在江湖上混,没有一把趁手的兵器怎么能行。对于一个深度学习算法工程师来说,一个强大的Trainer就是趁手的兵器,在复现顶会论文的时候不用每次都要写大量的模版代码,可以更加关注模型实现本身。我平时也开源了一些深度学习项目,累计获赞400+,最近整理了一些之前的代码,也看了很多优秀的开源实现,然后从0实现了一个小而美的trainer,主要有以下两个优点吧:

  • 轻量级:代码量很少,易上手也易改动。
  • 可拓展:能够满足工程师各种魔改网络/训练流程的需求。

https://pica.zhimg.com/80/v2-eee3688917e0be64a71105f789a185bf_1440w.png?source=d16d100b

为什么不选择pytorch-lightning / detectron2 / mmcv?

现有的detectron2、mmcv、pytorch-lightning中的trainer虽然也很优雅,但是看了源码就能感受到代码中的抽象层次太多,看的有些吃力。例如下面这个文件是detectron2中保存checkpoint的代码,共594行,考虑到了很多场景,功能非常完善。

fvcore/checkpoint.py at main · facebookresearch/fvcore

但是对于我们平时使用来说,可能需要的只是如下几行代码。

https://pic2.zhimg.com/80/v2-55c30e012c4453e40d209b36faaf131a_1440w.png?source=d16d100b

此外detectron2是基于迭代而不是基于epoch来训练的,和我平时的习惯有些许出入。mmcv虽然是基于epoch的,但是整体设计上要略微逊色于detectron2。

最重要的一点:虽然以上这些开源库提供的trainer足够强大,学会之后能满足所有需求,但是,我还是想造个轮子。一方面是实现过程中会深入思考一些工程细节,对于代码水平会有提高。另一方面是写一个trainer也不咋耗时间,最近没什么事情,一周还是拿得出来的。

代码已经放到github上了,仓库名字叫:Core-PyTorch-Utils,简称CPU。代码注释和文档都非常详细。

GitHub - serend1p1ty/core-pytorch-utils: Core components for deep learning.

https://pic1.zhimg.com/80/v2-3ad5f8776d0742baea5c4e6c2857b70a_1440w.png?source=d16d100b

https://pic2.zhimg.com/80/v2-be75e4a1c44bfb810da475761cd146e7_1440w.png?source=d16d100b

里面大多数文件都不会依赖其它模块,可以直接copy作为小工具使用。下面将逐个介绍这里的每个模块。

[1/6]灵活的配置系统:config_argparse.py

一般在一个深度学习项目中,有数十上百个需要调节的参数,cpu/config_argparse.py这个文件提供了一个简易灵活的配置系统,实现细节参考自pytorch-image-models

目前有两种比较流行的参数配置系统:

  1. 利用 argparse库,将所有参数添加到 argparse.ArgumentParser中,detrmoco就是这么干的。
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--learning-rate', type=float, default=0.001)
print(parser.parse_args())
  1. 利用 yacs库创建默认配置,然后从命令行传递 --config-file参数来覆盖默认配置。

default_config.py保存了项目的默认配置。

from yacs.config import CfgNode as CN

_C = CN()
_C.BATCH_SIZE = 1
_C.LEARNING_RATE = 0.001

def get_default_cfg():
    """返回一个默认配置的备份,防止默认配置被修改。"""
    return _C.clone()

config.yaml为某次实验使用的配置文件,用来覆盖默认配置。

BATCH_SIZE: 5

main.py中创建一个包含 --config-fileArgumentParser来实现加载配置文件。

import argparse
from default_config import get_default_cfg

parser = argparse.ArgumentParser()
parser.add_argument("--config-file", help="配置文件的路径。")
parser.add_argument("opts", nargs=argparse.REMAINDER, help="通过命令行修改配置。")
args = parser.parse_args()

# 最终的配置 = 默认配置 + 配置文件
cfg = get_default_cfg()
cfg.merge_from_file(args.cfg_file)
cfg.merge_from_list(args.opts)
cfg.freeze()

print(cfg)

然后运行 python main.py --config-file config.yaml,最终得到的配置为

BATCH_SIZE = 5
LEARNING_RATE = 0.001

如果配置文件中遗漏了一些内容,我们也可以在命令行中实时地修改配置。运行 python main.py --config-file config.yaml LEARNING_RATE 0.002,最终得到的配置为

BATCH_SIZE = 5
LEARNING_RATE = 0.002

第一种方式的缺点是无法加载 yaml配置文件,所有参数都要在命令行中输入(虽然 argparse自带 fromfile_prefix_chars特性,但不好用)。第二种方法的缺点是,一个配置系统被割裂成了两个部分, yacs部分+ argparse部分,且在命令行中通过 python main.py --help无法查看每个参数的说明(只能显示由 argparse管理的那部分参数,由 yacs管理的那部分参数则无能为力)。

因此我们的 config_argparse.py就实现了一个可以加载配置的文件的 ArgumentParser类。使用方式和 argparse完全相同,只是多了一个 --config的选项用来指定配置文件的路径。下面讲讲一些实现细节。

我们先定义了一个内部 parser用来解析用户指定的配置文件的路径。

self.config_parser.add_argument("-c", "--config", default=None, metavar="FILE",
                                help="where to load YAML configuration")

然后根据用户指定的配置文件覆盖相应选项的默认值。

# 先解析出用户指定的配置文件的路径,保存在res.config中
res, remaining_argv = self.config_parser.parse_known_args(args)

if res.config is not None:
    # 加载配置文件
    with open(res.config, "r") as f:
        config_vars = yaml.safe_load(f)
    # 判断配置文件中是否含有错误的键值对
    for key in config_vars:
        if key not in self.option_names:
            self.error(f"unexpected configuration entry: {key}")
    # ----将配置文件提供的值用作默认值----
    # 例如在定义某个选项的时候默认值为1,parser.add_argument("--batch-size", default=1)
    # 配置文件的内容为"batch_size: 2"
    # 那么就用配置文件提供的值覆盖该选项的默认值,即目前batch_size的默认值为2
    self.set_defaults(**config_vars)

# parse其余命令行参数
# 如果在命令中没有指定batch_size,那么最终的batch_size就是覆盖之后的默认值2
# 如果在命令行中指定了--batch-size=3,那么最终的batch_size就是3,因为命令行具有最高优先级!
return super().parse_args(remaining_argv)

目前我们实现的 ConfigArgumentParser类,还有一些小问题,例如:

如果指定了一个必选参数, parser.add_argument("--batch-size", required=True),且在配置文件中也已经设置了 batch_size的值,这个时候在命令行中理应不需要再设置 batch_size了。但是目前我们的实现仍然需要在命令行中指定 batch_size的值。这是因为我们将配置文件视为选项的默认值,而不是用户的实际输入值。

如果非常介意这个问题,可以选用ConfigArgParse这个库(实际上我们的实现方式就是ConfigArgParse的早期实现的简单版本)。但是在常用的深度学习系统中,我们提供的这个类已经足够了。

[2/6]将信息显示在终端和文件中:logger.py

https://pic2.zhimg.com/80/v2-429133b5d82e199211c0aef0a9ad8209_1440w.png?source=d16d100b

在深度学习项目中经常需要把终端显示的内容保存到文件中,这里我们选择使用python自带的 logging模块来作为信息记录系统。

我们在 logger.py中实现了一个 setup_logger()函数,用来对 logging模块返回的 logger做了一些简单配置

  • 设置终端中显示的信息的格式
  • 终端信息高亮,让warning和error信息更加醒目
  • 将终端显示的信息同步保存到文件中

下面是核心代码的讲解。

# 保证一个logger只会被初始化一次
if name in logger_initialized:
    return logger_initialized[name]

# 如果name等于None,则返回的是全局的root logger
# 初始化完root logger之后,所有logger的信息(除非这个logger的propagate为False)都会被传播到root logger,
# 然后root logger再根据配置好的格式将消息显示在终端中。
# 所以最常见的用法就是在程序的最开始调用一次setup_logger()配置好root logger,
# 然后其它所有logger都免配置直接使用。
logger = logging.getLogger(name)
logger.setLevel(log_level)
# 如果同时创建了子logger和父logger,
# 那么子logger产生的信息会向上传递给父logger,
# 终端中就会把同一段信息显示两次!因此在初始化logger的时候要设置logger.propagate=False,
# 防止子logger的信息向上传递。
logger.propagate = False

plain_formatter = logging.Formatter(
    "[%(asctime)s %(name)s %(levelname)s]: %(message)s", datefmt="%m/%d %H:%M:%S"
)

# 在分布式训练中只初始化主进程的logger
if rank == 0:
    ch = logging.StreamHandler(stream=sys.stdout)
    ch.setLevel(log_level)
    if color:
        # _ColorfulFormatter的功能是,如果当前信息的级别是INFO,
        # 那么就不展示级别,如果是WARNING、DEBUG、ERROR,
        # 就在信息的开头显示对应级别。这样的好处是,因为绝大多数信息都是INFO,
        # 这样以来可以节省空间显示更多文字。
        formatter = _ColorfulFormatter(
            colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
            datefmt="%m/%d %H:%M:%S",
        )
    else:
        formatter = plain_formatter
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    if output is not None:
        if output.endswith(".txt") or output.endswith(".log"):
            filename = output
        else:
            filename = os.path.join(output, "log.txt")

        # 如果单个文件名作为参数传递,则os.path.dirname将返回一个空字符串。
        # 例如,os.path.dirname("log.txt")== ""。这将导致os.makedirs()出错。
        # 所以我们需要用绝对路径来获取目录名。
        os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True)

        fh = logging.FileHandler(filename)
        fh.setLevel(log_level)
        fh.setFormatter(plain_formatter)
        logger.addHandler(fh)

logger_initialized[name] = logger
return logger

[3/6]平滑显示,让loss不要再上下乱跳了:history_buffer.py

如果想在终端中输出训练过程的loss值,相信有些同学会像下面这么写:

for i, batch in enumerate(data_loader):
    output = model(batch)
    loss = criterion(output, label)
    if i % 10 == 0:
        print(f"loss: {loss.item()}")

但是这样一来输出的loss跳变很大不利于观察,可能你看到的是这样一组数字:0.5、0.8、0.3…

为了更加平滑的显示loss,我们不应该直接输出loss的当前值,而是应该输出过去一段时间内loss的平均值。假设在5次迭代中loss的值分别为0.5、0.8、0.3、0.4、0.3,窗口大小为2(我们每次输出的是过去两个迭代的loss的平均值)。

  • 在第一次迭代时,loss队列如下:
0.5

数量太少,无法填满窗口,直接输出0.5

  • 第二次迭代,loss队列如下,窗口为0.5和0.8。
--------
0.5  0.8
--------

因此输出窗口内的平均值0.65

  • 第二次迭代,loss队列如下,窗口右移(因为窗口的含义是最近两个迭代)
     --------
0.5  0.8  0.3
     --------

输出窗口内的平均值0.55

  • 剩下的迭代以此类推。

history_buffer中提供了一个HistoryBuffer类,可以帮助我们输出窗口内的平均值(也能输出全局平均值、全局和、最近的值)。用法如下:

loss_buffer = HistoryBuffer()
for i, batch in enumerate(data_loader):
    output = model(batch)
    loss = criterion(output, label)
    loss_buffer.update(loss.item())
    if i % 10 == 0:
        print(f"loss: {loss_buffer.avg}")

pytorch-image-models中也有类似的用法,不过使用的是只能输出全局平均值的 AverageMeter。个人认为局部平均值比全局平均值更好,能够更明显地反映出当前的变化。所以不太理解为什么pytorch-image-models为什么这样做,知道的朋友烦请评论区告知。

[4/6]让PyTorch自带的lr scheduler支持warmup:lr_scheduler.py

这部分内容太长了,单独放到一篇文章中。

一文看懂学习率warmup及各主流框架实现差异

[5/6]巧用钩子赋予Trainer更强的可拓展性

CPU库的代码是写死的,而用户会有各种各样的需求,比如想每隔3个epoch就打印一次当前的学习率。为了给Trainer带来更强的拓展性,CPU采用了钩子(hook)机制。

我们首先定义了钩子的基类Hook,包含了几个内置的方法。以 before_epoch()为例,这个方法将在每个epoch开始之前调用。

class Hook:
    def before_train(self) -> None:
        """Called before the first iteration."""
        pass

    def after_train(self) -> None:
        """Called after the last iteration."""
        pass

    def before_epoch(self) -> None:
        """Called before each epoch."""
        pass

    def after_epoch(self) -> None:
        """Called after each epoch."""
        pass

    def before_iter(self) -> None:
        """Called before each iteration."""
        pass

如果想要实现自定义的功能(例如每隔三个epoch保存一次模型的权重),用户可以继承这个基类。

class CheckpointHook(Hook):
    def after_epoch(self):
        # self.trainer是在钩子注册的时候定义的
        epoch = self.trainer.epoch
        if epoch % 3 == 0:
           save_checkpoint()

再写一个简单的Trainer类。

class Trainer:
    def __init__(self):
        self._hooks = []

    def register_hooks(self, hooks):
        # 将钩子注册到trainer中,实际上就是放到trainer的_hooks列表里以便后续调用
        for hook in hooks:
           # 这里为每个钩子创建一个类变量,指向当前trainer。
           # 这样就可以访问trainer内部的model、optimizer、epoch,iter等。
           hook.trainer = self
           self._hooks.append(hook)

    def before_train(self):
        # trainer的before_train()函数就是
        # 调用每个注册的hook的before_train()函数
        for hook in self._hooks:
            hook.before_train()

    def after_train(self):
        for hook in self._hooks:
            hook.after_train()

    def before_epoch(self):
        for hook in self._hooks:
            hook.before_epoch()

    def after_epoch(self):
        for hook in self._hooks:
            hook.after_epoch()

    def before_iter(self):
        for hook in self._hooks:
            hook.before_iter()

    def after_iter(self):
        for hook in self._hooks:
            hook.after_iter()

    def train(self):
        self.before_train()
        for self.epoch in range(max_epochs):
            self.before_epoch()
            for self.iter, data in enumerate(data_loader):
                self.before_iter()
                self.train_one_iter(data)
                self.after_iter()
            self.after_epoch()
        self.after_train()

然后实例化我们刚才写的 CheckpointHook类,并将其注册到 Trainer中。

ckpt_hook = CheckpointHook()
trainer = Trainer()
trainer.register_hooks([ckpt_hook])
trainer.train()

上面这些是伪代码,不能运行,只是为了方便描述钩子的原理。

[6/6]开始写 Trainer

本节我们将逐行讲解 trainer.py的内容。先来看构造函数。

class Trainer:
    def __init__(
        self,
        model: nn.Module,
        optimizer: optim.Optimizer,
        lr_scheduler: optim.lr_scheduler._LRScheduler,
        data_loader: DataLoader,
        max_epochs: int,
        work_dir: str = "work_dir",
        max_num_checkpoints: int = None,
        checkpoint_period: int = 1,
        log_period: int = 50,
        clip_grad_norm: float = 0.0,
        enable_amp: bool = False,
        # 以下是关于lr warmup的参数
        by_epoch: bool = True,
        warmup_t: int = 0,
        warmup_by_epoch: bool = False,
        warmup_mode: str = "fix",
        warmup_factor: float = 0.0,
    ):
        """
        Args:
            model (torch.nn.Module)
            optimizer (torch.optim.Optimizer)
            lr_scheduler (optim.lr_scheduler._LRScheduler)
            data_loader (torch.utils.data.DataLoader): 训练数据。
            max_epochs (int): 训练多少个epoch。
            work_dir (str): 保存checkpoint和log文件的目录,默认是"work_dir"。
            max_num_checkpoints (int): 最多保存多少个checkpoint,如果是None,则保存所有的checkpoint。
            checkpoint_period (int): 每隔多少个epoch保存一次checkpoint,默认是1。
            log_period (int): 每隔多少次迭代输出一次log,默认是50。
            clip_grad_norm (float): 梯度裁剪使用的范数阈值,如果<=0,则不使用梯度裁剪。
            enable_amp (bool): 是否开启混合精度训练,默认不开启。
            by_epoch, warmup_t, warmup_by_epoch, warmup_mode, warmup_factor: 参考 lr_scheduler.py的文档。
        """
        self.model = model
        self.optimizer = optimizer
        # 将pytorch自带的lr scheduler转化为带warmup的scheduler
        self.lr_scheduler = LRWarmupScheduler(
            torch_scheduler=lr_scheduler, by_epoch=by_epoch, epoch_len=len(data_loader),
            warmup_t=warmup_t, warmup_by_epoch=warmup_by_epoch, warmup_mode=warmup_mode,
            warmup_factor=warmup_factor)
        self.data_loader = data_loader
        self.work_dir = work_dir
        # 我们使用HistoryBuff来平滑训练过程中产生的指标。
        # MetricStorage等下会详细讲,这里可以简单地理解成是一个字典。
        # 字典的key是指标的名字,value是保存这个指标的HistoryBuffer。
        self.metric_storage = MetricStorage()

        # 一些有用的计数器,在hook中可能会用到
        self.inner_iter: int  # [0, epoch_len - 1]
        self.epoch: int  # [0, max_epochs - 1]
        self.start_epoch = 0  # [0, max_epochs - 1]
        self.max_epochs = max_epochs

        # 保存所有已注册的hook
        self._hooks: List[HookBase] = []
        # 因为想记录从data_loader中加载每个batch需要的时间,所以我们
        # 没有选择用for循环,而是用迭代器的方式来遍历data_loader
        self._data_iter = iter(data_loader)
        self._max_num_checkpoints = max_num_checkpoints
        self._checkpoint_period = checkpoint_period
        self._log_period = log_period
        self._clip_grad_norm = clip_grad_norm
        self._enable_amp = enable_amp

        # 注册一些默认的hooks
        self.register_hooks(self._build_default_hooks())
        logger.info(f"Registered default hooks: {self.registered_hook_names}")

        if self._enable_amp:
            logger.info("Automatic Mixed Precision (AMP) training is on.")
            self._grad_scaler = GradScaler()

然后我们又定义了一些有用的属性,注意这里使用了python的 @property装饰器。以 self.max_iters为例,当更改 self.max_epochs时, self.max_iters的数值也会自动改变,不需要我们手动更新。

    @property
    def lr(self) -> float:
        return self.optimizer.param_groups[0]["lr"]

    @property
    def epoch_len(self) -> int:
        """每个epoch包含多少个iter(iteration,迭代)"""
        return len(self.data_loader)

    @property
    def max_iters(self) -> int:
        """一共需要训练多少个iter"""
        return self.max_epochs * self.epoch_len

    @property
    def cur_iter(self) -> int:
        """当前iter的索引,范围是[0, max_iters - 1]"""
        return self.epoch * self.epoch_len + self.inner_iter

    @property
    def start_iter(self) -> int:
        """训练从哪个iter开始,最小值是0"""
        return self.start_epoch * self.epoch_len

    @property
    def ckpt_dir(self) -> str:
        """保存checkpoints的目录"""
        return osp.join(self.work_dir, "checkpoints")

    @property
    def tb_log_dir(self) -> str:
        """保存tensorboard log文件的目录"""
        return osp.join(self.work_dir, "tb_logs")

    @property
    def model_or_module(self) -> nn.Module:
        if isinstance(self.model, (DistributedDataParallel, DataParallel)):
            return self.model.module
        return self.model

    @property
    def registered_hook_names(self) -> List[str]:
        """所有已经注册的钩子的名称"""
        return [h.__class__.__name__ for h in self._hooks]

然后是一些和钩子相关的方法。

    def register_hooks(self, hooks: List[Optional[HookBase]]) -> None:
        hooks = [h for h in hooks if h is not None]
        for h in hooks:
            assert isinstance(h, HookBase)
            # 这里使用weakref是为了避免循环引用导致的内存泄漏
            h.trainer = weakref.proxy(self)
            # LoggerHook旨在记录训练过程中产生的一些指标(模型的loss、其它钩子产生的指标)
            # 为了不遗漏其它钩子产生的信息,LoggerHook应该被最后执行(注册的时候放到末尾)
            if self._hooks and isinstance(self._hooks[-1], LoggerHook):
                self._hooks.insert(len(self._hooks) - 1, h)
            else:
                self._hooks.append(h)

    def _call_hooks(self, stage: str) -> None:
        """_call_hooks("before_epoch")可以调用所有钩子的before_epoch()方法"""
        for h in self._hooks:
            getattr(h, stage)()

    def _build_default_hooks(self) -> List[HookBase]:
        return [
            LRUpdateHook(),  # 更新学习率应该在iter结束之后立马执行,因此要放在第一位
            CheckpointerHook(self._checkpoint_period, self._max_num_checkpoints),
            LoggerHook(self._log_period, tb_log_dir=self.tb_log_dir),
        ]

然后就是核心的训练环节。

    def train(self) -> None:
        logger.info(f"Start training from epoch {self.start_epoch}")
        self._prepare_for_training()
        self._call_hooks("before_train")
        for self.epoch in range(self.start_epoch, self.max_epochs):
            self._call_hooks("before_epoch")
            self._train_one_epoch()
            self._call_hooks("after_epoch")
        self._call_hooks("after_train")

    def _train_one_epoch(self) -> None:
        # 测试hook会把模型设置为eval模式,因此这里要设置为train模式
        self.model.train()
        for self.inner_iter in range(self.epoch_len):
            self._call_hooks("before_iter")
            self.train_one_iter()
            self._call_hooks("after_iter")
        # 这里要手动更新迭代器,不然迭代到data_loader末尾时会引发异常
        self._data_iter = iter(self.data_loader)

    def train_one_iter(self) -> None:
        iter_start_time = time.perf_counter()

        # 1. 从data_loader中加载batch
        # 为了计算数据加载时间,我们选择通过迭代器读取数据,而不是"for data in data_loader"
        start = time.perf_counter()
        batch = next(self._data_iter)
        data_time = time.perf_counter() - start

        # 2. 计算loss
        if self._enable_amp:
            with autocast():
                loss_dict = self.model(batch)
        else:
            loss_dict = self.model(batch)
        if isinstance(loss_dict, torch.Tensor):
            losses = loss_dict
            loss_dict = {"total_loss": loss_dict}
        else:
            losses = sum(loss_dict.values())

        # 3. 计算梯度
        self.optimizer.zero_grad()
        if self._enable_amp:
            self._grad_scaler.scale(losses).backward()
        else:
            losses.backward()
        if self._clip_grad_norm > 0:
            if self._enable_amp:
                self._grad_scaler.unscale_(self.optimizer)
            clip_grad_norm_(self.model.parameters(), self._clip_grad_norm)

        # 4. 更新模型权重
        if self._enable_amp:
            self._grad_scaler.step(self.optimizer)
            self._grad_scaler.update()
        else:
            self.optimizer.step()

        # 记录此次迭代产生的各项指标
        self._log_iter_metrics(loss_dict, data_time, time.perf_counter() - iter_start_time, self.lr)

接下来详细讲讲 Trainer中如何自动记录训练过程中产生的各项指标。首先回顾一下 HistoryBuffer的功能,如果内容为

     --------
0.5  0.8  0.3
     --------

HistoryBuffer同时提供了 avg()方法用来获得窗口内的平均值0.55,和 latest()方法获得最新的数值0.3。

随着训练的进行,会把每次迭代的指标保存到字典 metric_storage中。假设已经进行了3次迭代,此时 metric_storage的内容为

"loss":
     --------
0.5  0.8  0.3
     --------
"lr":
     ---------
0.1  0.1  0.08
     ---------

LoggerHook会每隔几个iter来从 metric_storage中取出各项指标的值,并打印至终端。此时对于需要平滑的指标"loss", LoggerHook打印的是 metric_storage["loss"].avg,对于不需要平滑的指标"lr",打印的是 metric_storage["lr"].latest

那么 LoggerHook又该如何知道哪个指标需要平滑呢?这就需要我们在更新 metric_storage的时候指定一个 smooth参数,如果 smooth为True代表该指标需要平滑,否则不需要。

metric_storage.update(loss=loss_value, smooth=True)
metric_storage.update(lr=lr_value, smooth=False)

我们没有采用简单的字典,而是设计了 MetricStorage类,就是为了可以自动确定哪些指标需要平滑。 MetricStorage的代码如下。

class MetricStorage(dict):
    """该类在训练过程中存储多个度量的值(其中一些可能有噪声,例如损失、批处理时间),并提
    供对平滑值的访问,以便更好地记录。

    该类是为自动记录指标而设计的。用户在调用update()方法时应该指定smooth参数,
    以便可以确定在执行日志记录时应该平滑哪些指标。

    Example::

        >>> metric_storage = MetricStorage()
        >>> metric_storage.update(iter=0, loss=0.2)
        >>> metric_storage.update(iter=0, lr=0.01, smooth=False)
        >>> metric_storage.update(iter=1, loss=0.1)
        >>> metric_storage.update(iter=1, lr=0.001, smooth=False)
        >>> # loss将会被平滑, 但lr不会
        >>> metric_storage.values_maybe_smooth
        {"loss": (1, 0.15), "lr": (1, 0.001)}
        >>> # 类似于字典,可以直接用字符串索引
        >>> metric_storage["loss"].avg
        0.15
    """

    def __init__(self, window_size: int = 20) -> None:
        self._window_size = window_size
        self._history: Dict[str, HistoryBuffer] = self
        self._smooth: Dict[str, bool] = {}
        self._latest_iter: Dict[str, int] = {}

    def update(self, iter: Optional[int] = None, smooth: bool = True, **kwargs) -> None:
        """添加在特定迭代中生成的多个指标的数值。

        Args:
            iter (int): 这些数值是在哪个迭代生成的。如果为None,则使用内置的从0开始的计数器。
            smooth (bool): 如果为True,则在调用values_maybe_smooth()时返回这些指标的平滑值。
                   否则,返回最新的值。在对update()的不同调用中,同一指标必须具有相同的smooth参数值.
        """
        for key, value in kwargs.items():
            if key in self._smooth:
                assert self._smooth[key] == smooth
            else:
                self._smooth[key] = smooth
                self._history[key] = HistoryBuffer(window_size=self._window_size)
                self._latest_iter[key] = -1
            if iter is not None:
                assert iter > self._latest_iter[key]
                self._latest_iter[key] = iter
            else:
                self._latest_iter[key] += 1
            self._history[key].update(value)

    @property
    def values_maybe_smooth(self) -> Dict[str, Tuple[int, float]]:
        """返回平滑值或最新值。具体行为取决于更新指标时的smooth参数。"""
        return {
            key: (self._latest_iter[key], his_buf.avg if self._smooth[key] else his_buf.latest)
            for key, his_buf in self._history.items()
        }

在每次迭代结束时,我们都使用 _log_iter_metrics()函数来记录本次迭代产生的一些信息。

    def log(self, *args, **kwargs) -> None:
        """调用metric_storage字典的update函数"""
        self.metric_storage.update(*args, **kwargs)

    def _log_iter_metrics(self, loss_dict: Dict[str, torch.Tensor], data_time: float,
                          iter_time: float, lr: float) -> None:
        """
        Args:
            loss_dict (dict): 模型产生的损失,默认是一个字典,{"loss1": 0.5, "loss2": 0.6}.
            data_time (float): 从data_loader中加载batch需要的时间.
            iter_time (float): 完成一次iter需要的时间.
            lr (float): 本次迭代使用的学习率.
        """
        self.log(self.cur_iter, data_time=data_time, iter_time=iter_time)
        self.log(self.cur_iter, lr=lr, smooth=False)

        loss_dict = {k: v.detach().cpu().item() for k, v in loss_dict.items()}
        loss_value = sum(loss_dict.values())
        if not np.isfinite(loss_value):
            raise FloatingPointError(
                f"Loss became infinite or NaN at epoch={self.epoch}! loss_dict = {loss_dict}."
            )

        self.log(self.cur_iter, total_loss=loss_value)
        if len(loss_dict) > 1:
            self.log(self.cur_iter, **loss_dict)

总结

感谢诸位读者的时间,你们的阅读让我的创作有了意义。断断续续这么久,终于更新完了。文章只能挑一些重点来讲,不可避免的会遗漏一些细节。感兴趣的朋友可以去github中查看完整的最新代码,欢迎批评指正+点赞fork!

路漫漫其修远,希望CPU这个库能在日后的工作学习中陪着我一起披荆斩棘,炼出有用的丹药。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值