函数
CANet由三部分组成,encoder,co-attention fusion module,decoder。首先看最重要的部分co-attention fusion module代码,该module由PCAM和CCAM模块组成:
class PCAM_Module(Module):
""" Position attention module"""
#Ref from SAGAN
def __init__(self, in_dim):
super(PCAM_Module, self).__init__()
self.chanel_in = in_dim
self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = Parameter(torch.zeros(1))
self.softmax = Softmax(dim=-1)
def forward(self, x, y):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X (HxW) X (HxW)
"""
m_batchsize, C, height, width = x.size()
# # 生成Q,尺寸变换为(b,c,h,w)->(b,c,w*h)->(b,w*h,c/8)
proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
# 生成K,尺寸变换为(b,c,h,w)->(b,c/8,w*h)
proj_key = self.key_conv(y).view(m_batchsize, -1, width*height)
# q*k,维度变换为(b,w*h,c/8) * (b,c/8,w*h) = (b,w*h,w*h)
energy = torch.bmm(proj_query, proj_key)
# 经过softmax生成注意力图,(b,w*h,w*h)
attention = self.softmax(energy)
# 生成V,维度变换为(b,c,h,w)->(b,c,h*w)
proj_value = self.value_conv(y).view(m_batchsize, -1, width*height)
# attention * V = (b,c,h*w) * (b,w*h,w*h) = (b,c,w*h)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
# (b,c,w*h)->(b,c,h,w)
out = out.view(m_batchsize, C, height, width)
out = self.gamma*out + x
return out
class CCAM_Module(Module):
""" Channel attention module"""
def __init__(self, in_dim):
super(CCAM_Module, self).__init__()
self.chanel_in = in_dim
self.gamma = Parameter(torch.zeros(1))
self.softmax = Softmax(dim=-1)
def forward(self, x, y):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X C X C
"""
m_batchsize, C, height, width = x.size()
# 生成q,(b,c,h,w)->(b,c,n)
proj_query = x.view(m_batchsize, C, -1)
# 生成k,(b,c,h,w)->(b,c,n)->(b,n,c)
proj_key = y.view(m_batchsize, C, -1).permute(0, 2, 1)
# 矩阵相乘,(b,c,n) * (b,n,c) = (b,c,c)
energy = torch.bmm(proj_query, proj_key)
# 生成energy每一行最大的值,以及对应的索引。这里只取值,将其扩充到energy维度减去energy
energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
# 输出注意力map,(b,c,c)
attention = self.softmax(energy_new)
# 生成V,维度为(b,c,h*w)
proj_value = y.view(m_batchsize, C, -1)
# (b,c,c)*(b,c,h*w) = (b,c,h*w)
out = torch.bmm(attention, proj_value)
# (b,c,h*w)->(b,c,h,w)
out = out.view(m_batchsize, C, height, width)
out = self.gamma*out + x
return out
最后输出的两个特征图和卷积输出的特征图共同输入到fusion layer:
class FusionLayer(Module):
def __init__(self, in_channels, groups=1, radix=2, reduction_factor=4, norm_layer=None):
super(FusionLayer, self).__init__()
inter_channels = max(in_channels//reduction_factor, 32) # (256或者32)
self.radix = radix # 2
self.cardinality = groups
self.use_bn = norm_layer is not None
self.relu = ReLU(inplace=True)
self.fc1_p = Conv2d(in_channels, inter_channels, 1, groups=self.cardinality) # 1024 -> 256
self.fc1_c = Conv2d(in_channels, inter_channels, 1, groups=self.cardinality) # 1024 -> 256
if self.use_bn:
self.bn1_p = norm_layer(inter_channels)
self.bn1_c = norm_layer(inter_channels)
self.fc2_p = Conv2d(inter_channels, in_channels*radix, 1, groups=self.cardinality) # 256 -> 1024
self.fc2_c = Conv2d(inter_channels, in_channels*radix, 1, groups=self.cardinality) # 256 -> 1024
self.rsoftmax = rSoftMax(radix, groups)
def forward(self, x, y, z):
"""
:param x: convolution fusion features,(b,2048,h,w)
:param y: position attention features,(b,1024,h,w)
:param z: channel attention features,(b,1024,h,w)
:return:
"""
assert self.radix == 2, "Error radix size!"
# (b,2048,h,w)
batch, rchannel = x.shape[:2] # n, 2048
if self.radix > 1:
splited = torch.split(x, rchannel//self.radix, dim=1) # 两个,维度分别为(b,1024,h,w)
gap_1 = splited[0] # (b,1024,h,w)
gap_2 = splited[1] # (b,1024,h,w)
else:
gap_1 = x
gap_2 = x
assert gap_1.shape[1] == y.shape[1], "Error!"
assert gap_2.shape[1] == z.shape[1], "Error!"
gap_p = sum([gap_1, y])
gap_c = sum([gap_2, z])
gap_p = F.adaptive_avg_pool2d(gap_p, 1) # n, 1024, h, w -> n, 1024, 1, 1
gap_c = F.adaptive_avg_pool2d(gap_c, 1) # n, 1024, h, w -> n, 1024, 1, 1
gap_p = self.fc1_p(gap_p) # n,256,1,1
gap_c = self.fc1_c(gap_c) # n,256,1,1
if self.use_bn:
gap_p = self.bn1_p(gap_p)
gap_c = self.bn1_c(gap_c)
gap_p = self.relu(gap_p)
gap_c = self.relu(gap_c)
atten_p = self.fc2_p(gap_p) # n, 256, 1, 1 -> n, 2048, 1, 1
atten_c = self.fc2_c(gap_c) # n, 256, 1, 1 -> n, 2048, 1, 1
atten_p = self.rsoftmax(atten_p).view(batch, -1, 1, 1) # (n, 2048) -> (n, 2048, 1, 1)
atten_c = self.rsoftmax(atten_c).view(batch, -1, 1, 1) # (n, 2048) -> (n, 2048, 1, 1)
if self.radix > 1:
attens_p = torch.split(atten_p, rchannel//self.radix, dim=1) # 2(n, 1024, 1, 1) tuple
attens_c = torch.split(atten_c, rchannel//self.radix, dim=1) # 2(n, 1024, 1, 1) tuple
splited_p = (gap_1, y) # ((n, 1024, h, w),(n, 1024, h, w))
splited_c = (gap_1, y) # ((n, 1024, h, w),(n, 1024, h, w))
out_p = sum([att * split for (att, split) in zip(attens_p, splited_p)]) # (n, 1024, h, w)
out_c = sum([att * split for (att, split) in zip(attens_c, splited_c)]) # (n, 1024, h, w)
else:
out_p = atten_p * y
out_c = atten_c * z
if self.radix > 1:
out = torch.cat([out_p, out_c], 1) # (n, 2048, h, w)
else:
out = sum([out_p, out_c])
return out.contiguous()
CANet整体模块,首先需要明确的几点:
1:backbone采用resnet50
2:在decoder采用的TransBasicBlock进行上采样
首先定义一些基本函数,然后对RGB和depth分别进行特征提取:
class ACNet(nn.Module):
def __init__(self, num_class=37, backbone='ResNet-101', pretrained=False, pcca5=False):
super(ACNet, self).__init__()
self.pcca5 = pcca5
self.backbone = backbone
if self.backbone == 'ResNet-50':
layers = [3, 4, 6, 3]
else:
layers = [3, 4, 23, 3]
block = Bottleneck
transblock = TransBasicBlock
# RGB image branch
self.inplanes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) # use PSPNet extractors
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
# depth image branch
self.inplanes = 64
self.conv1_d = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1_d = nn.BatchNorm2d(64)
self.relu_d = nn.ReLU(inplace=True)
self.maxpool_d = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1_d = self._make_layer(block, 64, layers[0])
self.layer2_d = self._make_layer(block, 128, layers[1], stride=2)
self.layer3_d = self._make_layer(block, 256, layers[2], stride=2)
self.layer4_d = self._make_layer(block, 512, layers[3], stride=2)
"""
# merge branch
self.atten_rgb_0 = self.channel_attention(64)
self.atten_depth_0 = self.channel_attention(64)
self.maxpool_m = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.atten_rgb_1 = self.channel_attention(64*4)
self.atten_depth_1 = self.channel_attention(64*4)
# self.conv_2 = nn.Conv2d(64*4, 64*4, kernel_size=1) #todo 用cat和conv降回通道数
self.atten_rgb_2 = self.channel_attention(128*4)
self.atten_depth_2 = self.channel_attention(128*4)
self.atten_rgb_3 = self.channel_attention(256*4)
self.atten_depth_3 = self.channel_attention(256*4)
self.atten_rgb_4 = self.channel_attention(512*4)
self.atten_depth_4 = self.channel_attention(512*4)
"""
self.inplanes = 64
self.layer1_m = self._make_layer(block, 64, layers[0])
self.layer2_m = self._make_layer(block, 128, layers[1], stride=2)
self.layer3_m = self._make_layer(block, 256, layers[2], stride=2)
self.layer4_m = self._make_layer(block, 512, layers[3], stride=2)
# agant module
self.agant0 = self._make_agant_layer(64, 64)
self.agant1 = self._make_agant_layer(64*4, 64)
self.agant2 = self._make_agant_layer(128*4, 128)
self.agant3 = self._make_agant_layer(256*4, 256)
self.agant4 = self._make_agant_layer(512*4, 512)
#transpose layer
self.inplanes = 512
self.deconv1 = self._make_transpose(transblock, 256, 6, stride=2)
self.deconv2 = self._make_transpose(transblock, 128, 4, stride=2)
self.deconv3 = self._make_transpose(transblock, 64, 3, stride=2)
self.deconv4 = self._make_transpose(transblock, 64, 3, stride=2)
# final blcok
self.inplanes = 64
self.final_conv = self._make_transpose(transblock, 64, 3)
self.final_deconv = nn.ConvTranspose2d(self.inplanes, num_class, kernel_size=2,
stride=2, padding=0, bias=True)
self.out5_conv = nn.Conv2d(256, num_class, kernel_size=1, stride=1, bias=True)
self.out4_conv = nn.Conv2d(128, num_class, kernel_size=1, stride=1, bias=True)
self.out3_conv = nn.Conv2d(64, num_class, kernel_size=1, stride=1, bias=True)
self.out2_conv = nn.Conv2d(64, num_class, kernel_size=1, stride=1, bias=True)
if self.pcca5:
self.conv_5a = nn.Sequential(nn.Conv2d(2048, 512, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU())
self.conv_5c = nn.Sequential(nn.Conv2d(2048, 512, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU())
self.pca_5 = PCAM_Module(512)
self.cca_5 = CCAM_Module(512)
"""
self.pconv_5 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
BatchNorm2d(512),
nn.ReLU())
self.cconv_5 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
BatchNorm2d(512),
nn.ReLU())
self.pconv_out = nn.Sequential(nn.Conv2d(512, 2048, kernel_size=3, stride=1, padding=1, bias=False),
BatchNorm2d(2048),
nn.ReLU(),
nn.Dropout2d(0.1, False))
self.cconv_out = nn.Sequential(nn.Conv2d(512, 2048, kernel_size=3, stride=1, padding=1, bias=False),
BatchNorm2d(2048),
nn.ReLU(),
nn.Dropout2d(0.1, False))
self.alpha = Parameter(torch.ones(1))
self.beta = Parameter(torch.ones(1))
"""
self.pconv_5 = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(1024),
nn.ReLU())
self.cconv_5 = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(1024),
nn.ReLU())
self.split_conv = FusionLayer(in_channels=1024, groups=1,radix=2, reduction_factor=4, norm_layer=nn.BatchNorm2d)
# weight initial
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
if pretrained:
self._load_resnet_pretrained()
其中分别调用了_make_layer函数,block函数,_make_agant_layer函数,_make_transpose函数。
1:_make_layer函数,将输入维度,输出维度,步长,上采样输入到block函数,返回的是一个列表,里面是block个layer。
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers)
2:block函数,就是一个普通的残差网络,维度由输入的inplane,到输出的inplane*4。
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation,
padding=dilation, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
3:_make_agant_layer函数,将刚才四倍输出变为原来的维度。
def _make_agant_layer(self, inplanes, planes):
layers = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size=1,
stride=1, padding=0, bias=False),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True)
)
return layers
4:_make_transpose函数。使用nn.ConvTranspose2d进行上采样,将layer放在一起,生成序列。这里的block是TransBasicBlock。
def _make_transpose(self, block, planes, blocks, stride=1):
upsample = None
if stride != 1:
upsample = nn.Sequential(
nn.ConvTranspose2d(self.inplanes, planes,
kernel_size=2, stride=stride,
padding=0, bias=False),
nn.BatchNorm2d(planes),
)
elif self.inplanes != planes:
upsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes),
)
layers = []
for i in range(1, blocks):
layers.append(block(self.inplanes, self.inplanes))
layers.append(block(self.inplanes, planes, stride, upsample))
self.inplanes = planes
return nn.Sequential(*layers)
接着对rgb和depth进行提取:
def encoder(self, rgb, depth):
rgb = self.conv1(rgb)
rgb = self.bn1(rgb)
rgb = self.relu(rgb)
depth = self.conv1_d(depth)
depth = self.bn1_d(depth)
depth = self.relu_d(depth)
m0 = rgb + depth
rgb = self.maxpool(rgb)
depth = self.maxpool_d(depth)
m = self.maxpool(m0)
# block 1
rgb = self.layer1(rgb)
depth = self.layer1_d(depth)
m = self.layer1_m(m)
m1 = m + rgb + depth
# block 2
rgb = self.layer2(rgb)
depth = self.layer2_d(depth)
m = self.layer2_m(m1)
m2 = m + rgb + depth
# block 3
rgb = self.layer3(rgb)
depth = self.layer3_d(depth)
m = self.layer3_m(m2)
m3 = m + rgb + depth
# block 4
rgb = self.layer4(rgb)
depth = self.layer4_d(depth)
m = self.layer4_m(m3)
if self.pcca5:
rgb_down = self.conv_5a(rgb)
depth_down = self.conv_5c(depth)
attention_position = self.pca_5(rgb_down, depth_down)
attention_channel = self.cca_5(rgb_down, depth_down)
p_out = self.pconv_5(attention_position)
c_out = self.cconv_5(attention_channel)
m4 = self.split_conv(m, p_out, c_out)
"""
smooth_p = self.pconv_5(attention_position)
smooth_c = self.cconv_5(attention_channel)
p_out = self.pconv_out(smooth_p)
c_out = self.cconv_out(smooth_c)
m4 = m + self.alpha * p_out + self.beta * c_out
"""
else:
m4 = m + rgb + depth
return m0, m1, m2, m3, m4 # channel of m is 2048
最后输入进decoder:
def decoder(self, fuse0, fuse1, fuse2, fuse3, fuse4):
agant4 = self.agant4(fuse4)
# upsample 1
x = self.deconv1(agant4)
if self.training:
out5 = self.out5_conv(x)
x = x + self.agant3(fuse3)
# upsample 2
x = self.deconv2(x)
if self.training:
out4 = self.out4_conv(x)
x = x + self.agant2(fuse2)
# upsample 3
x = self.deconv3(x)
if self.training:
out3 = self.out3_conv(x)
x = x + self.agant1(fuse1)
# upsample 4
x = self.deconv4(x)
if self.training:
out2 = self.out2_conv(x)
x = x + self.agant0(fuse0)
# final
x = self.final_conv(x)
out = self.final_deconv(x)
if self.training:
return out, out2, out3, out4, out5
return out
将encoder输出作为decoder输入,整个模型就搭建完毕了。
def forward(self, rgb, depth, phase_checkpoint=False):
fuses = self.encoder(rgb, depth)
m = self.decoder(*fuses)
return m