mmaction中的rawframes_dataset.py

文章目录


相关的包

import mmcv
import numpy as np
import os.path as osp
from mmcv.parallel import DataContainer as DC
from torch.utils.data import Dataset

from .transforms import (GroupImageTransform)
from .utils import to_tensor

RawFramesRecord这个类提供了一些简单的封装,用来返回关于数据的一些信息(比如帧路径、该视频包含多少帧、帧标签)

class RawFramesRecord(object):
    def __init__(self, row):
        self._data = row

    @property
    def path(self):
        return self._data[0]

    @property
    def num_frames(self):
        return int(self._data[1])

    @property
    def label(self):
        return int(self._data[2])

注意from torch.utils.data import DatasetRawFramesDataset是继承自torch.utils.data的,这是因为自定义数据读取相关类的时候需要继承torch.utils.data.Dataset这个基类
关于torch.utils.data.Dataset的参考文献pytorch源码分析之torch.utils.data.Dataset类和torch.utils.data.DataLoader类
实现这个抽象类,有两个必要的函数:__len____getitem__

  • __len__(self)定义当被len()函数调用时的行为(返回容器中元素的个数)
  • __getitem__(self)定义获取容器中指定元素的行为,相当于self[key],即允许类对象可以有索引操作。
class RawFramesDataset(Dataset):
    def __init__(self,
                 ann_file,
                 img_prefix,
                 img_norm_cfg,
                 num_segments=3,
                 new_length=1,
                 new_step=1,
                 random_shift=True,
                 temporal_jitter=False,
                 modality='RGB',
                 image_tmpl='img_{}.jpg',
                 img_scale=256,
                 img_scale_file=None,
                 input_size=224,
                 div_255=False,
                 size_divisor=None,
                 proposal_file=None,
                 num_max_proposals=1000,
                 flip_ratio=0.5,
                 resize_keep_ratio=True,
                 resize_ratio=[1, 0.875, 0.75, 0.66],
                 test_mode=False,
                 oversample=None,
                 random_crop=False,
                 more_fix_crop=False,
                 multiscale_crop=False,
                 resize_crop=False,
                 rescale_crop=False,
                 scales=None,
                 max_distort=1,
                 input_format='NCHW'):
        # prefix of images path
        self.img_prefix = img_prefix

        # load annotations
        self.video_infos = self.load_annotations(ann_file)
        # normalization config
        self.img_norm_cfg = img_norm_cfg

        # parameters for frame fetching
        # number of segments
        # 视频集被分为num_segments个
        self.num_segments = num_segments
        # number of consecutive frames
        self.old_length = new_length * new_step
        self.new_length = new_length
        # number of steps (sparse sampling for efficiency of io)
        self.new_step = new_step
        # whether to temporally random shift when training
        self.random_shift = random_shift
        # whether to temporally jitter if new_step > 1
        self.temporal_jitter = temporal_jitter

        # parameters for modalities
        if isinstance(modality, (list, tuple)):
            self.modalities = modality
            num_modality = len(modality)
        else:
            self.modalities = [modality]
            num_modality = 1
        if isinstance(image_tmpl, (list, tuple)):
            self.image_tmpls = image_tmpl
        else:
            self.image_tmpls = [image_tmpl]
        assert len(self.image_tmpls) == num_modality

        # parameters for image preprocessing
        # img_scale
        if isinstance(img_scale, int):
            img_scale = (np.Inf, img_scale) # np.Inf代表无穷大
        self.img_scale = img_scale
        if img_scale_file is not None:
            self.img_scale_dict = {line.split(' ')[0]:
                                   (int(line.split(' ')[1]),
                                    int(line.split(' ')[2]))
                                   for line in open(img_scale_file)}
        else:
            self.img_scale_dict = None
        # network input size
        if isinstance(input_size, int):
            input_size = (input_size, input_size)
        self.input_size = input_size

        # parameters for specification from pre-trained networks (lecacy issue)
        self.div_255 = div_255

关于数据增强的常见操作,
翻转
resize_keep_ratio

        # parameters for data augmentation
        # flip ratio
        self.flip_ratio = flip_ratio # 图像的随机左右翻转的概率
        self.resize_keep_ratio = resize_keep_ratio # 图片放大缩小时,是否保持原图高宽比例

        # test mode or not
        self.test_mode = test_mode

        # set group flag for the sampler
        # if not self.test_mode:
        self._set_group_flag()

self._set_group_flag()函数设置,根据图片高宽比>1的设为group1,数据集所有都设为了1。

 def _set_group_flag(self):
        """Set flag according to image aspect ratio.

        Images with aspect ratio greater than 1 will be set as group 1,
        otherwise group 0.
        """
        self.flag = np.zeros(len(self), dtype=np.uint8)
        for i in range(len(self)):
            # img_info = self.img_infos[i]
            # if img_info['width'] / img_info['height'] > 1:
            self.flag[i] = 1

关于three_cropten_crop
具体实现源码在mmaction/datasets/transforms.py

根据TPN论文,three_cropten_crop是用在测试阶段的,是用来作空间完全卷积测试的近似值,并且在non-localslowfast论文中也提到了这种方法。

three_crop先把原始帧图片的短边调整到256,然后再在调整后的帧图片上随机裁剪出3个256x256的部分。

TPN中,作者在Kinetics-400上进行了three_crop测试。

ten_crop按照TSN的方法,截取帧图片的4个croner和1个center,再进行水平翻转,所以总共是10个裁剪。
GroupImageTransform这个文件来自于mmaction/datasets/transforms.py,这个文件主要对数据集做一些如数据增强等的处理。关于这个文件的介绍mmaction中的tranforms.py

        # transforms
        assert oversample in [None, 'three_crop', 'ten_crop']
        # 对数据集做一些数据增强的处理
        self.img_group_transform = GroupImageTransform(
            size_divisor=None, crop_size=self.input_size,
            oversample=oversample, random_crop=random_crop,
            more_fix_crop=more_fix_crop,
            multiscale_crop=multiscale_crop, scales=scales,
            max_distort=max_distort,
            resize_crop=resize_crop,
            rescale_crop=rescale_crop,
            **self.img_norm_cfg)

NCTHW
N为batch_size,C为通道数,T为帧数,H为高,W为宽。

        # input format
        assert input_format in ['NCHW', 'NCTHW']
        self.input_format = input_format
        '''
        self.bbox_transform = Bbox_transform()
        '''

__len__
根据ann_file获取训练集的路径、帧数和类别号。

    def __len__(self):
        return len(self.video_infos)

__getitem__
真正的读取数据操作

    def __getitem__(self, idx):
        record = self.video_infos[idx] # 获取某个视频的rawframes
        if self.test_mode:
            segment_indices, skip_offsets = self._get_test_indices(record)
        else:
            segment_indices, skip_offsets = self._sample_indices(
                record) if self.random_shift else self._get_val_indices(record)

        data = dict(num_modalities=DC(to_tensor(len(self.modalities))),
                    gt_label=DC(to_tensor(record.label), stack=True,
                                pad_dims=None))

self.random_shift默认是True,所以一般是使用self._sample_indices函数。来具体看一下这个函数。

    def _sample_indices(self, record):
        '''

        :param record: VideoRawFramesRecord
        :return: list, list
        '''
        # 把整个视频分成num_segement个片段
        average_duration = (record.num_frames -
                            self.old_length + 1) // self.num_segments
        # 只要视频总帧数大于num_segement,就成立
        if average_duration > 0:
        	# 按照num_segment把视频分段得到offsets
            offsets = np.multiply(list(range(self.num_segments)),
                                  average_duration)
            # 每个视频片段随机选一帧
            offsets = offsets + np.random.randint(average_duration,
                                                  size=self.num_segments)
        # 如果视频本身太短
        # e.g. 视频长度为6,num_segment=8,那么6/8=0.75,以0.75为间隔,并取整,所以采样到的帧为[0,2,2,3,3,3,4,4](如果序号从0开始)
        elif record.num_frames > max(self.num_segments, self.old_length):
            offsets = np.sort(np.random.randint(
                record.num_frames - self.old_length + 1,
                size=self.num_segments))
        else:# 否则,没采样到
            offsets = np.zeros((self.num_segments,))
        if self.temporal_jitter:
            skip_offsets = np.random.randint(
                self.new_step, size=self.old_length // self.new_step)
        else:
            skip_offsets = np.zeros(
                self.old_length // self.new_step, dtype=int)
        return offsets + 1, skip_offsets  # frame index starts from 1

测试时,运行self._get_test_indices

    def _get_test_indices(self, record):
        if record.num_frames > self.old_length - 1:
        # 把整个视频分成num_segment个片段
            tick = (record.num_frames - self.old_length + 1) / \
                   float(self.num_segments)
            # 选出每个片段的中间帧
            offsets = np.array([int(tick / 2.0 + tick * x)
                                for x in range(self.num_segments)])
        else:# 否则,没取到
            offsets = np.zeros((self.num_segments,))
        if self.temporal_jitter:
            skip_offsets = np.random.randint(
                self.new_step, size=self.old_length // self.new_step)
        else:
            skip_offsets = np.zeros(
                self.old_length // self.new_step, dtype=int)
        return offsets + 1, skip_offsets

回归__getitem__函数,继续数据读取操作。

        # handle the first modality
        modality = self.modalities[0]
        image_tmpl = self.image_tmpls[0]
        img_group = self._get_frames(
            record, image_tmpl, modality, segment_indices, skip_offsets)

关于self._get_frames,重点是self._load_image这个函数,已知rawframes有3种模式,RGB、RGBdiff和flow。RGBRGBdiff都是一张图片,而flow因为有x方向和y方向,所以有两张图片。

    def _get_frames(self, record, image_tmpl, modality, indices, skip_offsets):
        images = list()
        for seg_ind in indices:# 即将整个视频进行分割以后,从每个片段中选取的帧
            p = int(seg_ind)
            for i, ind in enumerate(range(0, self.old_length, self.new_step)):
                if p + skip_offsets[i] <= record.num_frames:
                # self._load_image调用opencv来读取图像数据
                    seg_imgs = self._load_image(osp.join(
                        self.img_prefix, record.path),
                        image_tmpl, modality, p + skip_offsets[i])
                else:
                    seg_imgs = self._load_image(
                        osp.join(self.img_prefix, record.path),
                        image_tmpl, modality, p)
                images.extend(seg_imgs)
                if p + self.new_step < record.num_frames:
                    p += self.new_step
        return images

回归__getitem__函数,继续数据读取操作。

		# 0.5的概率觉得是否水平翻转
        flip = True if np.random.rand() < self.flip_ratio else False
        if (self.img_scale_dict is not None
                and record.path in self.img_scale_dict):
            img_scale = self.img_scale_dict[record.path]
        else:
            img_scale = self.img_scale
         # 获取经数据增强处理后的数据
        (img_group, img_shape, pad_shape,
         scale_factor, crop_quadruple) = self.img_group_transform(
            img_group, img_scale,
            crop_history=None,
            flip=flip, keep_ratio=self.resize_keep_ratio,
            div_255=self.div_255,
            is_flow=True if modality == 'Flow' else False)
        ori_shape = (256, 340, 3)
        img_meta = dict(
            ori_shape=ori_shape,
            img_shape=img_shape,
            pad_shape=pad_shape,
            scale_factor=scale_factor,
            crop_quadruple=crop_quadruple,
            flip=flip)

关于下面的几个参数。
[M x C x H x W]
M = 1 * N_oversample * N_seg * L
N_oversample代表three_cropten_cropcenter_crop这些操作得到的裁剪图片张数,应该分别为3、10和1.
N_seg代表整个视频被分为多少片段。
L应该是指new_length?如果是RGB的话就是1,如果是flow的话就是5.

        # [M x C x H x W]
        # M = 1 * N_oversample * N_seg * L
        if self.input_format == "NCTHW":
            img_group = img_group.reshape(
                (-1, self.num_segments, self.new_length) + img_group.shape[1:])
            # N_over x N_seg x L x C x H x W
            img_group = np.transpose(img_group, (0, 1, 3, 2, 4, 5))
            # N_over x N_seg x C x L x H x W
            img_group = img_group.reshape((-1,) + img_group.shape[2:])
            # M' x C x L x H x W

        # 这里的数据集相关图片都存放在cpu中?这方面不太了解
        data.update(dict(
            img_group_0=DC(to_tensor(img_group), stack=True, pad_dims=2),
            img_meta=DC(img_meta, cpu_only=True),
            img_path=DC(record.path, cpu_only=True),
            over_sample=DC(self.oversample, cpu_only=True),
        ))

        return data
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值