动作识别0-07:mmaction2(SlowFast)-源码无死角解析(3)-训练架构总览-2

以下链接是个人关于mmaction2(SlowFast-动作识别) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号 海量资源。 \color{blue}{ 海量资源}。 海量资源

动作识别0-00:mmaction2(SlowFast)-目录-史上最新无死角讲解

极度推荐的商业级项目: \color{red}{极度推荐的商业级项目:} 极度推荐的商业级项目:这是本人落地的行为分析项目,主要包含(1.行人检测,2.行人追踪,3.行为识别三大模块):行为分析(商用级别)00-目录-史上最新无死角讲解

前言

我们在训练模型的时候,是执行如下指令:

python tools/train.py configs/recognition/slowfast/my_slowfast_r50_4x16x1_256e_ucf101_rgb.py   --work-dir work_dirs/my_slowfast_r50_4x16x1_256e_ucf101_rgb    --validate --seed 0 --deterministic

前面的博客我们对 tools/train.py 进行了分析,可以知道其会调用到项目根目录下的 mmaction/apis/train.py 中的 def train_model 函数,该函数会调用以下函数:

    # 创建模型训练的类
    runner = EpochBasedRunner(model,optimizer=optimizer,work_dir=cfg.work_dir,logger=logger,meta=meta)
	# 获得时间戳
    runner.timestamp = timestamp

    # register hooks,注册训练模型的相关组件,如学习率,优化器,预训练模型等
    runner.register_training_hooks(cfg.lr_config, optimizer_config,cfg.checkpoint_config, cfg.log_config,cfg.get('momentum_config', None))

   # 对模型进行评估验证
   if validate:
		# 注册训练的钩子
        eval_hook = DistEvalHook if distributed else EvalHook
        runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

    # 加载预训练模型
    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)

    # 正式开始模型训练
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)

从上面的总结,可以看到其核心都在于EpochBasedRunner这个类创建的对象runner。对于EpochBasedRunner的注释,本人如下。

EpochBasedRunner

# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import time
import warnings

import torch

import mmcv
from .base_runner import BaseRunner
from .checkpoint import save_checkpoint
from .utils import get_host_info


class EpochBasedRunner(BaseRunner):
    """Epoch-based Runner.

    This runner train models epoch by epoch.
    """

    def train(self, data_loader, **kwargs):
        """
        :param data_loader:训练数据迭代器
        :param kwargs:模型模型的一些相关参数
        :return:
        """
        # 设置模型为训练模式
        self.model.train()
        self.mode = 'train'
        # 赋值训练数据迭代器
        self.data_loader = data_loader
        # 获得每个epoch最大的迭代次数
        self._max_iters = self._max_epochs * len(data_loader)
        self.call_hook('before_train_epoch')
        # 防止转型期可能出现的僵局
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        # 循环迭代数据进行训练
        for i, data_batch in enumerate(data_loader):
            # 记录迭代次数
            self._inner_iter = i
            self.call_hook('before_train_iter')

            # 如果不需要预处理,则直接训练当前batch的数据
            if self.batch_processor is None:
                outputs = self.model.train_step(data_batch, self.optimizer,**kwargs)
            # 如果需要预处理,则先进行数据预处理
            else:
                outputs = self.batch_processor(
                    self.model, data_batch, train_mode=True, **kwargs)

            # 如果输出结果不是一个字典则报错
            if not isinstance(outputs, dict):
                raise TypeError('"batch_processor()" or "model.train_step()"'
                                ' must return a dict')
            # 进行log打印
            if 'log_vars' in outputs:
                self.log_buffer.update(outputs['log_vars'],
                                       outputs['num_samples'])
            # 进行迭代后的处理,如反向传播
            self.outputs = outputs
            self.call_hook('after_train_iter')
            self._iter += 1
        # 迭代一个epoch的后处理
        self.call_hook('after_train_epoch')
        self._epoch += 1

    def val(self, data_loader, **kwargs):
        # 配置为验证模式
        self.model.eval()
        self.mode = 'val'
        # 加载验证数据迭代器
        self.data_loader = data_loader
        self.call_hook('before_val_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        # 循环迭代数据进行训练
        for i, data_batch in enumerate(data_loader):
            # 记录迭代次数
            self._inner_iter = i
            self.call_hook('before_val_iter')
            # 设置不进行梯度传播
            with torch.no_grad():
                # 如果需要进行预处理再进行反向传播,则先进行预处理
                if self.batch_processor is None:
                    outputs = self.model.val_step(data_batch, self.optimizer,
                                                  **kwargs)
                else:
                    outputs = self.batch_processor(
                        self.model, data_batch, train_mode=False, **kwargs)
            if not isinstance(outputs, dict):
                raise TypeError('"batch_processor()" or "model.val_step()"'
                                ' must return a dict')
            # 打印los日志
            if 'log_vars' in outputs:
                self.log_buffer.update(outputs['log_vars'],
                                       outputs['num_samples'])

            self.outputs = outputs
            # 进行验证后的处理,如反向传播
            self.call_hook('after_val_iter')
        # 迭代一个epoch的后处理
        self.call_hook('after_val_epoch')

    def run(self, data_loaders, workflow, max_epochs, **kwargs):
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.该为一个列表,其中可以包含了训练数据迭代器,
                以及验证数据迭代器。

            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.该为包含了多个(phase, epochs)元组形式的数组,如
                [('train', 2), ('val', 1)]表示训练两个epoch之后进行一次验证

            max_epochs (int): Total training epochs.表示最大的迭代次数
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)

        # 对输入数据进行判断,看是否符合规范
        self._max_epochs = max_epochs
        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if mode == 'train':
                self._max_iters = self._max_epochs * len(data_loaders[i])
                break

        # 进行loger信息打印
        #work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
        self.call_hook('before_run')

        # 如果当前的self.epoch小于最大的max_epochs,则继续训练
        while self.epoch < max_epochs:
            # 获得数据迭代器是训练,还验证模式,以及需要执行多少个epoch
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                # 如果mode为字符串,
                if isinstance(mode, str):  # self.train()
                    if not hasattr(self, mode):
                        raise ValueError(
                            f'runner has no method named "{mode}" to run an '
                            'epoch')
                    epoch_runner = getattr(self, mode)
                # 如果mode不为字符串
                else:
                    raise TypeError(
                        'mode in workflow must be a str, but got {}'.format(
                            type(mode)))
                # 循环迭代训练epochs次
                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= max_epochs:
                        break
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')

    # 保存模型
    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='epoch_{}.pth',
                        save_optimizer=True,
                        meta=None,
                        create_symlink=True):
        """Save the checkpoint.

        Args:
            out_dir (str): The directory that checkpoints are saved.
            filename_tmpl (str, optional): The checkpoint filename template,
                which contains a placeholder for the epoch number.
                Defaults to 'epoch_{}.pth'.
            save_optimizer (bool, optional): Whether to save the optimizer to
                the checkpoint. Defaults to True.
            meta (dict, optional): The meta information to be saved in the
                checkpoint. Defaults to None.
            create_symlink (bool, optional): Whether to create a symlink
                "latest.pth" to point to the latest checkpoint.
                Defaults to True.
        """
        if meta is None:
            meta = dict(epoch=self.epoch + 1, iter=self.iter)
        elif isinstance(meta, dict):
            meta.update(epoch=self.epoch + 1, iter=self.iter)
        else:
            raise TypeError(
                f'meta should be a dict or None, but got {type(meta)}')
        if self.meta is not None:
            meta.update(self.meta)

        filename = filename_tmpl.format(self.epoch + 1)
        filepath = osp.join(out_dir, filename)
        optimizer = self.optimizer if save_optimizer else None
        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
        # in some environments, `os.symlink` is not supported, you may need to
        # set `create_symlink` to False
        if create_symlink:
            mmcv.symlink(filename, osp.join(out_dir, 'latest.pth'))



class Runner(EpochBasedRunner):
    """Deprecated name of EpochBasedRunner."""

    def __init__(self, *args, **kwargs):
        warnings.warn(
            'Runner was deprecated, please use EpochBasedRunner instead')
        super().__init__(*args, **kwargs)

结语

相信到了这里,各位小哥哥,小姐姐对于模型训练整体结构,以及细节之处,应该都十分的了解,下面我们就对数据的预处理部分开始分析吧。因为我们总得知道数据的来源,如加载路径,缩放,归一化,数据增强等等操作。

在这里插入图片描述

  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

江南才尽,年少无知!

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值