1 源码准备
git clone https://github.com/mit-han-lab/temporal-shift-module
2 源码结构
文件名称 | 功能作用 |
---|---|
main.py | 训练脚本 |
test_model.py | 测试脚本 |
dataset.py | 数据的读取和处理 |
models.py | 网络模型的构建 |
mypath.py | 数据模型路径自定义 |
transforms.py | 数据的预处理 |
temporal_shift.py | 时序移位模块 |
dataset_config.py | 数据集配置文件 |
opts.py | 参数配置 |
3 源码分析
3.1 数据准备
dataset.py主要是对数据集进行读取,并且对其采用三种采用方式:稀疏采样(随机和固定)、密集采样和二次采样,返回采样得到的数据集。
它首先定义了一个类TSNDataSet,用来处理最原始的数据。该类返回的是torch.utils.data.Dataset类型,(注:一般而言在pytorch中自定义的数据读取类都要继承torch.utils.DataSet这个基类),然后通过重写_init_和_getitem_方法来读取函数。
(1)init函数
def __init__(self, root_path, list_file,
num_segments=3, new_length=1, modality='RGB',
image_tmpl='img_{:05d}.jpg', transform=None,
random_shift=True, test_mode=False,
remove_missing=False, dense_sample=False, twice_sample=False)
TSNDataSet类的初始化方法_init_需要如下参数:
- root_path : 项目的根目录地址,如果其他文件地址使用绝对地址,则可以写成" "
- list_file : 训练或测试的列表文件(.txt文件)地址
- num_segments : 视频分割的段数
- new_length : 根据输入数据集类型的不同,new_length取不同的值。输入为RGB帧时等于1,光流时等于5。
- modality : 输入数据集类型(RGB、光流、RGB差异)
- image_tmpl:加载数据集时格式
- transform:对数据进行预处理,这里默认为None
- random_shift:布尔型,当设置为True时对训练集进行采样,设置为False时对验证集进行采样
- test_mode:布尔型,默认为False,当设置为True时即即对测试集进行采样
- remove_missing:布尔型,默认为False。与test_mode在同一个判断条件下,对数据进行读取。
- dense_sample:布尔型,设置为True时进行密集采样
- twice_sample:布尔型,设置为True时进行二次采样
(2)_parse_list函数
_parse_list函数功能在于读取list文件,储存在video_list中
def _parse_list(self):
# check the frame number is large >3:
tmp = [x.strip().split(' ') for x in open(self.list_file)]
if not self.test_mode or self.remove_missing:
tmp = [item for item in tmp if int(item[1]) >= 3]
self.video_list = [VideoRecord(item) for item in tmp]
self.video_list是一个长度为训练数据数量的列表。每个值都是VIDEORecord对象,包含一个列表和3个属性,列表长度为3,用空格键分割,分别为帧路径、该视频含有多少帧和帧标签。其中有一个判断机制,保证帧数需大于3。
(3)_sample_indices函数
_sample_indices函数功能在于实现TSN的密集采样或者稀疏采样,返回的是采样的帧数列表
def _sample_indices(self, record):
if self.dense_sample: # i3d dense sample # 密集随机采样
sample_pos = max(1, 1 + record.num_frames - 64)
t_stride = 64 // self.num_segments
start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
return np.array(offsets) + 1
else: # normal sample #稀疏随机采样
average_duration = (record.num_frames - self.new_length + 1) // self.num_segments
if average_duration > 0:
offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration,
size=self.num_segments)
elif record.num_frames > self.num_segments:
offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))
else:
offsets = np.zeros