TSM源码解析

在这里插入图片描述
论文链接
代码链接

1 源码准备

git clone https://github.com/mit-han-lab/temporal-shift-module

2 源码结构

文件名称 功能作用
main.py 训练脚本
test_model.py 测试脚本
dataset.py 数据的读取和处理
models.py 网络模型的构建
mypath.py 数据模型路径自定义
transforms.py 数据的预处理
temporal_shift.py 时序移位模块
dataset_config.py 数据集配置文件
opts.py 参数配置

3 源码分析

3.1 数据准备

dataset.py主要是对数据集进行读取,并且对其采用三种采用方式:稀疏采样(随机和固定)、密集采样和二次采样,返回采样得到的数据集。

它首先定义了一个类TSNDataSet,用来处理最原始的数据。该类返回的是torch.utils.data.Dataset类型,(:一般而言在pytorch中自定义的数据读取类都要继承torch.utils.DataSet这个基类),然后通过重写_init_和_getitem_方法来读取函数。

(1)init函数

def __init__(self, root_path, list_file,
             num_segments=3, new_length=1, modality='RGB',
             image_tmpl='img_{:05d}.jpg', transform=None,
             random_shift=True, test_mode=False,
             remove_missing=False, dense_sample=False, twice_sample=False)

TSNDataSet类的初始化方法_init_需要如下参数:

  • root_path : 项目的根目录地址,如果其他文件地址使用绝对地址,则可以写成" "
  • list_file : 训练或测试的列表文件(.txt文件)地址
  • num_segments : 视频分割的段数
  • new_length : 根据输入数据集类型的不同,new_length取不同的值。输入为RGB帧时等于1,光流时等于5。
  • modality : 输入数据集类型(RGB、光流、RGB差异)
  • image_tmpl:加载数据集时格式
  • transform:对数据进行预处理,这里默认为None
  • random_shift:布尔型,当设置为True时对训练集进行采样,设置为False时对验证集进行采样
  • test_mode:布尔型,默认为False,当设置为True时即即对测试集进行采样
  • remove_missing:布尔型,默认为False。与test_mode在同一个判断条件下,对数据进行读取。
  • dense_sample:布尔型,设置为True时进行密集采样
  • twice_sample:布尔型,设置为True时进行二次采样

(2)_parse_list函数

_parse_list函数功能在于读取list文件,储存在video_list中

def _parse_list(self):
    # check the frame number is large >3:
    tmp = [x.strip().split(' ') for x in open(self.list_file)]
    if not self.test_mode or self.remove_missing:
        tmp = [item for item in tmp if int(item[1]) >= 3]
    self.video_list = [VideoRecord(item) for item in tmp]

self.video_list是一个长度为训练数据数量的列表。每个值都是VIDEORecord对象,包含一个列表和3个属性,列表长度为3,用空格键分割,分别为帧路径、该视频含有多少帧和帧标签。其中有一个判断机制,保证帧数需大于3。

(3)_sample_indices函数

​ _sample_indices函数功能在于实现TSN的密集采样或者稀疏采样,返回的是采样的帧数列表

def _sample_indices(self, record):
    if self.dense_sample:  # i3d dense sample # 密集随机采样
        sample_pos = max(1, 1 + record.num_frames - 64)
        t_stride = 64 // self.num_segments
        start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
        offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
        return np.array(offsets) + 1
    else:  # normal sample #稀疏随机采样
        average_duration = (record.num_frames - self.new_length + 1) // self.num_segments
        if average_duration > 0:
            offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration,
                                                                                              size=self.num_segments)
        elif record.num_frames > self.num_segments:
            offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))
        else:
            offsets = np.zeros
  • 10
    点赞
  • 53
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值