CMX代码对比SA-Gate代码

由于这两个模块比较相像,而且结果也都十分的好,且CMX在nyu数据集上取得了第一的成绩,两个都是在encoder进行特征的矫正,交互,融合。
先看一下模型框架:
CMX
SA-gate
这里主要看他们的创新点模块的实现,就不看整体框架的搭建了。
先看CMX的:
模型首先经过FRM模块,

class FeatureRectifyModule(nn.Module):
    def __init__(self, dim, reduction=1, lambda_c=.5, lambda_s=.5):
        super(FeatureRectifyModule, self).__init__()
        self.lambda_c = lambda_c
        self.lambda_s = lambda_s
        self.channel_weights = ChannelWeights(dim=dim, reduction=reduction)
        self.spatial_weights = SpatialWeights(dim=dim, reduction=reduction)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
    
    def forward(self, x1, x2):
        channel_weights = self.channel_weights(x1, x2)
        spatial_weights = self.spatial_weights(x1, x2)
        out_x1 = x1 + self.lambda_c * channel_weights[1] * x2 + self.lambda_s * spatial_weights[1] * x2
        out_x2 = x2 + self.lambda_c * channel_weights[0] * x1 + self.lambda_s * spatial_weights[0] * x1
        return out_x1, out_x2 

输入的rgb和depth首先经过通道和空间注意力生成权重,我们调到两个注意力的代码:
先看ChannelWeights,首先获得x1大小,然后将x1和x2拼接起来,经过自适应平均池化和自适应最大池化,并转换到(B,2C)大小,然后将两个结果拼接起来,变为(B,4C),经过mlp函数,通道又变为2C,然后reshape为(2 B C 1 1)大小。那么他就是有两个tensor组成的,大小分别为channel_weights[0]=(B,C,1,1), channel_weights[1]=(B,C,1,1)。

class ChannelWeights(nn.Module):
    def __init__(self, dim, reduction=1):
        super(ChannelWeights, self).__init__()
        self.dim = dim
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.mlp = nn.Sequential(
                    nn.Linear(self.dim * 4, self.dim * 4 // reduction),
                    nn.ReLU(inplace=True),
                    nn.Linear(self.dim * 4 // reduction, self.dim * 2), 
                    nn.Sigmoid())

    def forward(self, x1, x2):
        B, _, H, W = x1.shape
        x = torch.cat((x1, x2), dim=1)
        avg = self.avg_pool(x).view(B, self.dim * 2)
        max = self.max_pool(x).view(B, self.dim * 2)
        y = torch.cat((avg, max), dim=1) # B 4C
        y = self.mlp(y).view(B, self.dim * 2, 1)
        channel_weights = y.reshape(B, 2, self.dim, 1, 1).permute(1, 0, 2, 3, 4) # 2 B C 1 1
        return channel_weights  # channel_weights[0]=(B,C,1,1) channel_weights[1]=(B,C,1,1)

接着看SpatialWeights,将rgb和depth拼接起来,然后经过mlp函数,经过reshape函数,在permute到(2,B,1,H,W)大小。同样spatial_weights[0]=(B 1 H W),spatial_weights[1]=(B 1 H W)。

class SpatialWeights(nn.Module):
    def __init__(self, dim, reduction=1):
        super(SpatialWeights, self).__init__()
        self.dim = dim
        self.mlp = nn.Sequential(
                    nn.Conv2d(self.dim * 2, self.dim // reduction, kernel_size=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(self.dim // reduction, 2, kernel_size=1), 
                    nn.Sigmoid())

    def forward(self, x1, x2):
        B, _, H, W = x1.shape
        x = torch.cat((x1, x2), dim=1) # B 2C H W
        spatial_weights = self.mlp(x).reshape(B, 2, 1, H, W).permute(1, 0, 2, 3, 4) # 2 B 1 H W
        return spatial_weights

我们得到两个权重,然后继续往下看,out_x1 = x1 + self.lambda_c * channel_weights[1] * x2 + self.lambda_s * spatial_weights[1] * x2,如图标注了数据处理。out_x2同理。
在这里插入图片描述
然后是FFM模块,分为两个stage:

class FeatureFusionModule(nn.Module):
    def __init__(self, dim, reduction=1, num_heads=None, norm_layer=nn.BatchNorm2d):
        super().__init__()
        self.cross = CrossPath(dim=dim, reduction=reduction, num_heads=num_heads)
        self.channel_emb = ChannelEmbed(in_channels=dim*2, out_channels=dim, reduction=reduction, norm_layer=norm_layer)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x1, x2):
        B, C, H, W = x1.shape
        x1 = x1.flatten(2).transpose(1, 2) #(B,C,N)-->(B,N,C)
        x2 = x2.flatten(2).transpose(1, 2) #(B,C,N)-->(B,N,C)
        x1, x2 = self.cross(x1, x2)  #(B,N,C)
        merge = torch.cat((x1, x2), dim=-1) #(B,N,2C)
        merge = self.channel_emb(merge, H, W)#(B,C,H,W)
        
        return merge

首先我们将输入展平,并转置为(B,N,C)形式,然后两个输入共同经过cross函数,对应于crosspath函数:

class CrossPath(nn.Module):
    def __init__(self, dim, reduction=1, num_heads=None, norm_layer=nn.LayerNorm):
        super().__init__()
        self.channel_proj1 = nn.Linear(dim, dim // reduction * 2)
        self.channel_proj2 = nn.Linear(dim, dim // reduction * 2)
        self.act1 = nn.ReLU(inplace=True)
        self.act2 = nn.ReLU(inplace=True)
        self.cross_attn = CrossAttention(dim // reduction, num_heads=num_heads)
        self.end_proj1 = nn.Linear(dim // reduction * 2, dim)
        self.end_proj2 = nn.Linear(dim // reduction * 2, dim)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)

    def forward(self, x1, x2):
        y1, u1 = self.act1(self.channel_proj1(x1)).chunk(2, dim=-1) #(B,N,C)->(B,N,2C)-->{(b,n,c),(b,n,c)}
        y2, u2 = self.act2(self.channel_proj2(x2)).chunk(2, dim=-1) #(B,N,C)->(B,N,2C)-->{(b,n,c),(b,n,c)}
        v1, v2 = self.cross_attn(u1, u2) #(B,N,C)
        y1 = torch.cat((y1, v1), dim=-1)#(b,n,2c)
        y2 = torch.cat((y2, v2), dim=-1)
        out_x1 = self.norm1(x1 + self.end_proj1(y1))#(B,N,2C)-->(B,N,C)
        out_x2 = self.norm2(x2 + self.end_proj2(y2))
        return out_x1, out_x2

展平的两个输入,先投射,维度扩大两倍,然后进行切割为两个维度为(B,N,C)大小的tensor,输出为y1, u1,同理x2,输出为 y2, u2,然后u1,u2做一个交叉注意力。cross_attn,跳到CrossAttention。

class CrossAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None):
        super(CrossAttention, self).__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.kv1 = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.kv2 = nn.Linear(dim, dim * 2, bias=qkv_bias)

    def forward(self, x1, x2):
        B, N, C = x1.shape
        #(B,N,C)-->(B,N,8,C//8)-->(B,8,N,C//8)
        q1 = x1.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
        q2 = x2.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()
        #(B,N,C)-->(B,N,2C)-->(B,N,2,8,C//8)-->(2,B,8,N,C//8)={(B,8,N,C//8),(B,8,N,C//8)}
        k1, v1 = self.kv1(x1).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
        k2, v2 = self.kv2(x2).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
        #((B,8,N,C//8))-->(B,8,C//8,N)@(B,8,N,C//8)=(B,8,C//8,C//8)
        ctx1 = (k1.transpose(-2, -1) @ v1) * self.scale
        ctx1 = ctx1.softmax(dim=-2)
        #(B,8,N,C//8)-->(B,8,C//8,N)@(B,8,N,C//8)=(B,8,C//8,C//8)
        ctx2 = (k2.transpose(-2, -1) @ v2) * self.scale
        ctx2 = ctx2.softmax(dim=-2)
        #(B,8,N,C//8)@(B,8,C//8,C//8)=(B,8,N,C//8)-->(B,N,C)
        x1 = (q1 @ ctx2).permute(0, 2, 1, 3).reshape(B, N, C).contiguous() 
        x2 = (q2 @ ctx1).permute(0, 2, 1, 3).reshape(B, N, C).contiguous() 

        return x1, x2

在这里进行注意力计算,维度变换已经标出,主要是Q@(K转置乘以 V),结果大小仍为(B,N,C)。
回到crosspath函数中,生成的v1与y1进行concat,v2与y2进行concat,维度变为(B,N,2C),结果在进行一个线性投射,维度变为(B,N,C)。
回到FeatureFusionModule中,生成的两个结果进行concat,通过channel_emb。

class ChannelEmbed(nn.Module):
    def __init__(self, in_channels, out_channels, reduction=1, norm_layer=nn.BatchNorm2d):
        super(ChannelEmbed, self).__init__()
        self.out_channels = out_channels
        self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.channel_embed = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels//reduction, kernel_size=1, bias=True),
                        nn.Conv2d(out_channels//reduction, out_channels//reduction, kernel_size=3, stride=1, padding=1, bias=True, groups=out_channels//reduction),
                        nn.ReLU(inplace=True),
                        nn.Conv2d(out_channels//reduction, out_channels, kernel_size=1, bias=True),
                        norm_layer(out_channels) 
                        )
        self.norm = norm_layer(out_channels)
        
    def forward(self, x, H, W):
        B, N, _C = x.shape #(B,N,2C)
        x = x.permute(0, 2, 1).reshape(B, _C, H, W).contiguous() #(B,N,2C)-->(B,2C,N)-->(B,2C,H,W)
        residual = self.residual(x) #(B,2C,H,W)-->(B,C,H,W)
        x = self.channel_embed(x) #(B,2C,H,W)-->(B,C,H,W)
        out = self.norm(residual + x)
        return out

在这里插入图片描述

经过1x1卷积做为跳跃连接,同时经过channel_embed,结果进行相加得到最终输出。即为FeatureFusionModule最终输出。

------------------------------------------------------------------------CMX代码已完结-------------------------------------------------------------------
同理我们看一下SA-Gate模块的代码编写:

class SAGate(nn.Module):
    def __init__(self, in_planes, out_planes, reduction=16, bn_momentum=0.0003):
        self.init__ = super(SAGate, self).__init__()
        self.in_planes = in_planes
        self.bn_momentum = bn_momentum

        self.fsp_rgb = FSP(in_planes, out_planes, reduction)
        self.fsp_hha = FSP(in_planes, out_planes, reduction)

        self.gate_rgb = nn.Conv2d(in_planes*2, 1, kernel_size=1, bias=True)
        self.gate_hha = nn.Conv2d(in_planes*2, 1, kernel_size=1, bias=True)

        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        rgb, hha = x
        b, c, h, w = rgb.size()

        rec_rgb = self.fsp_rgb(hha, rgb)
        rec_hha = self.fsp_hha(rgb, hha)

        cat_fea = torch.cat([rec_rgb, rec_hha], dim=1)

        attention_vector_l = self.gate_rgb(cat_fea)
        attention_vector_r = self.gate_hha(cat_fea)

        attention_vector = torch.cat([attention_vector_l, attention_vector_r], dim=1)
        attention_vector = self.softmax(attention_vector)
        attention_vector_l, attention_vector_r = attention_vector[:, 0:1, :, :], attention_vector[:, 1:2, :, :]
        merge_feature = rgb*attention_vector_l + hha*attention_vector_r

        rgb_out = (rgb + merge_feature) / 2
        hha_out = (hha + merge_feature) / 2

        rgb_out = self.relu1(rgb_out)
        hha_out = self.relu2(hha_out)

        return [rgb_out, hha_out], merge_feature

对于两个输入,rgb和hha,首先经过FSP函数,

class FSP(nn.Module):
    def __init__(self, in_planes, out_planes, reduction=16):
        super(FSP, self).__init__()
        self.filter = FilterLayer(2*in_planes, out_planes, reduction)

    def forward(self, guidePath, mainPath):
        combined = torch.cat((guidePath, mainPath), dim=1)
        channel_weight = self.filter(combined)
        out = mainPath + channel_weight * guidePath
        return out

在rec_rgb = self.fsp_rgb(hha, rgb)中,跳到FSP,hha是guidepath,rgb是mainpath,将两个输入拼接起来,经过filter函数,

class FilterLayer(nn.Module):
    def __init__(self, in_planes, out_planes, reduction=16):
        super(FilterLayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_planes, out_planes // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(out_planes // reduction, out_planes),
            nn.Sigmoid()
        )
        self.out_planes = out_planes

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, self.out_planes, 1, 1)
        return y

首先进行一个平均池化,并view到b,c大小,然后经过一个线性层,view到(b,c,1,1)大小,同理ec_rgb也是如此,最终生成了两个权重,将通道权重乘以guidepath,这里对hha和rgb都进行通道注意力,并与主path进行相加。FSP函数结束后,回到主函数。

我们将经过矫正过的hha和rgb进行拼接。然后就进入了特征聚合模块。经过两个输出通道为1的1x1卷积,获得两个注意力向量,再拼接起来,经过softmax获得相似度矩阵,经过两个切片获得Argb和Ahha,与原始的输出相乘再相加。最后再与原始的输入逐像素相加求平均。
在这里插入图片描述

  • 4
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值