TSN算法的PyTorch代码解读(训练部分)

本文详述了TSN算法的PyTorch训练代码,从模型导入、数据准备到训练过程,深入剖析代码逻辑。文章介绍了TSN模型的结构、参数配置、数据读取以及训练流程,对关键函数进行了详细的解释,如模型初始化、数据加载和模型训练。此外,还讨论了针对不同输入类型(RGB、Flow)时网络结构的调整。
摘要由CSDN通过智能技术生成

这篇博客来读一读TSN算法的PyTorch代码,总体而言代码风格还是不错的,多读读优秀的代码对自身的提升还是有帮助的,另外因为代码内容较多,所以分训练和测试两篇介绍,这篇介绍训练代码,介绍顺序为代码运行顺序。TSN算法的介绍可以参考博客TSN(Temporal Segment Networks)算法笔记
论文:Temporal Segment Networks: Towards Good Practices for Deep Action Recognition
代码地址:https://github.com/yjxiong/tsn-pytorch

项目结构:
main.py是训练脚本
test_models.py是测试脚本
opts.py是参数配置脚本
dataset.py是数据读取脚本
models.py是网络结构构建脚本
transforms.py是数据预处理相关的脚本
tf_model_zoo文件夹关于导入模型结构的脚本

main.py是训练模型的入口。
首先是导入模块,其中比较重要的是导入模型:from models import TSN,导入配置的参数:from opts import parser。

import argparse
import os
import time
import shutil
import torch
import torchvision
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from torch.nn.utils import clip_grad_norm

from dataset import TSNDataSet
from models import TSN
from transforms import *
from opts import parser

best_prec1 = 0

main函数主要包含导入模型、数据准备、训练三个部分,接下来将按顺序介绍。parser是在opts.py中定义的关于读取命令行参数的对象,然后通过from opts import parser导入的。model = TSN(num_class, args.num_segments, args.modality,...,partial_bn=not args.no_partialbn)这一行是导入模型操作,TSN类的定义在models.py脚本中。输入包含分类的类别数:num_class;args.num_segments表示把一个video分成多少份,对应论文中的K,默认K=3;采用哪种输入:args.modality,比如RGB表示常规图像,Flow表示optical flow等;采用哪种模型:args.arch,比如resnet101,BNInception等;不同输入snippet的融合方式:args.consensus_type,比如avg等;dropout参数:args.dropout。

def main():
    global args, best_prec1
    args = parser.parse_args()

    if args.dataset == 'ucf101':
        num_class = 101
    elif args.dataset == 'hmdb51':
        num_class = 51
    elif args.dataset == 'kinetics':
        num_class = 400
    else:
        raise ValueError('Unknown dataset '+args.dataset)

    model = TSN(num_class, args.num_segments, args.modality,
                base_model=args.arch,
                consensus_type=args.consensus_type, dropout=args.dropout, partial_bn=not args.no_partialbn)

TSN类(定义在models.py中)的初始化操作:__init__,这里只列出主要的代码。new_length和输入数据类型相关。这里主要调用了该类的两个方法来完成初始化操作,一个是self._prepare_base_model(base_model),通过调用TSN类的_prepare_base_model方法来导入模型。另一个是feature_dim = self._prepare_tsn(num_class),通过调用TSN类的_prepare_tsn方法来得到。另外如果你的输入数据是optical flow或RGBDiff,那么还会对网络结构做修改,分别调用_construct_flow_model方法和_construct_diff_model方法来实现的,主要差别在第一个卷积层,因为该层的输入channel依据不同的输入类型而变化。接下来依次介绍这些方法。

class TSN(nn.Module):
    def __init__(self, num_class, num_segments, modality,
                 base_model='resnet101', new_length=None,
                 consensus_type='avg', before_softmax=True,
                 dropout=0.8,
                 crop_num=1, partial_bn=True):
        super(TSN, self).__init__()

        if new_length is None:
            self.new_length = 1 if modality == "RGB" else 5
        else:
            self.new_length = new_length

        self._prepare_base_model(base_model)

        feature_dim = self._prepare_tsn(num_class)

        if self.modality == 'Flow':
            print("Converting the ImageNet model to a flow init model")
            self.base_model = self._construct_flow_model(self.base_model)
            print("Done. Flow model ready...")
        elif self.modality == 'RGBDiff':
            print("Converting the ImageNet model to RGB+Diff init model")
            self.base_model = self._construct_diff_model(self.base_model)
            print("Done. RGBDiff model ready.")

        self.consensus = ConsensusModule(consensus_type)

        if not self.before_softmax:
            self.softmax = nn.Softmax()

        self._enable_pbn = partial_bn
        if partial_bn:
            self.partialBN(True)

_prepare_base_model方法的部分代码(以base_model为‘BNInception为例’)如下。getattr模块的使用:getattr(tf_model_zoo, base_model)()类似tf_model_zoo.BNInception(),因为要根据base_model的不同指定值来导入不同的网络,所以用getattr模块。导入模型之后就是一些常规的配置信息了。

elif base_model == 'BNInception':
            import tf_model_zoo
            self.base_model = getattr(tf_model_zoo, base_model)()
            self.base_model.last_layer_name = 'fc'
            self.input_size = 224
            self.input_mean = [104, 117, 128]
            self.input_std = [1]

            if self.modality == 'Flow':
                self.input_mean = [128]
            elif self.modality == 'RGBDiff':
                self.input_mean = self.input_mean * (1 + self.new_length)

BNInception类,定义在tf_model_zoo文件夹下的bninception文件夹下的pytorch_load.py中。前面当运行self.base_model = getattr(tf_model_zoo, base_model)(),且base_model是‘BNInception’的时候就会调用这个BNInception类的初始化函数__init__。manifest = yaml.load(open(model_path))是读进配置好的网络结构(.yml格式),返回的manifest是长度为3的字典,和.yml文件内容对应。其中manifest[‘layers’]是关于网络层的详细定义,其中的每个值表示一个层,每个层也是一个字典,包含数据流关系、名称和结构参数等信息。然后get_basic_layer函数是用来根据这些参数得到具体的网络层并保存相关信息。setattr(self, id, module)是将得到的层写入self的指定属性中,就是搭建层的过程。这样循环完所有层的配置信息后,就搭建好

评论 57
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值