一、插入模块
有如下代码:
class embed_net(nn.Module):
def __init__(self, class_num, no_local= 'on', gm_pool = 'on', arch='resnet50'):
super(embed_net, self).__init__()
self.thermal_module = thermal_module(arch=arch)
self.visible_module = visible_module(arch=arch)
self.base_resnet = base_resnet(arch=arch)
self.non_local = no_local
if self.non_local =='on':
layers=[3, 4, 6, 3]
non_layers=[0,2,3,0]
self.NL_1 = nn.ModuleList(
[Non_local(256) for i in range(non_layers[0])])
self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
self.NL_2 = nn.ModuleList(
[Non_local(512) for i in range(non_layers[1])])
self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
self.NL_3 = nn.ModuleList(
[Non_local(1024) for i in range(non_layers[2])])
self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
self.NL_4 = nn.ModuleList(
[Non_local(2048) for i in range(non_layers[3])])
self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
pool_dim = 2048
self.l2norm = Normalize(2)
self.bottleneck = nn.BatchNorm1d(pool_dim)
self.bottleneck.bias.requires_grad_(False) # no shift
self.classifier = nn.Linear(pool_dim, class_num, bias=False)
self.bottleneck.apply(weights_init_kaiming)
self.classifier.apply(weights_init_classifier)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.gm_pool = gm_pool
def forward(self, x1, x2, modal=0):
if modal == 0:
x1 = self.visible_module(x1)
x2 = self.thermal_module(x2)
x = torch.cat((x1, x2), 0)
elif modal == 1:
x = self.visible_module(x1)
elif modal == 2:
x = self.thermal_module(x2)
# shared block
if self.non_local == 'on':
# 在base_resnet.base.layer1后面插入模块
NL1_counter = 0
if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1]
for i in range(len(self.base_resnet.base.layer1)):
x = self.base_resnet.base.layer1[i](x)
if i == self.NL_1_idx[NL1_counter]:
_, C, H, W = x.shape
x = self.NL_1[NL1_counter](x)
NL1_counter += 1
# Layer 2
NL2_counter = 0
if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1]
for i in range(len(self.base_resnet.base.layer2)):
x = self.base_resnet.base.layer2[i](x)
if i == self.NL_2_idx[NL2_counter]:
_, C, H, W = x.shape
x = self.NL_2[NL2_counter](x)
NL2_counter += 1
# Layer 3
NL3_counter = 0
if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1]
for i in range(len(self.base_resnet.base.layer3)):
x = self.base_resnet.base.layer3[i](x)
if i == self.NL_3_idx[NL3_counter]:
_, C, H, W = x.shape
x = self.NL_3[NL3_counter](x)
NL3_counter += 1
# Layer 4
NL4_counter = 0
if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1]
for i in range(len(self.base_resnet.base.layer4)):
x = self.base_resnet.base.layer4[i](x)
if i == self.NL_4_idx[NL4_counter]:
_, C, H, W = x.shape
x = self.NL_4[NL4_counter](x)
NL4_counter += 1
else:
x = self.base_resnet(x)
if self.gm_pool == 'on':
b, c, h, w = x.shape
x = x.view(b, c, -1)
p = 3.0
x_pool = (torch.mean(x**p, dim=-1) + 1e-12)**(1/p)
else:
x_pool = self.avgpool(x)
x_pool = x_pool.view(x_pool.size(0), x_pool.size(1))
feat = self.bottleneck(x_pool)
if self.training:
return x_pool, self.classifier(feat)
else:
return self.l2norm(x_pool), self.l2norm(feat)
注意代码第51-58行,它在base_resnet.base.layer1后面加入了一个NL1模块,此时,对于 self.base_resnet.base.layer2
,它的输入是 x
,而这个 x
是由 self.base_resnet.base.layer1
的循环处理过的结果。也就是说,x
是经过 layer1
处理的特征张量。
具体来说,x
最初是输入到整个 ResNet 网络的原始输入。在第一个 for
循环中,经过 self.base_resnet.base.layer1[i](x)
的处理,x
发生了变化,变成了 layer1
中第 i
个基本块的输出。
如果在这个循环中,if i == self.NL_1_idx[NL1_counter]:
的条件满足,那么就会在 x
上插入 NL_1[NL1_counter]
中的 Non_local
模块,进一步修改了 x
。
所以,对于 self.base_resnet.base.layer2
来说,它的输入是经过 self.base_resnet.base.layer1
处理过,并且在部分位置经过了 Non_local
模块的修改的特征张量 x
。
二、对self.base_resnet.base.layer1/2的解释
self.base_resnet.base.layer1
是指向 ResNet 网络中的第一个层的引用。在 PyTorch 中,ResNet 网络被组织成一系列的层,其中每个层包含多个基本块(basic block)。
具体而言,ResNet 网络的结构如下:
layer1
: 第一个阶段,包含多个基本块。layer2
: 第二个阶段,也包含多个基本块。layer3
: 第三个阶段,同样包含多个基本块。layer4
: 第四个阶段,仍然包含多个基本块。
每个基本块由两个卷积层组成,同时包含批量归一化和激活函数。这些基本块通过残差连接(Residual Connection)的方式连接在一起,以帮助解决梯度消失的问题,使得更深的网络可以更容易地训练。
因此,self.base_resnet.base.layer1
是指向 ResNet 网络中的第一个阶段(stage)的引用。这个阶段通常包含多个基本块,每个基本块又包含若干个卷积层、归一化层和激活函数。
以resnet18为例:
self.base_resnet.base.layer1
的内部结构通常包含四个基本块(basic block)。每个基本块由两个卷积层组成,具体结构如下:
-
Basic Block 1:
- 3x3 卷积层
- 批量归一化(Batch Normalization)
- ReLU 激活函数
- 3x3 卷积层
- 批量归一化
- 残差连接(Residual Connection):通过跳跃连接将输入直接添加到卷积输出上
- ReLU 激活函数
-
Basic Block 2:
- 同上
-
Basic Block 3:
- 同上
-
Basic Block 4:
- 同上
这是 ResNet-18 中第一个阶段(layer1
)的基本块结构。整个 layer1
就是这四个基本块的堆叠。这些基本块的设计目的是通过残差连接来缓解梯度消失问题,使得神经网络更容易训练。
而进一步的,self.base_resnet.base.layer1[2]
指的是 ResNet 网络中第一个阶段(layer1
)的第三个基本块(basic block)。在 ResNet 的 PyTorch 实现中,这是一个通过索引选择基本块的方式。
在 ResNet-18 中,layer1
由四个基本块组成,索引从 0 到 3。因此self.base_resnet.base.layer1[2]
就是 layer1
中的第三个基本块。