文章目录
相关的包
import mmcv
import numpy as np
import os.path as osp
from mmcv.parallel import DataContainer as DC
from torch.utils.data import Dataset
from .transforms import (GroupImageTransform)
from .utils import to_tensor
RawFramesRecord
这个类提供了一些简单的封装,用来返回关于数据的一些信息(比如帧路径、该视频包含多少帧、帧标签)
class RawFramesRecord(object):
def __init__(self, row):
self._data = row
@property
def path(self):
return self._data[0]
@property
def num_frames(self):
return int(self._data[1])
@property
def label(self):
return int(self._data[2])
注意from torch.utils.data import Dataset
,RawFramesDataset
是继承自torch.utils.data
的,这是因为自定义数据读取相关类的时候需要继承torch.utils.data.Dataset这个基类。
关于torch.utils.data.Dataset
的参考文献pytorch源码分析之torch.utils.data.Dataset类和torch.utils.data.DataLoader类
实现这个抽象类,有两个必要的函数:__len__
和__getitem__
__len__(self)
定义当被len()
函数调用时的行为(返回容器中元素的个数)__getitem__(self)
定义获取容器中指定元素的行为,相当于self[key]
,即允许类对象可以有索引操作。
class RawFramesDataset(Dataset):
def __init__(self,
ann_file,
img_prefix,
img_norm_cfg,
num_segments=3,
new_length=1,
new_step=1,
random_shift=True,
temporal_jitter=False,
modality='RGB',
image_tmpl='img_{}.jpg',
img_scale=256,
img_scale_file=None,
input_size=224,
div_255=False,
size_divisor=None,
proposal_file=None,
num_max_proposals=1000,
flip_ratio=0.5,
resize_keep_ratio=True,
resize_ratio=[1, 0.875, 0.75, 0.66],
test_mode=False,
oversample=None,
random_crop=False,
more_fix_crop=False,
multiscale_crop=False,
resize_crop=False,
rescale_crop=False,
scales=None,
max_distort=1,
input_format='NCHW'):
# prefix of images path
self.img_prefix = img_prefix
# load annotations
self.video_infos = self.load_annotations(ann_file)
# normalization config
self.img_norm_cfg = img_norm_cfg
# parameters for frame fetching
# number of segments
# 视频集被分为num_segments个
self.num_segments = num_segments
# number of consecutive frames
self.old_length = new_length * new_step
self.new_length = new_length
# number of steps (sparse sampling for efficiency of io)
self.new_step = new_step
# whether to temporally random shift when training
self.random_shift = random_shift
# whether to temporally jitter if new_step > 1
self.temporal_jitter = temporal_jitter
# parameters for modalities
if isinstance(modality, (list, tuple)):
self.modalities = modality
num_modality = len(modality)
else:
self.modalities = [modality]
num_modality = 1
if isinstance(image_tmpl, (list, tuple)):
self.image_tmpls = image_tmpl
else:
self.image_tmpls = [image_tmpl]
assert len(self.image_tmpls) == num_modality
# parameters for image preprocessing
# img_scale
if isinstance(img_scale, int):
img_scale = (np.Inf, img_scale) # np.Inf代表无穷大
self.img_scale = img_scale
if img_scale_file is not None:
self.img_scale_dict = {line.split(' ')[0]:
(int(line.split(' ')[1]),
int(line.split(' ')[2]))
for line in open(img_scale_file)}
else:
self.img_scale_dict = None
# network input size
if isinstance(input_size, int):
input_size = (input_size, input_size)
self.input_size = input_size
# parameters for specification from pre-trained networks (lecacy issue)
self.div_255 = div_255
关于数据增强
的常见操作,
翻转
resize_keep_ratio
# parameters for data augmentation
# flip ratio
self.flip_ratio = flip_ratio # 图像的随机左右翻转的概率
self.resize_keep_ratio = resize_keep_ratio # 图片放大缩小时,是否保持原图高宽比例
# test mode or not
self.test_mode = test_mode
# set group flag for the sampler
# if not self.test_mode:
self._set_group_flag()
self._set_group_flag()
函数设置,根据图片高宽比>1的设为group1,数据集所有都设为了1。
def _set_group_flag(self):
"""Set flag according to image aspect ratio.
Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
self.flag = np.zeros(len(self), dtype=np.uint8)
for i in range(len(self)):
# img_info = self.img_infos[i]
# if img_info['width'] / img_info['height'] > 1:
self.flag[i] = 1
关于three_crop
和ten_crop
。
具体实现源码在mmaction/datasets/transforms.py
。
根据
TPN
论文,three_crop
和ten_crop
是用在测试阶段的,是用来作空间完全卷积测试的近似值,并且在non-local
、slowfast
论文中也提到了这种方法。
three_crop
先把原始帧图片的短边调整到256,然后再在调整后的帧图片上随机裁剪出3个256x256的部分。
在
TPN
中,作者在Kinetics-400上进行了three_crop
测试。
ten_crop
按照TSN
的方法,截取帧图片的4个croner和1个center,再进行水平翻转,所以总共是10个裁剪。
GroupImageTransform
这个文件来自于mmaction/datasets/transforms.py
,这个文件主要对数据集做一些如数据增强等的处理。关于这个文件的介绍mmaction中的tranforms.py
# transforms
assert oversample in [None, 'three_crop', 'ten_crop']
# 对数据集做一些数据增强的处理
self.img_group_transform = GroupImageTransform(
size_divisor=None, crop_size=self.input_size,
oversample=oversample, random_crop=random_crop,
more_fix_crop=more_fix_crop,
multiscale_crop=multiscale_crop, scales=scales,
max_distort=max_distort,
resize_crop=resize_crop,
rescale_crop=rescale_crop,
**self.img_norm_cfg)
NCTHW
,
N为batch_size,C为通道数,T为帧数,H为高,W为宽。
# input format
assert input_format in ['NCHW', 'NCTHW']
self.input_format = input_format
'''
self.bbox_transform = Bbox_transform()
'''
__len__
根据ann_file
获取训练集的路径、帧数和类别号。
def __len__(self):
return len(self.video_infos)
__getitem__
真正的读取数据操作
def __getitem__(self, idx):
record = self.video_infos[idx] # 获取某个视频的rawframes
if self.test_mode:
segment_indices, skip_offsets = self._get_test_indices(record)
else:
segment_indices, skip_offsets = self._sample_indices(
record) if self.random_shift else self._get_val_indices(record)
data = dict(num_modalities=DC(to_tensor(len(self.modalities))),
gt_label=DC(to_tensor(record.label), stack=True,
pad_dims=None))
self.random_shift
默认是True
,所以一般是使用self._sample_indices
函数。来具体看一下这个函数。
def _sample_indices(self, record):
'''
:param record: VideoRawFramesRecord
:return: list, list
'''
# 把整个视频分成num_segement个片段
average_duration = (record.num_frames -
self.old_length + 1) // self.num_segments
# 只要视频总帧数大于num_segement,就成立
if average_duration > 0:
# 按照num_segment把视频分段得到offsets
offsets = np.multiply(list(range(self.num_segments)),
average_duration)
# 每个视频片段随机选一帧
offsets = offsets + np.random.randint(average_duration,
size=self.num_segments)
# 如果视频本身太短
# e.g. 视频长度为6,num_segment=8,那么6/8=0.75,以0.75为间隔,并取整,所以采样到的帧为[0,2,2,3,3,3,4,4](如果序号从0开始)
elif record.num_frames > max(self.num_segments, self.old_length):
offsets = np.sort(np.random.randint(
record.num_frames - self.old_length + 1,
size=self.num_segments))
else:# 否则,没采样到
offsets = np.zeros((self.num_segments,))
if self.temporal_jitter:
skip_offsets = np.random.randint(
self.new_step, size=self.old_length // self.new_step)
else:
skip_offsets = np.zeros(
self.old_length // self.new_step, dtype=int)
return offsets + 1, skip_offsets # frame index starts from 1
测试时,运行self._get_test_indices
。
def _get_test_indices(self, record):
if record.num_frames > self.old_length - 1:
# 把整个视频分成num_segment个片段
tick = (record.num_frames - self.old_length + 1) / \
float(self.num_segments)
# 选出每个片段的中间帧
offsets = np.array([int(tick / 2.0 + tick * x)
for x in range(self.num_segments)])
else:# 否则,没取到
offsets = np.zeros((self.num_segments,))
if self.temporal_jitter:
skip_offsets = np.random.randint(
self.new_step, size=self.old_length // self.new_step)
else:
skip_offsets = np.zeros(
self.old_length // self.new_step, dtype=int)
return offsets + 1, skip_offsets
回归__getitem__
函数,继续数据读取操作。
# handle the first modality
modality = self.modalities[0]
image_tmpl = self.image_tmpls[0]
img_group = self._get_frames(
record, image_tmpl, modality, segment_indices, skip_offsets)
关于self._get_frames
,重点是self._load_image
这个函数,已知rawframes有3种模式,RGB、RGBdiff和flow。RGB
和RGBdiff
都是一张图片,而flow
因为有x方向和y方向,所以有两张图片。
def _get_frames(self, record, image_tmpl, modality, indices, skip_offsets):
images = list()
for seg_ind in indices:# 即将整个视频进行分割以后,从每个片段中选取的帧
p = int(seg_ind)
for i, ind in enumerate(range(0, self.old_length, self.new_step)):
if p + skip_offsets[i] <= record.num_frames:
# self._load_image调用opencv来读取图像数据
seg_imgs = self._load_image(osp.join(
self.img_prefix, record.path),
image_tmpl, modality, p + skip_offsets[i])
else:
seg_imgs = self._load_image(
osp.join(self.img_prefix, record.path),
image_tmpl, modality, p)
images.extend(seg_imgs)
if p + self.new_step < record.num_frames:
p += self.new_step
return images
回归__getitem__
函数,继续数据读取操作。
# 0.5的概率觉得是否水平翻转
flip = True if np.random.rand() < self.flip_ratio else False
if (self.img_scale_dict is not None
and record.path in self.img_scale_dict):
img_scale = self.img_scale_dict[record.path]
else:
img_scale = self.img_scale
# 获取经数据增强处理后的数据
(img_group, img_shape, pad_shape,
scale_factor, crop_quadruple) = self.img_group_transform(
img_group, img_scale,
crop_history=None,
flip=flip, keep_ratio=self.resize_keep_ratio,
div_255=self.div_255,
is_flow=True if modality == 'Flow' else False)
ori_shape = (256, 340, 3)
img_meta = dict(
ori_shape=ori_shape,
img_shape=img_shape,
pad_shape=pad_shape,
scale_factor=scale_factor,
crop_quadruple=crop_quadruple,
flip=flip)
关于下面的几个参数。
[M x C x H x W]
M = 1 * N_oversample * N_seg * L
,
N_oversample
代表three_crop
、ten_crop
和center_crop
这些操作得到的裁剪图片张数,应该分别为3、10和1.
N_seg
代表整个视频被分为多少片段。
L
应该是指new_length
?如果是RGB
的话就是1,如果是flow
的话就是5.
# [M x C x H x W]
# M = 1 * N_oversample * N_seg * L
if self.input_format == "NCTHW":
img_group = img_group.reshape(
(-1, self.num_segments, self.new_length) + img_group.shape[1:])
# N_over x N_seg x L x C x H x W
img_group = np.transpose(img_group, (0, 1, 3, 2, 4, 5))
# N_over x N_seg x C x L x H x W
img_group = img_group.reshape((-1,) + img_group.shape[2:])
# M' x C x L x H x W
# 这里的数据集相关图片都存放在cpu中?这方面不太了解
data.update(dict(
img_group_0=DC(to_tensor(img_group), stack=True, pad_dims=2),
img_meta=DC(img_meta, cpu_only=True),
img_path=DC(record.path, cpu_only=True),
over_sample=DC(self.oversample, cpu_only=True),
))
return data