动作识别0-09:mmaction2(SlowFast)-源码无死角解析(5)-数据加载,预处理2

以下链接是个人关于mmaction2(SlowFast-动作识别) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号 海量资源。 \color{blue}{ 海量资源}。 海量资源

动作识别0-00:mmaction2(SlowFast)-目录-史上最新无死角讲解

极度推荐的商业级项目: \color{red}{极度推荐的商业级项目:} 极度推荐的商业级项目:这是本人落地的行为分析项目,主要包含(1.行人检测,2.行人追踪,3.行为识别三大模块):行为分析(商用级别)00-目录-史上最新无死角讲解

前言

根据上一篇博客,我们知道数据加载的类主要涉及到:
mmaction/datasets/rawframe_dataset.py中的 class RawframeDataset( B a s e D a t a s e t \color{red}{BaseDataset} BaseDataset):
或者
mmaction/datasets/video_dataset.py中的 class VideoDataset( B a s e D a t a s e t \color{red}{BaseDataset} BaseDataset):
对于BaseDataset在上篇博客已经详细介绍了,该篇博客我们就来分析一下RawframeDataset以及VideoDataset。代码的注释如下。

VideoDataset

mmaction/datasets/rawframe_dataset.py中的 class RawframeDataset(BaseDataset):
文章最后有总结分析 \color{red}{文章最后有总结分析} 文章最后有总结分析

import copy
import os.path as osp

import torch
from mmcv.utils import print_log

from ..core import mean_average_precision, mean_class_accuracy, top_k_accuracy
from .base import BaseDataset
from .registry import DATASETS

# 把这个类注册到DATASETS容器之中
@DATASETS.register_module()
class RawframeDataset(BaseDataset):
    """Rawframe dataset for action recognition.
    数据集加载原始帧并应用指定的转换,然后返回包含帧张量和其他信息的字典
    The dataset loads raw frames and apply specified transforms to return a
    dict containing the frame tensors and other information.

    注释文件文件存在多行,每行标识了视频帧的存放目录,视频帧的总和,
    以及视频对应的标签。他们都是使用空格隔开的,
    The ann_file is a text file with multiple lines, and each line indicates
    the directory to frames of a video, total frames of the video and
    the label of a video, which are split with a whitespace.

    下面是一个注释文件的例子
    Example of a annotation file:

    .. code-block:: txt

        some/directory-1 163 1
        some/directory-2 122 1
        some/directory-3 258 2
        some/directory-4 234 2
        some/directory-5 295 3
        some/directory-6 121 3

    如果一个视频存在多个标签文件,注释如下
    Example of a multi-class annotation file:
    .. code-block:: txt

        some/directory-1 163 1 3 5
        some/directory-2 122 1 2
        some/directory-3 258 2
        some/directory-4 234 2 4 6 8
        some/directory-5 295 3
        some/directory-6 121 3


    Args:
        ann_file (str): Path to the annotation file.注释文件的路径
        pipeline (list[dict | callable]): A sequence of data transforms,数据转换序列
        data_prefix (str): Path to a directory where videos are held.存放视频的目录
            Default: None.

        在构建测试或验证数据集时需要设置为ture
        test_mode (bool): Store True when building test or validation dataset.
            Default: False.

        # 帧图片名的模板
        filename_tmpl (str): Template for each filename.
            Default: 'img_{:05}.jpg'.

        # 是否进行多标签的训练或者测试
        multi_class (bool): Determines whether it is a multi-class
            recognition dataset. Default: False.

        # 数据集的类别数目
        num_classes (int): Number of classes in the dataset. Default: None.

        # 数据的格式,默认为RGB
        modality (str): Modality of data. Support 'RGB', 'Flow'.
                            Default: 'RGB'.
    """

    def __init__(self,
                 ann_file, # 注释文件的路径
                 pipeline, # 数据转换序列
                 data_prefix=None, # 存放视频的目录
                 test_mode=False, # 在构建测试或验证数据集时需要设置为ture
                 filename_tmpl='img_{:05}.jpg', # 帧图片名的模板
                 multi_class=False, # 是否进行多标签的训练或者测试
                 num_classes=None, # 数据集的类别数目
                 modality='RGB'): # 数据的格式,默认为RGB
        # 调用父类的初始化函数
        super().__init__(ann_file, pipeline, data_prefix, test_mode,
                         multi_class, num_classes, modality)
        # 帧图片名的模板
        self.filename_tmpl = filename_tmpl

    def load_annotations(self):
        """Load annotation file to get video information.
        加载注释文件,获得视频信息
        """
        video_infos = []
        with open(self.ann_file, 'r') as fin:
            # 循环读取注释文件的每一行数据
            for line in fin:
                # 先去除首位的空格,然后进行分割
                line_split = line.strip().split()
                # 如果是进行多标签的训练
                if self.multi_class:
                    # 检测标注信息的self.num_classes是否为None
                    assert self.num_classes is not None
                    # 获得存放图片帧的目录,其中的帧数总数,以及该视频的标签类别
                    (frame_dir, total_frames,label) = (line_split[0], line_split[1], line_split[2:])

                    # 一次把标签转化为int型
                    label = list(map(int, label))
                    # 把标签转化为onehot格式
                    onehot = torch.zeros(self.num_classes)
                    onehot[label] = 1.0
                # 如果不是进行多标签的训练
                else:
                    # 获得存放视频帧的目录,帧总数,以及对应的视频标签
                    frame_dir, total_frames, label = line_split
                    label = int(label)
                # 把前缀目录凭借起来
                if self.data_prefix is not None:
                    frame_dir = osp.join(self.data_prefix, frame_dir)

                # 把解析之后每行的注释信息添加到video_infos之中
                video_infos.append(
                    dict(
                        frame_dir=frame_dir,
                        total_frames=int(total_frames),
                        label=onehot if self.multi_class else label))
        return video_infos

    def prepare_train_frames(self, idx):
        """Prepare the frames for training given the index.
        根据idx,对训练数据进行序列转换"""
        results = copy.deepcopy(self.video_infos[idx])
        results['filename_tmpl'] = self.filename_tmpl
        results['modality'] = self.modality
        return self.pipeline(results)

    def prepare_test_frames(self, idx):
        """Prepare the frames for testing given the index.
        根据idx,对训练数据进行序列转换"""
        results = copy.deepcopy(self.video_infos[idx])
        results['filename_tmpl'] = self.filename_tmpl
        results['modality'] = self.modality
        return self.pipeline(results)

    # 重写评估函数
    def evaluate(self,
                 results, # 网络推断的结果
                 metrics='top_k_accuracy', # 度量准确率的方式
                 topk=(1, 5), # 如果前topk个预测对了,则认为其预测正确
                 logger=None):
        """Evaluation in rawframe dataset.

        Args:
            results (list): Output results.
            metrics (str | sequence[str]): Metrics to be performed.
                Defaults: 'top_k_accuracy'.
            logger (obj): Training logger. Defaults: None.
            topk (int | tuple[int]): K value for top_k_accuracy metric.
                Defaults: (1, 5).
            logger (logging.Logger | None): Logger for recording.
                Default: None.

        Returns:
            dict: Evaluation results dict.
        """
        # 如果输入的results不是列表则报错
        if not isinstance(results, list):
            raise TypeError(f'results must be a list, but got {type(results)}')
        assert len(results) == len(self), (
            f'The length of results is not equal to the dataset len: '
            f'{len(results)} != {len(self)}')

        # 如果输入的topk不为整形获得元组则报错
        if not isinstance(topk, (int, tuple)):
            raise TypeError(
                f'topk must be int or tuple of int, but got {type(topk)}')

        # 如果topk为单个整型
        if isinstance(topk, int):
            topk = (topk, )

        #
        metrics = metrics if isinstance(metrics, (list, tuple)) else [metrics]

        allowed_metrics = [
            'top_k_accuracy', 'mean_class_accuracy', 'mean_average_precision'
        ]


        # 如果评估的方式不在allowed_metrics中则报错
        for metric in metrics:
            if metric not in allowed_metrics:
                raise KeyError(f'metric {metric} is not supported')

        # 存储评估的结果
        eval_results = {}
        # 从注释文件获得的标签gt_labels
        gt_labels = [ann['label'] for ann in self.video_infos]

        for metric in metrics:
            #添加log信息
            msg = f'Evaluating {metric}...'
            if logger is None:
                msg = '\n' + msg
            print_log(msg, logger=logger)

            # 如果评估的方式为top_k_accuracy
            if metric == 'top_k_accuracy':
                top_k_acc = top_k_accuracy(results, gt_labels, topk)
                log_msg = []
                for k, acc in zip(topk, top_k_acc):
                    eval_results[f'top{k}_acc'] = acc
                    log_msg.append(f'\ntop{k}_acc\t{acc:.4f}')
                log_msg = ''.join(log_msg)
                print_log(log_msg, logger=logger)
                continue

            # 如果评估的方式为mean_class_accuracy
            if metric == 'mean_class_accuracy':
                mean_acc = mean_class_accuracy(results, gt_labels)
                eval_results['mean_class_accuracy'] = mean_acc
                log_msg = f'\nmean_acc\t{mean_acc:.4f}'
                print_log(log_msg, logger=logger)
                continue

            # 如果评估的方式为mean_average_precision
            if metric == 'mean_average_precision':
                gt_labels = [label.cpu().numpy() for label in gt_labels]
                mAP = mean_average_precision(results, gt_labels)
                eval_results['mean_average_precision'] = mAP
                log_msg = f'\nmean_average_precision\t{mAP:.4f}'
                print_log(log_msg, logger=logger)
                continue

        return eval_results

VideoDataset

mmaction/datasets/video_dataset.py中的 class VideoDataset( B a s e D a t a s e t \color{red}{BaseDataset} BaseDataset):

import os.path as osp

import torch
from mmcv.utils import print_log

from ..core import mean_class_accuracy, top_k_accuracy
from .base import BaseDataset
from .registry import DATASETS


@DATASETS.register_module()
class VideoDataset(BaseDataset):
    """Video dataset for action recognition.
    直接加载视频源数据,经过指定转换之后返回一个包含多个frame tensors的字典,其中还包含了一些其他的信息
    The dataset loads raw videos and apply specified transforms to return a
    dict containing the frame tensors and other information.

     注释文件文件存在多行,每行标识了视频存放的路径,以及视频对应的标签。他们都是使用空格隔开的,
    The ann_file is a text file with multiple lines, and each line indicates
    a sample video with the filepath and label, which are split with a
    whitespace. Example of a annotation file:

    .. code-block:: txt

        some/path/000.mp4 1
        some/path/001.mp4 1
        some/path/002.mp4 2
        some/path/003.mp4 2
        some/path/004.mp4 3
        some/path/005.mp4 3
    """

    def load_annotations(self):
        """Load annotation file to get video information.
        加载视频注释文件"""
        video_infos = []
        with open(self.ann_file, 'r') as fin:
            # 对文件的每一行进行处理
            for line in fin:
                # 先去除首位的空格,然后进行分割
                line_split = line.strip().split()
                # 如果是进行多标签的训练
                if self.multi_class:
                    # 检测标注信息的self.num_classes是否为None
                    assert self.num_classes is not None
                    # 获取视频的路径名称,以及对应的类别标签
                    filename, label = line_split[0], line_split[1:]
                    # 分别把每个标签转化为整形,然后变换成onehot格式
                    label = list(map(int, label))
                    onehot = torch.zeros(self.num_classes)
                    onehot[label] = 1.0
                # 如果不是进行多标签的训练
                else:
                    # 获取视频的名称,以及对应的类别标签
                    filename, label = line_split
                    label = int(label)

                # 把前缀目录个视频名称拼接起来
                if self.data_prefix is not None:
                    filename = osp.join(self.data_prefix, filename)

                # 把视频信息添加到 video_infos之中
                video_infos.append(
                    dict(
                        filename=filename,
                        label=onehot if self.multi_class else label))
        return video_infos

    def evaluate(self,
                 results, # 网络推断的结果
                 metrics='top_k_accuracy', # 度量准确率的方式
                 topk=(1, 5), # 如果前topk个预测对了,则认为其预测正确
                 logger=None):
        """Evaluation in rawframe dataset.

        Args:
            results (list): Output results.
            metrics (str | sequence[str]): Metrics to be performed.
                Defaults: 'top_k_accuracy'.
            logger (obj): Training logger. Defaults: None.
            topk (tuple[int]): K value for top_k_accuracy metric.
                Defaults: (1, 5).
            logger (logging.Logger | None): Logger for recording.
                Default: None.

        Return:
            dict: Evaluation results dict.
        """
        # 如果输入的results不是列表则报错
        if not isinstance(results, list):
            raise TypeError(f'results must be a list, but got {type(results)}')
        assert len(results) == len(self), (
            f'The length of results is not equal to the dataset len: '
            f'{len(results)} != {len(self)}')

        # 如果输入的topk不为整形获得元组则报错
        if not isinstance(topk, (int, tuple)):
            raise TypeError(
                f'topk must be int or tuple of int, but got {type(topk)}')

        # 如果评估的方式不在allowed_metrics中则报错
        metrics = metrics if isinstance(metrics, (list, tuple)) else [metrics]
        allowed_metrics = ['top_k_accuracy', 'mean_class_accuracy']
        for metric in metrics:
            if metric not in allowed_metrics:
                raise KeyError(f'metric {metric} is not supported')

        # 存储评估的结果
        eval_results = {}
        # 从注释文件获得的标签gt_labels
        gt_labels = [ann['label'] for ann in self.video_infos]

        for metric in metrics:
            #添加log信息
            msg = f'Evaluating {metric}...'
            if logger is None:
                msg = '\n' + msg
            print_log(msg, logger=logger)

            # 如果评估的方式为top_k_accuracy
            if metric == 'top_k_accuracy':
                top_k_acc = top_k_accuracy(results, gt_labels, topk)
                log_msg = []
                for k, acc in zip(topk, top_k_acc):
                    eval_results[f'top{k}_acc'] = acc
                    log_msg.append(f'\ntop{k}_acc\t{acc:.4f}')
                log_msg = ''.join(log_msg)
                print_log(log_msg, logger=logger)
                continue

            # 如果评估的方式为mean_class_accuracy
            if metric == 'mean_class_accuracy':
                mean_acc = mean_class_accuracy(results, gt_labels)
                eval_results['mean_class_accuracy'] = mean_acc
                log_msg = f'\nmean_acc\t{mean_acc:.4f}'
                print_log(log_msg, logger=logger)
                continue

        return eval_results

总结

我们无论my_slowfast_r50_4x16x1_256e_ucf101_rgb.py中设置为如下的情况(加载视频帧):

dataset_type = 'RawframeDataset'
data_root = 'data/ucf101/rawframes'
data_root_val = 'data/ucf101/rawframes'
ann_file_train = 'data/ucf101/ucf101_train_split_1_rawframes.txt'
ann_file_val = 'data/ucf101/ucf101_val_split_1_rawframes.txt'
ann_file_test = 'data/ucf101/ucf101_val_split_1_rawframes.txt'

还是设置成(直接加载视频源数据):

dataset_type = 'VideoDataset'
data_root = 'data/ucf101/videos'
data_root_val = 'data/ucf101/videos'
ann_file_train = 'data/ucf101/ucf101_train_split_1_videos.txt'
ann_file_val = 'data/ucf101/ucf101_val_split_1_videos.txt'
ann_file_test = 'data/ucf101/ucf101_val_split_1_videos.txt'

train_pipeline = [
    dict(type='DecordInit'),
    #dict(type='SampleFrames', clip_len=32, frame_interval=2, num_clips=1),
    dict(type='SampleFrames', clip_len=16, frame_interval=2, num_clips=1),
    #dict(type='FrameSelector'),

val_pipeline = [
    dict(type='DecordInit'),
    #dict(type='SampleFrames',clip_len=32,frame_interval=2,num_clips=1,test_mode=True),
    dict(type='SampleFrames', clip_len=16, frame_interval=2, num_clips=1, test_mode=True),
    #dict(type='FrameSelector'),

test_pipeline = [
    dict(type='DecordInit'),
    #dict(type='SampleFrames',clip_len=32,frame_interval=2,num_clips=1,test_mode=True),
    dict(type='SampleFrames', clip_len=16, frame_interval=2, num_clips=1, test_mode=True),
    #dict(type='FrameSelector'),
    dict(type='DecordDecode'),

其数据的输出都是BaseDataset类中如下函数的返回结果:

    def __getitem__(self, idx):
        """Get the sample for either training or testing given index.
        根据训练或者测试模式,进行不同的数据转换
        """
        if self.test_mode:
            return self.prepare_test_frames(idx)
        else:
            return self.prepare_train_frames(idx)

那么他输出的到底是什么东西呢?本人截图如下:
在这里插入图片描述
其上的imgs形状为NCTHW = [1,3,16,224,224], label形状为[1]。

1.这里的224x224表示图片的分辨率。
2.其中的16来自我们cfg文件设置的clip_len=16,表示选取了16帧图像
3.然后就剩下一个3了,3表示随机选取出来的每一帧,都进行了3次随机剪裁。

好了,到这里为止,我相信大家应该是十分的清楚数据迭代器获得的数据是什么了。

在这里插入图片描述

  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

江南才尽,年少无知!

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值