(即插即用模块-特征处理部分) 三十、(2024) BFAM & CBM & DFEM 特征聚合+特征提取+边界感知

在这里插入图片描述

paper:B2CNet: A Progressive Change Boundary-to-Center Refinement Network for Multitemporal Remote Sensing Images Change Detection

Code:https://github.com/bao11seven/B2CNet


1、Bitemporal Feature Aggregation Module

B2CNet 论文中提出的 双时相特征聚合模块(Bitemporal Feature Aggregation Module)改变边界感知模块(Change Boundary-Aware Module)深度特征提取模块(Deep Feature Extraction Module) 都是为了解决遥感图像变化检测中的关键问题,并提升检测精度。以下是每个模块的设计动机、原理和实现过程:

由于遥感图像变化检测需要综合考虑低级纹理信息和高级语义信息。而现有的方法往往只关注其中一种信息,导致检测结果不够完整。所以 BFAM 通过聚合来自不同感受野的低级纹理信息和高层次语义信息。使特征更加丰富。

对于输入特征,BFAM 的实现过程:

  1. 对输入的双时相特征进行通道拼接。
  2. 使用不同膨胀率的卷积操作提取多尺度特征。
  3. 使用 1x1 卷积进行通道降维。
  4. 使用 SimAM 注意力机制强化特征。
  5. 使用 SimAM 注意力机制分别提取不同时相特征的共性特征。
  6. 将共性特征与各自时相特征相乘,得到相似性特征。
  7. 将低级纹理特征、相似性特征和高级语义特征进行求和。
  8. 使用 SimAM 注意力机制进行特征聚合。

Bitemporal Feature Aggregation Module 结构图:
在这里插入图片描述


2、Change Boundary-Aware Module

在遥感图像变化检测中,边界信息对于准确识别变化区域至关重要。现有的方法往往忽略了边界信息与变化信息之间的关系,导致边界特征提取能力不足,影响检测精度。CBM 利用边界信息来辅助提取变化区域特征,其通过边缘增强操作,强化边界信息,使其在特征图中更加突出。此外,CBM 学习如何更好地利用边界信息来定位变化区域,并指导其他模块进行特征解耦。

对于输入特征,CBM 的实现过程:

  1. 使用 SimAM 注意力机制识别特征图中的显著区域。
  2. 通过池化、减法和卷积操作提取边缘特征。
  3. 使用 SimAM 注意力机制再次强化边缘特征。
  4. 对增强后的边缘特征进行特征差异操作,得到边缘增强的差异特征。
  5. 对 CBM 分支的预测结果进行监督学习,指导模型更好地定位变化区域。

Change Boundary-Aware Module 结构图:
在这里插入图片描述


3、Deep Feature Extraction Module

遥感图像变化检测需要提取更深层次的语义特征,以更好地理解复杂场景。现有的方法往往由于特征解耦过程的信息损失,导致高级语义信息不足。DFEM 聚合 CBM 和 BFAM 的特征,进行深度特征提取。然后利用残差操作融合特征信息,并提取更深层次的高层次语义特征。

对于输入特征,DFEM 的实现过程:

  1. 对 CBM 和 BFAM 的特征进行拼接和求和。
  2. 使用 1x1 卷积进行通道降维。
  3. 使用残差拼接操作保留信息完整性。
  4. 使用 3x3 卷积提取深度特征。
  5. 将上一层的 CBM 输出特征与深度特征进行求和。
  6. 使用 SimAM 注意力机制进行特征聚合。

Deep Feature Extraction Module 结构图:
在这里插入图片描述


4、代码实现

import torch
import torch.nn as nn
from einops import rearrange


class simam_module(torch.nn.Module):
    def __init__(self, channels=None, e_lambda=1e-4):
        super(simam_module, self).__init__()

        self.activaton = nn.Sigmoid()
        self.e_lambda = e_lambda

    def __repr__(self):
        s = self.__class__.__name__ + '('
        s += ('lambda=%f)' % self.e_lambda)
        return s

    @staticmethod
    def get_module_name():
        return "simam"

    def forward(self, x):
        b, c, h, w = x.size()

        n = w * h - 1

        x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
        y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5

        return x * self.activaton(y)


class BFAM(nn.Module):
    def __init__(self, inp, out):
        super(BFAM, self).__init__()

        self.pre_siam = simam_module()
        self.lat_siam = simam_module()

        out_1 = int(inp / 2)

        self.conv_1 = nn.Conv2d(inp*2, out_1, padding=1, kernel_size=3, groups=out_1,
                                dilation=1)
        self.conv_2 = nn.Conv2d(inp*2, out_1, padding=2, kernel_size=3, groups=out_1,
                                dilation=2)
        self.conv_3 = nn.Conv2d(inp*2, out_1, padding=3, kernel_size=3, groups=out_1,
                                dilation=3)
        self.conv_4 = nn.Conv2d(inp*2, out_1, padding=4, kernel_size=3, groups=out_1,
                                dilation=4)

        self.fuse = nn.Sequential(
            nn.Conv2d(out_1 * 4, out_1*2, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_1*2),
            nn.ReLU(inplace=True)
        )

        self.fuse_siam = simam_module()

        self.out = nn.Sequential(
            nn.Conv2d(out_1*2, out, kernel_size=3, padding=1),
            nn.BatchNorm2d(out),
            nn.ReLU(inplace=True)
        )

    def forward(self, inp1, inp2, last_feature=None):
        x = torch.cat([inp1, inp2], dim=1)
        c1 = self.conv_1(x)
        c2 = self.conv_2(x)
        c3 = self.conv_3(x)
        c4 = self.conv_4(x)
        cat = torch.cat([c1, c2, c3, c4], dim=1)
        fuse = self.fuse(cat)
        inp1_siam = self.pre_siam(inp1)
        inp2_siam = self.lat_siam(inp2)

        inp1_mul = torch.mul(inp1_siam, fuse)
        inp2_mul = torch.mul(inp2_siam, fuse)
        fuse = self.fuse_siam(fuse)
        if last_feature is None:
            out = self.out(fuse + inp1 + inp2 + inp2_mul + inp1_mul)
        else:
            out = self.out(fuse + inp2_mul + inp1_mul + last_feature + inp1 + inp2)
        out = self.fuse_siam(out)

        return out


class diff_moudel(nn.Module):
    def __init__(self, in_channel):
        super(diff_moudel, self).__init__()
        self.avg_pool = nn.AvgPool2d((3, 3), stride=1, padding=1)
        self.conv_1 = nn.Conv2d(in_channel, in_channel, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(in_channel)
        self.sigmoid = nn.Sigmoid()
        self.simam = simam_module()

    def forward(self, x):
        x = self.simam(x)
        edge = x - self.avg_pool(x)  # Xi=X-Avgpool(X)
        weight = self.sigmoid(self.bn1(self.conv_1(edge)))
        # weight = self.conv_1(edge)
        out = weight * x + x
        out = self.simam(out)
        return out


class CBM(nn.Module):
    def __init__(self, in_channel):
        super(CBM, self).__init__()
        self.diff_1 = diff_moudel(in_channel)
        self.diff_2 = diff_moudel(in_channel)
        self.simam = simam_module()

    def forward(self, x1, x2):
        d1 = self.diff_1(x1)
        d2 = self.diff_2(x2)
        d = torch.abs(d1 - d2)
        d = self.simam(d)
        return d


class DFEM(nn.Module):
    def __init__(self, inc, outc):
        super(DFEM, self).__init__()

        self.Conv_1 = nn.Sequential(nn.Conv2d(inc*2, outc, kernel_size=1),
                                    nn.BatchNorm2d(outc),
                                    nn.ReLU(inplace=True)
                                    )

        self.Conv = nn.Sequential(nn.Conv2d(outc, outc, kernel_size=3, stride=1, padding=1),
                                  nn.BatchNorm2d(outc),
                                  nn.ReLU(inplace=True),
                                  )
        self.relu = nn.ReLU(inplace=True)
        self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, diff, accom):
        cat = torch.cat([accom, diff], dim=1)
        cat = self.Conv_1(cat) + diff + accom
        c = self.Conv(cat) + cat
        c = self.relu(c) + diff
        c = self.Up(c)
        return c


if __name__ == '__main__':
    x = torch.randn(4, 64, 128, 128).cuda()
    y = torch.randn(4, 64, 128, 128).cuda()

    # model = BFAM(64,64).cuda()
    # out = model(x,y)

    # model = CBM(64).cuda()
    # out = model(x, y)

    model = DFEM(64,64).cuda()
    out = model(x, y)

    print(out.shape)

### ViT(Vision Transformer)即插即用特征增强模块概述 视觉Transformer (ViT) 是一种基于Transformer架构的模型,在计算机视觉领域取得了显著成果。为了进一步提升其性能,许多研究人员提出了多种即插即用特征增强模块。这些模块可以轻松集成到现有的ViT框架中,从而提高模型的表现能力。 #### 完全注意力网络 (FANs) 完全注意力网络 (Fully Attentional Networks, FANs)[^1] 提出了通过强化自注意力机制来改善ViT的学习能力和鲁棒性。具体来说,FANs引入了一种新的注意力通道模块,该模块增强了自注意力在学习鲁棒特征表示方面的作用。这种方法不仅提升了模型对噪声数据的容忍度,还在面对不同类型的图像损坏时表现得更加稳定。 以下是实现FANs的一个简单代码片段: ```python import torch.nn as nn class FullyAttentionModule(nn.Module): def __init__(self, dim): super(FullyAttentionModule, self).__init__() self.attention = nn.MultiheadAttention(dim, num_heads=8) def forward(self, x): attn_output, _ = self.attention(x, x, x) return attn_output + x # Residual connection ``` #### 协调注意力 (CA - Coordinate Attention) 协调注意力(CA)[^2]是一种轻量级的注意力机制,旨在捕捉空间维度上的依赖关系。它通过对输入特征图的高度和宽度分别应用一维卷积操作,提取坐标级别的上下文信息。此方法能够有效减少计算开销并保持较高的精度增益。 下面展示如何将CA应用于PyTorch中的ViT结构: ```python import torch from torch import nn class CoordAtt(nn.Module): def __init__(self, inp, oup, reduction=32): super(CoordAtt, self).__init__() self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) self.pool_w = nn.AdaptiveAvgPool2d((1, None)) temp_c = max(oup//reduction, 4) self.conv1 = nn.Conv2d(inp, temp_c, kernel_size=1, stride=1, padding=0) self.bn1 = nn.BatchNorm2d(temp_c) self.act = nn.ReLU() self.conv_h = nn.Conv2d(temp_c, oup, kernel_size=1, stride=1, padding=0) self.conv_w = nn.Conv2d(temp_c, oup, kernel_size=1, stride=1, padding=0) def forward(self, x): identity = x n,c,h,w = x.size() x_h = self.pool_h(x) x_w = self.pool_w(x).permute(0, 1, 3, 2) y = torch.cat([x_h, x_w], dim=2) y = self.conv1(y) y = self.bn1(y) y = self.act(y) x_h, x_w = torch.split(y, [h, w], dim=2) x_w = x_w.permute(0, 1, 3, 2) a_h = self.conv_h(x_h).sigmoid() a_w = self.conv_w(x_w).sigmoid() out = identity * a_w * a_h return out ``` 以上两种技术只是众多可用于ViT的即插即用模块的一部分。每种模块都有各自的特点以及适用场景,因此可以根据实际需求选择合适的方案进行实验验证。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值