以下链接是个人关于mmaction2(SlowFast-动作识别) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号− 海量资源。 \color{blue}{ 海量资源}。 海量资源。
动作识别0-00:mmaction2(SlowFast)-目录-史上最新无死角讲解
极度推荐的商业级项目: \color{red}{极度推荐的商业级项目:} 极度推荐的商业级项目:这是本人落地的行为分析项目,主要包含(1.行人检测,2.行人追踪,3.行为识别三大模块):行为分析(商用级别)00-目录-史上最新无死角讲解
前言
我们在训练模型的时候,是执行如下指令:
python tools/train.py configs/recognition/slowfast/my_slowfast_r50_4x16x1_256e_ucf101_rgb.py --work-dir work_dirs/my_slowfast_r50_4x16x1_256e_ucf101_rgb --validate --seed 0 --deterministic
前面的博客我们对 tools/train.py 进行了分析,可以知道其会调用到项目根目录下的 mmaction/apis/train.py 中的 def train_model 函数,该函数会调用以下函数:
# 创建模型训练的类
runner = EpochBasedRunner(model,optimizer=optimizer,work_dir=cfg.work_dir,logger=logger,meta=meta)
# 获得时间戳
runner.timestamp = timestamp
# register hooks,注册训练模型的相关组件,如学习率,优化器,预训练模型等
runner.register_training_hooks(cfg.lr_config, optimizer_config,cfg.checkpoint_config, cfg.log_config,cfg.get('momentum_config', None))
# 对模型进行评估验证
if validate:
# 注册训练的钩子
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
# 加载预训练模型
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
# 正式开始模型训练
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
从上面的总结,可以看到其核心都在于EpochBasedRunner这个类创建的对象runner。对于EpochBasedRunner的注释,本人如下。
EpochBasedRunner
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import time
import warnings
import torch
import mmcv
from .base_runner import BaseRunner
from .checkpoint import save_checkpoint
from .utils import get_host_info
class EpochBasedRunner(BaseRunner):
"""Epoch-based Runner.
This runner train models epoch by epoch.
"""
def train(self, data_loader, **kwargs):
"""
:param data_loader:训练数据迭代器
:param kwargs:模型模型的一些相关参数
:return:
"""
# 设置模型为训练模式
self.model.train()
self.mode = 'train'
# 赋值训练数据迭代器
self.data_loader = data_loader
# 获得每个epoch最大的迭代次数
self._max_iters = self._max_epochs * len(data_loader)
self.call_hook('before_train_epoch')
# 防止转型期可能出现的僵局
time.sleep(2) # Prevent possible deadlock during epoch transition
# 循环迭代数据进行训练
for i, data_batch in enumerate(data_loader):
# 记录迭代次数
self._inner_iter = i
self.call_hook('before_train_iter')
# 如果不需要预处理,则直接训练当前batch的数据
if self.batch_processor is None:
outputs = self.model.train_step(data_batch, self.optimizer,**kwargs)
# 如果需要预处理,则先进行数据预处理
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
# 如果输出结果不是一个字典则报错
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
' must return a dict')
# 进行log打印
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
# 进行迭代后的处理,如反向传播
self.outputs = outputs
self.call_hook('after_train_iter')
self._iter += 1
# 迭代一个epoch的后处理
self.call_hook('after_train_epoch')
self._epoch += 1
def val(self, data_loader, **kwargs):
# 配置为验证模式
self.model.eval()
self.mode = 'val'
# 加载验证数据迭代器
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
# 循环迭代数据进行训练
for i, data_batch in enumerate(data_loader):
# 记录迭代次数
self._inner_iter = i
self.call_hook('before_val_iter')
# 设置不进行梯度传播
with torch.no_grad():
# 如果需要进行预处理再进行反向传播,则先进行预处理
if self.batch_processor is None:
outputs = self.model.val_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.batch_processor(
self.model, data_batch, train_mode=False, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.val_step()"'
' must return a dict')
# 打印los日志
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'],
outputs['num_samples'])
self.outputs = outputs
# 进行验证后的处理,如反向传播
self.call_hook('after_val_iter')
# 迭代一个epoch的后处理
self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs, **kwargs):
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.该为一个列表,其中可以包含了训练数据迭代器,
以及验证数据迭代器。
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.该为包含了多个(phase, epochs)元组形式的数组,如
[('train', 2), ('val', 1)]表示训练两个epoch之后进行一次验证
max_epochs (int): Total training epochs.表示最大的迭代次数
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
# 对输入数据进行判断,看是否符合规范
self._max_epochs = max_epochs
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
# 进行loger信息打印
#work_dir = self.work_dir if self.work_dir is not None else 'NONE'
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
self.call_hook('before_run')
# 如果当前的self.epoch小于最大的max_epochs,则继续训练
while self.epoch < max_epochs:
# 获得数据迭代器是训练,还验证模式,以及需要执行多少个epoch
for i, flow in enumerate(workflow):
mode, epochs = flow
# 如果mode为字符串,
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
# 如果mode不为字符串
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
# 循环迭代训练epochs次
for _ in range(epochs):
if mode == 'train' and self.epoch >= max_epochs:
break
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
# 保存模型
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None,
create_symlink=True):
"""Save the checkpoint.
Args:
out_dir (str): The directory that checkpoints are saved.
filename_tmpl (str, optional): The checkpoint filename template,
which contains a placeholder for the epoch number.
Defaults to 'epoch_{}.pth'.
save_optimizer (bool, optional): Whether to save the optimizer to
the checkpoint. Defaults to True.
meta (dict, optional): The meta information to be saved in the
checkpoint. Defaults to None.
create_symlink (bool, optional): Whether to create a symlink
"latest.pth" to point to the latest checkpoint.
Defaults to True.
"""
if meta is None:
meta = dict(epoch=self.epoch + 1, iter=self.iter)
elif isinstance(meta, dict):
meta.update(epoch=self.epoch + 1, iter=self.iter)
else:
raise TypeError(
f'meta should be a dict or None, but got {type(meta)}')
if self.meta is not None:
meta.update(self.meta)
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
mmcv.symlink(filename, osp.join(out_dir, 'latest.pth'))
class Runner(EpochBasedRunner):
"""Deprecated name of EpochBasedRunner."""
def __init__(self, *args, **kwargs):
warnings.warn(
'Runner was deprecated, please use EpochBasedRunner instead')
super().__init__(*args, **kwargs)
结语
相信到了这里,各位小哥哥,小姐姐对于模型训练整体结构,以及细节之处,应该都十分的了解,下面我们就对数据的预处理部分开始分析吧。因为我们总得知道数据的来源,如加载路径,缩放,归一化,数据增强等等操作。