TSM源码分析

1,源码下载

在文件夹下右键打开Git Bash Here,输入如下代码:

git clone git@github.com:mit-han-lab/temporal-shift-module.git

2,代码结构

文件名称功能
main.py训练代码
opts.py参数配置代码
ops/dataset.py数据集的载入代码
ops/dataset_config.py用于配置不同的数据集
ops/model.py用于组装模型
ops/temporal_shift.py为核心的temporal shift 操作的实现
utils.py计算精度,损失值等更新代码

3,源码分析

3.1ops.py参数配置

ops.py是参数配置文件,除了一些超参数,路径信息意外特别需要注意以下这些。

parser.add_argument('--arch', type=str, default="resnet50") #BNInception
parser.add_argument('--num_segments', type=int, default=8)
parser.add_argument('--dataset', type=str, default="ucf101")
parser.add_argument('--modality', type=str, default="RGB")
parser.add_argument('--shift', default=True, action="store_true", help='use shift for models')
parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)')

–arch 表示使用的模型。

–num_segments 表示每个视频采样frames数,一般是8或者16。

–dataset 表示使用的数据集。

–modality 表示输入的类型,可以是图片帧(rgb),也可以是光流(flow)。

–shift 表示是否加入TSM模块。

–shift_div 表示shift的特征的比例,一般是8。表示2*1/8比例的特征会移动,其中1/8的特征做shift left, 另1/8的特征做shift right。

3.2dataset_config.py 数据集的配置

每个数据集实现一个return_xxx(modality)
返回数据集支持的子类名,train_list路径,val_list路径,数据集的根路径等信息。该方法会在main.py中得到调用,得到数据集的信息,作为参数传递给别的方法。

3.3dataset.py 是数据集的载入部分

dataset.py的主要功能就是对数据集进行读取,并且对其稀疏采样,返回稀疏采样后得到的数据集。
dataset.py中实现了TSNDataSet类,此类继承于torch.utils.data.dataset类。

(1)__init__函数

        def __init__(self, root_path, list_file,
                 num_segments=3, new_length=1, modality='RGB',
                 image_tmpl='img_{:05d}.jpg', transform=None,
                 random_shift=True, test_mode=False,
                 remove_missing=False, dense_sample=True, twice_sample=False):

        self.root_path = root_path
        self.list_file = list_file
        self.num_segments = num_segments
        self.new_length = new_length
        self.modality = modality
        self.image_tmpl = image_tmpl
        self.transform = transform
        self.random_shift = random_shift
        self.test_mode = test_mode
        self.remove_missing = remove_missing
        self.dense_sample = dense_sample  # using dense sample as I3D
        self.twice_sample = twice_sample  # twice sample for more validation
        if self.dense_sample:
            print('=> Using dense sample for the dataset...')
        if self.twice_sample:
            print('=> Using twice sample for the dataset...')

        if self.modality == 'RGBDiff':
            self.new_length += 1  # Diff needs one more image to calculate diff

        self._parse_list()

TSNDataSet类的初始化方法需要如下参数:

  • root_path:项目的根目录地址
  • list_file:训练或测试的列表文件(.txt文件)地址
  • num_segments:视频分割的段数
  • new_length:根据输入数据集类型的不同,new_length取不同的值
  • modality:输入数据集类型(RGB、光流、RGB差异)
  • image_tmpl:图片的名称
  • transform:数据集是否进行变换操作
  • random_shift:进行稀疏采样的时候是否添加一个随机数
  • test_mode:是否是测试时的数据集输入
  • dense_sample:稀疏采样方式中的稠密采样
  • twice_sample:稀疏采样方式中的二次采样

__init__函数中,_parse_list(),是对输入的train_val_list做解析,将内容写入到一个VideoRecord的list中去。

(2)_parse_list函数

    def _parse_list(self):
        # check the frame number is large >3:
        tmp = [x.strip().split(' ') for x in open(self.list_file)]
        if not self.test_mode or self.remove_missing:
            tmp = [item for item in tmp if int(item[1]) >= 3]
        self.video_list = [VideoRecord(item) for item in tmp]

        if self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
            for v in self.video_list:
                v._data[1] = int(v._data[1]) / 2
        print('video number:%d' % (len(self.video_list)))
        list = self.video_list[0]
        return list

该方法中,先检查帧数是否大于三,把小于三的剔除。之后再调用VideoRecord类,self.video_list是一个长度为训练数据数量的列表。每个值都是VideoRecord对象,包含一个列表和3个属性,列表长度为3,用空格键分割,分别为帧路径、该视频含有多少帧和帧标签。

(3)采样代码

(3.1)_sample_indices

该代码针对训练集进行采样,方法中定义了两种采样形式,dense_sample(稠密采样)和normal sample(普通稀疏采样)。

    def _sample_indices(self, record):
        """

        :param record: VideoRecord
        :return: list
        """
        if self.dense_sample:  # i3d dense sample
            sample_pos = max(1, 1 + record.num_frames - 64)
            t_stride = 64 // self.num_segments
            start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
            offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
            return np.array(offsets) + 1
        else:  # normal sample
            average_duration = (record.num_frames - self.new_length + 1) // self.num_segments
            if average_duration > 0:
                offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration,
                                                                                   size=self.num_segments)
            elif record.num_frames > self.num_segments:
                offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))
            else:
                offsets = np.zeros((self.num_segments,))
            return offsets + 1

假设um_segments = 3,num_frames = 128
稠密采样:
普通稀疏采样:

(3.2)_get_val_indices函数

该代码针对验证集及逆行采样,方法中定义了两种采样形式,dense_sample(稠密采样)和normal sample(普通稀疏采样)并且时进行等间隔的稀疏采样。

    def _get_val_indices(self, record):
        if self.dense_sample:  # i3d dense sample
            sample_pos = max(1, 1 + record.num_frames - 64)
            t_stride = 64 // self.num_segments
            start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
            offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
            return np.array(offsets) + 1
        else:
            if record.num_frames > self.num_segments + self.new_length - 1:
                tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
                offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
            else:
                offsets = np.zeros((self.num_segments,))
            return offsets + 1

假设um_segments = 3,num_frames = 128
稠密采样:
普通稀疏采样:

(3.3)_get_test_indices函数

该代码针对测试集进行采样,方法中定义了三种采样形式,dense_sample(稠密采样),twice_sample(二次采样)和normal sample(普通稀疏采样)并且时进行等间隔的稀疏采样。

    def _get_test_indices(self, record):
        if self.dense_sample:
            sample_pos = max(1, 1 + record.num_frames - 64)
            t_stride = 64 // self.num_segments
            start_list = np.linspace(0, sample_pos - 1, num=10, dtype=int)
            offsets = []
            for start_idx in start_list.tolist():
                offsets += [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]
            return np.array(offsets) + 1
        elif self.twice_sample:
            tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)

            offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)] +
                               [int(tick * x) for x in range(self.num_segments)])

            return offsets + 1
        else:
            tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
            offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
            return offsets + 1

假设um_segments = 3,num_frames = 128
稠密采样:
二次采样:
普通稀疏采样:

(4)_getitem_函数

该函数会在TSNDataSet初始化之后执行,功能在于选择性的调用执行稀疏采样的函数,并且调用get方法,得到TSNDataSet的返回。

    def __getitem__(self, index):
        record = self.video_list[index]
        # check this is a legit video folder

        if self.image_tmpl == 'flow_{}_{:05d}.jpg':
            file_name = self.image_tmpl.format('x', 1)
            full_path = os.path.join(self.root_path, record.path, file_name)
        elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
            file_name = self.image_tmpl.format(int(record.path), 'x', 1)
            full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name)
        else:
            file_name = self.image_tmpl.format(1)
            full_path = os.path.join(self.root_path, record.path, file_name)

        while not os.path.exists(full_path):
            print('################## Not Found:', os.path.join(self.root_path, record.path, file_name))
            index = np.random.randint(len(self.video_list))
            record = self.video_list[index]
            if self.image_tmpl == 'flow_{}_{:05d}.jpg':
                file_name = self.image_tmpl.format('x', 1)
                full_path = os.path.join(self.root_path, record.path, file_name)
            elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg':
                file_name = self.image_tmpl.format(int(record.path), 'x', 1)
                full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name)
            else:
                file_name = self.image_tmpl.format(1)
                full_path = os.path.join(self.root_path, record.path, file_name)

        if not self.test_mode:
            segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
        else:
            segment_indices = self._get_test_indices(record)
        return self.get(record, segment_indices)

在进行采样之前,会对视频文件夹进行检查。record变量读取的是video_list的第index个数据,包含该视频所在的文件地址、视频包含的帧数和视频所属的分类。如果该TSNDataSet不是为测试部分运行的,则对_sample_indices(record)或_get_val_indices(record)运行,判断条件在于它是否为训练数据集,如果是,则执行前者,否则,执行后者。将稀疏采样获得的帧列表保存于segment_indices中,之后调用get()方法,作为其中的参数。

(5)_load_image函数

该方法就是,获得指定索引的帧。

        if self.modality == 'RGB' or self.modality == 'RGBDiff':
            try:
                return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')]
            # convert(‘RGB’)如果不使用.convert('RGB')进行转换的话,读出来的图像是RGBA四通道的,A通道为透明通道,该对深度学习模型训练来说暂时用不到,因此使用convert('RGB')进行通道转换。
            except Exception:
                print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx)))
                return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')]
        elif self.modality == 'Flow':
            if self.image_tmpl == 'flow_{}_{:05d}.jpg':  # ucf
                x_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('x', idx))).convert(
                    'L')
                y_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('y', idx))).convert(
                    'L')
            elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg':  # something v1 flow
                x_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl.
                                                format(int(directory), 'x', idx))).convert('L')
                y_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl.
                                                format(int(directory), 'y', idx))).convert('L')
            else:
                try:
                    # idx_skip = 1 + (idx-1)*5
                    flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert(
                        'RGB')
                except Exception:
                    print('error loading flow file:',
                          os.path.join(self.root_path, directory, self.image_tmpl.format(idx)))
                    flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')
                # the input flow file is RGB image with (flow_x, flow_y, blank) for each channel
                flow_x, flow_y, _ = flow.split()
                x_img = flow_x.convert('L')
                y_img = flow_y.convert('L')

            return [x_img, y_img]

关键函数时Image.open(),他得到的就是指定索引的帧图片。
(6)get函数
get函数是对上面所说的_load_image方法的调用,并且会对帧图片进行变形操作。

    def get(self, record, indices):

        images = list()
        for seg_ind in indices:
            p = int(seg_ind)
            for i in range(self.new_length):
                seg_imgs = self._load_image(record.path, p)
                images.extend(seg_imgs)
                if p < record.num_frames:
                    p += 1

        process_data = self.transform(images)
        return process_data, record.label

对需要提取的帧索引进循环遍历,通过循环的调用_load_image方法提取帧图片,并且将图片存放在image列表中。最终返回的时指定的tensor张量和对应的标签。

3.4,model.py 模型组装

(1)__init__函数

TSN类的初始化方法,主要用于初始化参数,以及调用函数修改模型。

    def __init__(self, num_class, num_segments, modality,
                 base_model='resnet50', new_length=None,
                 consensus_type='avg', before_softmax=True,
                 dropout=0.8, img_feature_dim=256,
                 crop_num=1, partial_bn=True, print_spec=True, pretrain='imagenet',
                 is_shift=True, shift_div=8, shift_place='blockres', fc_lr5=False,
                 temporal_pool=False, non_local=False):
        super(TSN, self).__init__()
        self.modality = modality
        self.num_segments = num_segments
        self.reshape = True
        self.before_softmax = before_softmax
        self.dropout = dropout
        self.crop_num = crop_num
        self.consensus_type = consensus_type
        self.img_feature_dim = img_feature_dim  # the dimension of the CNN feature to represent each frame
        self.pretrain = pretrain

        self.is_shift = is_shift
        self.shift_div = shift_div
        self.shift_place = shift_place
        self.base_model_name = base_model
        self.fc_lr5 = fc_lr5
        self.temporal_pool = temporal_pool
        self.non_local = non_local

        if not before_softmax and consensus_type != 'avg':
            raise ValueError("Only avg consensus can be used after Softmax")

        if new_length is None:
            self.new_length = 1 if modality == "RGB" else 5
        else:
            self.new_length = new_length
        if print_spec:
            print(("""
    Initializing TSN with base model: {}.
    TSN Configurations:
        input_modality:     {}
        num_segments:       {}
        new_length:         {}
        consensus_module:   {}
        dropout_ratio:      {}
        img_feature_dim:    {}
            """.format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout, self.img_feature_dim)))

        self._prepare_base_model(base_model)

        feature_dim = self._prepare_tsn(num_class)

初始化方法初始了如下参数:

  • num_class:分类数
  • num_segments:视频分割的段数,也就是一个视频采样的帧数。
  • modality:输入数据的数据类型。
  • base_model:基础模型,之后的TSN模型以此基础来修改
  • new_length : 视频取帧的起点,rgb为1,光流为5

    之后调用修改模型的方法:1.调用 _prepare_base_model(base_model)构建出基础的模型,2. 调用_prepare_tsn(num_class),用于根据不同数据集的子类数,适配fc层大小。3. 对于flow和rgbdiff的输入,调用_construct_flow_model和_construct_diff_model更改第一个卷积核的大小。

(2)_prepare_base_model函数

主要是对基础模型的下载操作,在该函数中还调用了make_temporal_shit()函数加入tsm模块。

        def _prepare_base_model(self, base_model):
        print('=> base model: {}'.format(base_model))

        if 'resnet' in base_model:
            self.base_model = getattr(torchvision.models, base_model)(True if self.pretrain == 'imagenet' else False)
            if self.is_shift:
                print('Adding temporal shift...')
                from ops.temporal_shift import make_temporal_shift
                make_temporal_shift(self.base_model, self.num_segments,
                                    n_div=self.shift_div, place=self.shift_place, temporal_pool=self.temporal_pool)

            if self.non_local:
                print('Adding non-local module...')
                from ops.non_local import make_non_local
                make_non_local(self.base_model, self.num_segments)

            self.base_model.last_layer_name = 'fc'
            self.input_size = 224
            self.input_mean = [0.485, 0.456, 0.406]
            self.input_std = [0.229, 0.224, 0.225]

            self.base_model.avgpool = nn.AdaptiveAvgPool2d(1)

            if self.modality == 'Flow':
                self.input_mean = [0.5]
                self.input_std = [np.mean(self.input_std)]
            elif self.modality == 'RGBDiff':
                self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length
                self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length

关键的函数就是

self.base_model = getattr(torchvision.models, base_model)
make_temporal_shift(self.base_model, self.num_segments,n_div=self.shift_div, place=self.shift_place, temporal_pool=self.temporal_pool)

他们的作用分别是加载模型,以及加入TSM模块。

(3)_prepare_tsn函数

该函数的功能:修改模型的最后一层,微调最后的全连接层,改为适合使用数据集的形式。

    def _prepare_tsn(self, num_class):
        feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features
        if self.dropout == 0:
            setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class))
            self.new_fc = None
        else:
            setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout))
            self.new_fc = nn.Linear(feature_dim, num_class)

        std = 0.001
        if self.new_fc is None:
            normal_(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std)
            constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0)
        else:
            if hasattr(self.new_fc, 'weight'):  #用于检查new_fc是否含有'weight'这个属性名
                normal_(self.new_fc.weight, 0, std)
                constant_(self.new_fc.bias, 0)
        return feature_dim

关键函数是getattr和setattr函数,他们都是torch.nn.Module中的方法,可以获得网络某层的信息和修改网络某层的信息。

  1. 获取最后一层的输入通道数。
  2. 通过判断是否需要添加dropout层,添加的话就把模型的最后一层全连接层改为指定参数的dropout层。并且根据上面获得的最后一一层输入通道数和数据集的类别定义一个新的全连接层。
  3. 对新的全连接层的权重进行均值为0方差为0.0001的归一化操作。
  4. 对偏置进行置0操作。

3.5temporal_shift.py

该文件下主要定义了两个类,一个是进行shift操作的TemporalShift类和进行时间池化的TemporalPool类。
下面会依次进行讲解。

(1)TemporalShift类

(1.1)__init__函数

初始化参数

    def __init__(self, net, n_segment=3, n_div=8, inplace=False):
        super(TemporalShift, self).__init__()
        self.net = net
        self.n_segment = n_segment
        self.fold_div = n_div
        self.inplace = inplace

参数有:

  • net:需要进行shift操作模型
  • n_segment:视频分割的段数
  • fold_div:shift的比例
  • inplace:决定是否进行inplace_shift
(1.2)shift函数
    def shift(x, n_segment, fold_div=3, inplace=False):
        nt, c, h, w = x.size()
        n_batch = nt // n_segment
        x = x.view(n_batch, n_segment, c, h, w)
        fold = c // fold_div
        if inplace:
            print("inplace:",inplace)
            raise NotImplementedError  
        else:
            out = torch.zeros_like(x)
            out[:, :-1, :fold] = x[:, 1:, :fold]  # shift left
            out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold]  # shift right
            out[:, :, 2 * fold:] = x[:, :, 2 * fold:]  # not shift
        return out.view(nt, c, h, w)

TSM模块的核心代码,根据不同的flod_div会有不同比例的左移和右移操作
举个例子:
当c = 8, num_segment=4, 2维度的特征表示如下:
0_xx代表第一帧的特征。1_xx代码第二帧的特征,每个特征有8个通道。原始特征如下:
在这里插入图片描述
当fold_div = 8的时候,移动后如下:
在这里插入图片描述
可见第一帧中融入了第二帧的特征,第二帧中融入了第三帧和第二帧的特征。

当fold_div=4的时候,移动的部分会更多,即当前帧的特征中会包含更多前一帧和后一帧的信息。

在这里插入图片描述

(2)TemporalPool类

该类是对模型的layer2的时间池化操作

(2.1)__init__函数

初始化参数

    def __init__(self, net, n_segment):
        super(TemporalPool, self).__init__()
        self.net = net
        self.n_segment = n_segment

参数有:

  • net:进行时间池化的模型
  • n_segment:视频的分割数
(2.2)temporal_pool函数

时间池化的的核心代码,将时间维度的通道数减半

    def temporal_pool(x, n_segment):
        nt, c, h, w = x.size()
        n_batch = nt // n_segment
        x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2)  # n, c, t, h, w  
        x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
        x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w)
        return x

首先将时间维度给分离出来,时间的通道数也就是对应的 n_segment,再交换1,2维度,之后进行max_pool3d操作,对第二个维度(交换之后就是时间维度)进行降维,通道数减半。之后将维度交换回来。

3.6utils.py

(1)AverageMeter(object)类

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

首先定义了一个类AverageMeter来管理一些变量的更新,比如loss损失、top1准确率等。在初始化的时候,调用重置方法reset。当调用该类对象的update方法的时候就会进行变量的更新,当要读取某个变量的时候,可以通过对象.属性的方式来获取,比如在train函数中的top1.val读取top1准确率。

(2)accuracy函数

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

accuracy函数是准确率计算函数,输入output是模型的预测结果,尺寸为batch_size*num_class;target是真实标签,长度为batch_size。maxk为之后要进行排序的前k个数,比如topk=(1,3),则maxk=3,之后找到output中值最大的前三个。
_,pred=output.topk(maxk,1,True,True),这一语句调用了pytorch中的topk方法。各参数含义如下:

  • 第一个参数(maxk):需要返回的排序前k个的个数,如maxk=3,则返回前三个较大值
  • 第二个参数(1):表示dim,即按行计算
  • 第三个参数(True):表示largest=True,表示返回的是maxk个最大值
  • 第四个参数(True):表示sorted=True,表示返回排序的结果。

target.view(1,-1).expand_as(pred)将Target的尺寸规范到1batch_size,然后将维度扩充为pred相同的维度,也就是maxkbatchsize,然后调用equal方法计算两个tensor 矩阵相同元素的情况,得到correct是同等维度的矩阵。
correct_k = correct[:k].view(-1).float().sum(0)通过k值来决定计算topk的准确率,sum(0)表示按照列的维度计算和,最后都添加到res列表中返回。

3.7main.py

3.7.1train函数

def train(train_loader, model, criterion, optimizer, epoch, log, tf_writer):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    if args.no_partialbn:
        model.module.partialBN(False)
    else:
        model.module.partialBN(True)

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        target = target.cuda()
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 3))
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # compute gradient and do SGD step

        optimizer.zero_grad()
        loss.backward()

        if args.clip_gradient is not None:
            total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient)

        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                epoch, i, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'] * 0.1))  # TODO
            print(output)
            log.write(output + '\n')
            log.flush()

  • 首先定义了一些变量,这些变量是通过AverageMeter类来管理的。之后判断是否需要部分bn
  • 再调用model.py中的train方法来对模型的参数进行预训练,并且冻结除第一层之外的所有批处理规范化层的均值和方差参数,对全连接层的参数进行训练,达到微调的目的。
  • 然后对train_loader中的数据集进行遍历,存储其中的数据集输入和真实标签。执行output=model(input_var)得到模型的输入结果,调用损失计算函数得到loss,调用之前的accuracy函数来更新top1和top3的准确率。
  • 对于梯度清零、回传和先前学习的过程一样,最后可以得到训练之后的模型和训练的loss、accuracy等。

3.7.2main函数

main函数主要包含导入模型、数据准备、训练三个部分,接下来将按顺序介绍。

(1)导入模型
model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type,
                dropout=args.dropout,
                img_feature_dim=args.img_feature_dim,
                partial_bn=not args.no_partialbn,
                pretrain=args.pretrain,
                is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
                fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                temporal_pool=args.temporal_pool,
                non_local=args.non_local)

TSN类的定义在models.py脚本中,之前已经详细介绍过。输入包含分类的类别数:num_class;args.num_segments表示把一个video分成多少份,采用何种输入:args.modality,比如RGB表示常规图像,Flow表示光流;采用哪种模型:args.arch,比如resnet101等;dropout参数:args.dropout等。

(2)数据导入
train_loader = torch.utils.data.DataLoader(
        TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
                   new_length=data_length,
                   modality=args.modality,
                   image_tmpl=prefix,
                   transform=torchvision.transforms.Compose([
                       train_augmentation,
                       Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
                       normalize,
                   ]), dense_sample=args.dense_sample),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True,
        drop_last=True)  # prevent something not % n_GPU

首先是自定义的TSNDataSet类用来处理最原始的数据,返回的是torch.utils.data.Dataset类型,然后通过重写初始化函数_init_和_getitem_方法来读取数据。
torch.utils.data.Dataset类型的数据并不能作为模型的输入,还要通过torch.utils.data.DataLoader类进一步封装,将batch_size个数据和标签分别封装成一个Tensor,从而组成一个长度为2的list。最重要的输入就是TSNDataSet类的初始化结果,其他如batch size和shuffle参数是常用的。通过这两个类读取和封装数据,后续再转为Variable就能作为模型的输入了。在TSNDataSet类中,调用了_parse_list()方法、_sample_indices()方法等,最终返回稀疏采样之后的数据集。

(3)训练模型

通过for循环,循环的调用train函数和验证函数进行模型的训练和效果的验证

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer)

        # evaluate on validation set
        if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
            prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

            output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
            print(output_best)
            log_training.write(output_best + '\n')
            log_training.flush()

            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)
  • 3
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值