这篇博客来读一读TSN算法的PyTorch代码,总体而言代码风格还是不错的,多读读优秀的代码对自身的提升还是有帮助的,另外因为代码内容较多,所以分训练和测试两篇介绍,这篇介绍训练代码,介绍顺序为代码运行顺序。TSN算法的介绍可以参考博客TSN(Temporal Segment Networks)算法笔记。
论文:Temporal Segment Networks: Towards Good Practices for Deep Action Recognition
代码地址:https://github.com/yjxiong/tsn-pytorch
项目结构:
main.py是训练脚本
test_models.py是测试脚本
opts.py是参数配置脚本
dataset.py是数据读取脚本
models.py是网络结构构建脚本
transforms.py是数据预处理相关的脚本
tf_model_zoo文件夹关于导入模型结构的脚本
main.py是训练模型的入口。
首先是导入模块,其中比较重要的是导入模型:from models import TSN,导入配置的参数:from opts import parser。
import argparse
import os
import time
import shutil
import torch
import torchvision
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from torch.nn.utils import clip_grad_norm
from dataset import TSNDataSet
from models import TSN
from transforms import *
from opts import parser
best_prec1 = 0
main函数主要包含导入模型、数据准备、训练三个部分,接下来将按顺序介绍。parser是在opts.py中定义的关于读取命令行参数的对象,然后通过from opts import parser导入的。model = TSN(num_class, args.num_segments, args.modality,...,partial_bn=not args.no_partialbn)
这一行是导入模型操作,TSN类的定义在models.py脚本中。输入包含分类的类别数:num_class;args.num_segments表示把一个video分成多少份,对应论文中的K,默认K=3;采用哪种输入:args.modality,比如RGB表示常规图像,Flow表示optical flow等;采用哪种模型:args.arch,比如resnet101,BNInception等;不同输入snippet的融合方式:args.consensus_type,比如avg等;dropout参数:args.dropout。
def main():
global args, best_prec1
args = parser.parse_args()
if args.dataset == 'ucf101':
num_class = 101
elif args.dataset == 'hmdb51':
num_class = 51
elif args.dataset == 'kinetics':
num_class = 400
else:
raise ValueError('Unknown dataset '+args.dataset)
model = TSN(num_class, args.num_segments, args.modality,
base_model=args.arch,
consensus_type=args.consensus_type, dropout=args.dropout, partial_bn=not args.no_partialbn)
TSN类(定义在models.py中)的初始化操作:__init__
,这里只列出主要的代码。new_length和输入数据类型相关。这里主要调用了该类的两个方法来完成初始化操作,一个是self._prepare_base_model(base_model),通过调用TSN类的_prepare_base_model方法来导入模型。另一个是feature_dim = self._prepare_tsn(num_class),通过调用TSN类的_prepare_tsn方法来得到。另外如果你的输入数据是optical flow或RGBDiff,那么还会对网络结构做修改,分别调用_construct_flow_model方法和_construct_diff_model方法来实现的,主要差别在第一个卷积层,因为该层的输入channel依据不同的输入类型而变化。接下来依次介绍这些方法。
class TSN(nn.Module):
def __init__(self, num_class, num_segments, modality,
base_model='resnet101', new_length=None,
consensus_type='avg', before_softmax=True,
dropout=0.8,
crop_num=1, partial_bn=True):
super(TSN, self).__init__()
if new_length is None:
self.new_length = 1 if modality == "RGB" else 5
else:
self.new_length = new_length
self._prepare_base_model(base_model)
feature_dim = self._prepare_tsn(num_class)
if self.modality == 'Flow':
print("Converting the ImageNet model to a flow init model")
self.base_model = self._construct_flow_model(self.base_model)
print("Done. Flow model ready...")
elif self.modality == 'RGBDiff':
print("Converting the ImageNet model to RGB+Diff init model")
self.base_model = self._construct_diff_model(self.base_model)
print("Done. RGBDiff model ready.")
self.consensus = ConsensusModule(consensus_type)
if not self.before_softmax:
self.softmax = nn.Softmax()
self._enable_pbn = partial_bn
if partial_bn:
self.partialBN(True)
_prepare_base_model方法的部分代码(以base_model为‘BNInception为例’)如下。getattr模块的使用:getattr(tf_model_zoo, base_model)()类似tf_model_zoo.BNInception(),因为要根据base_model的不同指定值来导入不同的网络,所以用getattr模块。导入模型之后就是一些常规的配置信息了。
elif base_model == 'BNInception':
import tf_model_zoo
self.base_model = getattr(tf_model_zoo, base_model)()
self.base_model.last_layer_name = 'fc'
self.input_size = 224
self.input_mean = [104, 117, 128]
self.input_std = [1]
if self.modality == 'Flow':
self.input_mean = [128]
elif self.modality == 'RGBDiff':
self.input_mean = self.input_mean * (1 + self.new_length)
BNInception类,定义在tf_model_zoo文件夹下的bninception文件夹下的pytorch_load.py中。前面当运行self.base_model = getattr(tf_model_zoo, base_model)(),且base_model是‘BNInception’的时候就会调用这个BNInception类的初始化函数__init__
。manifest = yaml.load(open(model_path))是读进配置好的网络结构(.yml格式),返回的manifest是长度为3的字典,和.yml文件内容对应。其中manifest[‘layers’]是关于网络层的详细定义,其中的每个值表示一个层,每个层也是一个字典,包含数据流关系、名称和结构参数等信息。然后get_basic_layer函数是用来根据这些参数得到具体的网络层并保存相关信息。setattr(self, id, module)是将得到的层写入self的指定属性中,就是搭建层的过程。这样循环完所有层的配置信息后,就搭建好