从代码的整体结构来说整个,整个模型相当复杂。每一个类对象和函数都有复杂的调用关系,定义的参数量越巨大,这对于详细定义模型模型的各个环节参数,精细化模型是有帮助形成成熟的商业化模型。
但是从科研角度上,代码的可读性和理解要求更高,很容易看着看着1000+的代码就发昏了。大量的参数的定义和反复传递,太容易忘了。
我们就抽丝剥茧一个函数一个函数的来看,整个网络的框架。
模型总体结构分析
然后就几个关键的模块,总体结构参照原文
这么看代码好像不多,但是几乎所有的模型都是已经定义好的函数和类。尤其是其中的ConvGRU函数,pytorch并没有现成的API可以用,所以需要自己编写。虽然没这个图看着不复杂,但需要完全掌握模型的每个细节还是很有难度的。
按照惯例先打印出关键网络结构,考虑是在太长,只展示出网络的部分结构。首先是MNetV2模块
UNISAL(
(cnn): MobileNetV2(
(features): Sequential(
(0): Sequential(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): InvertedResidual(
(conv): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
(3): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(18): Sequential(
(0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
)
post_cnn与RNN模块,可以看到不同的模块是对应不同参数组,并按照数据集名进行区分
(post_cnn): Sequential(
(inv_res): InvertedResidual(
(conv): Sequential(
(0): Conv2d(1296, 1296, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1296, bias=False)
(1): BatchNorm2d(1296, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
(3): Conv2d(1296, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
)
)
(rnn): ConvGRU(
(cell_list): ModuleList(
(0): ConvGRUCell(
(norm_r_x): DomainBatchNorm2d(
(bn_DHF1K): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(bn_Hollywood): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(bn_UCFSports): BatchNorm2d(256, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(bn_SALICON): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
_rnn): Sequential(
(0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1):