CVPR 2025 自适应矩形卷积模块,即插即用

点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

来源:ai缝合大王

仅用于学术分享,若侵权请联系删除

论文地址: https://arxiv.org/pdf/2503.00467

图片

创新点

1. 提出“自适应矩形卷积(ARConv)”模块

  • 高度和宽度自适应学习:与传统卷积固定的正方形感受野不同,ARConv能够动态学习卷积核的高度和宽度,生成矩形卷积核,从而根据图像中不同物体的尺寸,灵活调整卷积窗口的形状。

  • 动态采样点数量调整:ARConv不仅能改变卷积核的形状,还能根据学习到的高度和宽度动态调整采样点数量,解决了以往卷积中采样点数量固定、难以适配多尺度目标的问题。

  • 计算复杂度友好:相较于传统可变形卷积需要学习大量偏移参数,ARConv仅需学习两个参数(高度、宽度),在小数据集(如遥感图像融合)任务中具有更好的收敛性。

2. 提出ARNet网络框架

  • 以U-Net为基础,使用ARConv替代U-Net中的标准卷积,设计出更具尺度自适应能力的网络ARNet,在遥感图像的pansharpening任务上实现更优的特征提取和融合效果。

3. 新颖的“仿射变换”机制

  • 在卷积操作完成后,ARConv还引入仿射变换模块,为特征图提供空间变换能力,进一步提升模型对不同形状和位置目标的适应性。

4. 泛化性和可移植性

  • ARConv设计为即插即用模块,论文通过实验证明将ARConv集成到现有的pansharpening网络(如FusionNet、LAGNet、CANNet)中后,均能有效提升这些网络的性能。

方法

整体架构

       ARNet模型整体采用U-Net风格的Encoder-Decoder结构,结合跳跃连接以保留高分辨率特征信息。在编码器和解码器的各个层中,ARNet用自适应矩形卷积(ARConv)替代了标准卷积,形成了自定义的AR-ResBlock模块,使网络具备根据图像中不同目标尺度自适应调整卷积核形状和采样点数量的能力。同时,ARNet在解码阶段通过转置卷积进行逐步上采样,最终输出高分辨率多光谱图像(HRMS)。整体结构实现了高效的空间-光谱信息融合与细节恢复,显著提升遥感图像pansharpening任务中的图像质量。

图片

1. 整体架构:Encoder-Decoder结构

  • Encoder部分(下采样)

    • 类似U-Net,ARNet的下采样部分由多个 Down-Block 组成。

    • 每个Down-Block内包含了带有ARConv的AR-ResBlock模块,即在ResBlock(残差块)中用ARConv替代了普通卷积,能够在下采样过程中自适应捕获不同尺度的特征。

  • Decoder部分(上采样)

    • 上采样阶段使用 Up-Block,配合转置卷积逐步恢复空间分辨率。

    • 同样,ARConv模块嵌入在解码器的残差块中,持续进行尺度自适应特征提取。

2. 跳跃连接(Skip Connections)

  • 依然保持U-Net的特点,将Encoder中的浅层特征通过跳跃连接传递到Decoder对应层,保证高分辨率的空间信息不丢失,提升细节恢复能力。

3. 输入与输出

  • 输入:将高分辨率全色图像(PAN)与上采样到相同尺寸的低分辨率多光谱图像(LRMS)进行通道拼接(concat),作为网络输入。

  • 输出:输出高分辨率多光谱图像(HRMS),即融合后的图像。

4. 模块特色

  • AR-ResBlock:标准ResBlock的替代版,核心为ARConv模块。

  • ARConv模块:在特征提取层代替普通卷积,实现卷积核尺寸自适应 + 采样点动态选择 + 仿射空间变换,提升多尺度特征的提取能力。

图片

  • Affine Transformation:在卷积操作之后,针对每个位置输出特征图引入仿射变换,增强空间适配能力。

即插即用模块作用

ARConv 作为一个即插即用模块主要适用于

遥感图像处理任务(尤其是pansharpening任务)
  • 特点:遥感图像中存在大量尺度差异巨大的目标,比如小尺度的车辆和大尺度的建筑物、森林、道路等。

  • 适用性:ARConv可以自适应调整卷积核的形状(矩形)和采样点数量,灵活适应不同尺度、不同形状的地物特征,特别适合处理遥感图像中的“多尺度、多形状”物体。

多尺度、多形状特征显著的图像处理场景
  • 医学影像(病灶区域大小不一)卫星影像超分辨率目标检测与分割任务等。

  • 适用性:ARConv在这些任务中也能作为卷积替代模块,提升模型对尺度变化显著或目标形状多样化场景的感知和提取能力。

ARConv 作为一个即插即用模块的作用

🌟 (1)增强多尺度特征提取能力
  • 自动学习每个位置的卷积核“高度”和“宽度”,自适应形成“矩形卷积”,不再局限于固定的正方形窗口。

  • 能根据图像内容动态调整“采样点数量”,在大目标区域密集采样,小目标区域稀疏采样,精细捕捉不同尺度的空间特征。

🌟 (2)减少参数、降低计算量
  • 相较于传统的可变形卷积(Deformable Conv)需要为每个采样点学习偏移量,ARConv只需学习两个参数(h 和 w),在小数据集(如遥感、医疗等)上更容易训练收敛,计算开销更低。

🌟 (3)提升空间适应性与灵活性
  • 结合仿射变换,ARConv不仅可变形、可调整采样密度,还能在输出阶段进一步提升特征图的空间适配能力,更好应对目标位移、扭曲等问题。

消融实验结果

图片

内容:测试了ARConv三个关键组件的作用,具体为:

  1. HWA(Height and Width Adaptation,高度和宽度自适应)

  2. NSPA(Number of Sampling Points Adaptation,采样点数量自适应)

  3. AT(Affine Transformation,仿射变换)

结果

    • 去掉HWA或NSPA后,模型性能有轻微下降,说明自适应卷积核的形状和采样点数量对特征提取是有效的。

    • 去掉仿射变换(No AT)后,性能下降更明显,说明仿射变换进一步提升了空间适应能力。

结论:这表明三个组件在ARConv中的设计均有助于性能提升,且仿射变换对模型的空间灵活性尤为重要。

图片

内容:探究ARConv中卷积核高度和宽度学习范围(kernel size range)对性能的影响,测试了以下范围:

    • 1-3

    • 1-9

    • 1-18(最佳)

    • 1-36

    • 1-63

结果

    • 1-18范围下,模型获得了最优的性能,既能避免采样点过于稠密带来的噪声,也避免采样点过于稀疏造成的信息丢失。

    • 当卷积核范围过大(如1-36或1-63),性能反而下降,说明ARConv需要合理的尺度区间,才能兼顾全局感知和细节捕捉

结论:适度的卷积核自适应范围能够平衡特征提取的“密度”和“尺度”,进一步验证了ARConv中动态控制核大小的必要性。

即插即用模块

import torch
import torch.nn as nn

class ARConv(nn.Module):
    def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, l_max=9, w_max=9, flag=False, modulation=True):
        super(ARConv, self).__init__()
        self.lmax = l_max
        self.wmax = w_max
        self.inc = inc
        self.outc = outc
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.zero_padding = nn.ZeroPad2d(padding)
        self.flag = flag
        self.modulation = modulation
        self.i_list = [33, 35, 53, 37, 73, 55, 57, 75, 77]
        self.convs = nn.ModuleList(
            [
                nn.Conv2d(inc, outc, kernel_size=(i // 10, i % 10), stride=(i // 10, i % 10), padding=0)
                for i inself.i_list
            ]
        )
        self.m_conv = nn.Sequential(
            nn.Conv2d(inc, outc, kernel_size=3, padding=1, stride=stride),
            nn.LeakyReLU(),
            nn.Dropout2d(0.3),
            nn.Conv2d(outc, outc, kernel_size=3, padding=1, stride=stride),
            nn.LeakyReLU(),
            nn.Dropout2d(0.3),
            nn.Conv2d(outc, outc, kernel_size=3, padding=1, stride=stride),
            nn.Tanh()
        )
        self.b_conv = nn.Sequential(
            nn.Conv2d(inc, outc, kernel_size=3, padding=1, stride=stride),
            nn.LeakyReLU(),
            nn.Dropout2d(0.3),
            nn.Conv2d(outc, outc, kernel_size=3, padding=1, stride=stride),
            nn.LeakyReLU(),
            nn.Dropout2d(0.3),
            nn.Conv2d(outc, outc, kernel_size=3, padding=1, stride=stride)
        )
        self.p_conv = nn.Sequential(
            nn.Conv2d(inc, inc, kernel_size=3, padding=1, stride=stride),
            nn.BatchNorm2d(inc),
            nn.LeakyReLU(),
            nn.Dropout2d(0),
            nn.Conv2d(inc, inc, kernel_size=3, padding=1, stride=stride),
            nn.BatchNorm2d(inc),
            nn.LeakyReLU(),
        )
        self.l_conv = nn.Sequential(
            nn.Conv2d(inc, 1, kernel_size=3, padding=1, stride=stride),
            nn.BatchNorm2d(1),
            nn.LeakyReLU(),
            nn.Dropout2d(0),
            nn.Conv2d(1, 1, 1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.w_conv = nn.Sequential(
            nn.Conv2d(inc, 1, kernel_size=3, padding=1, stride=stride),
            nn.BatchNorm2d(1),
            nn.LeakyReLU(),
            nn.Dropout2d(0),
            nn.Conv2d(1, 1, 1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.dropout1 = nn.Dropout(0.3)
        self.dropout2 = nn.Dropout2d(0.3)
        self.hook_handles = []
        self.hook_handles.append(self.m_conv[0].register_full_backward_hook(self._set_lr))
        self.hook_handles.append(self.m_conv[1].register_full_backward_hook(self._set_lr))
        self.hook_handles.append(self.b_conv[0].register_full_backward_hook(self._set_lr))
        self.hook_handles.append(self.b_conv[1].register_full_backward_hook(self._set_lr))
        self.hook_handles.append(self.p_conv[0].register_full_backward_hook(self._set_lr))
        self.hook_handles.append(self.p_conv[1].register_full_backward_hook(self._set_lr))
        self.hook_handles.append(self.l_conv[0].register_full_backward_hook(self._set_lr))
        self.hook_handles.append(self.l_conv[1].register_full_backward_hook(self._set_lr))
        self.hook_handles.append(self.w_conv[0].register_full_backward_hook(self._set_lr))
        self.hook_handles.append(self.w_conv[1].register_full_backward_hook(self._set_lr))

        self.reserved_NXY = nn.Parameter(torch.tensor([3, 3], dtype=torch.int32), requires_grad=False)

    @staticmethod
    def _set_lr(module, grad_input, grad_output):
        grad_input = tuple(g * 0.1if g is not None else None for g in grad_input)
        grad_output = tuple(g * 0.1if g is not None else None for g in grad_output)
        return grad_input

    def remove_hooks(self):
        for handle inself.hook_handles:
            handle.remove() # 移除钩子函数
        self.hook_handles.clear() # 清空句柄列表

    def forward(self, x, epoch, hw_range):
        assert isinstance(hw_range, list) and len(hw_range) == 2, "hw_range should be a list with 2 elements, represent the range of h w"
        scale = hw_range[1] // 9
        if hw_range[0] == 1 and hw_range[1] == 3:
            scale = 1
        m = self.m_conv(x)
        bias = self.b_conv(x)
        offset = self.p_conv(x * 100)
        l = self.l_conv(offset) * (hw_range[1] - 1) + 1# b, 1, h, w
        w = self.w_conv(offset) * (hw_range[1] - 1) + 1# b, 1, h, w
        if epoch <= 100:
            mean_l = l.mean(dim=0).mean(dim=1).mean(dim=1)
            mean_w = w.mean(dim=0).mean(dim=1).mean(dim=1)
            N_X = int(mean_l // scale)
            N_Y = int(mean_w // scale)
            def phi(x):
                if x % 2 == 0:
                    x -= 1
                return x
            N_X, N_Y = phi(N_X), phi(N_Y)
            N_X, N_Y = max(N_X, 3), max(N_Y, 3)
            N_X, N_Y = min(N_X, 7), min(N_Y, 7)
            if epoch == 100:
                self.reserved_NXY = self.reserved_NXY = nn.Parameter(
                    torch.tensor([N_X, N_Y], dtype=torch.int32, device=x.device),
                    requires_grad=False
                )
        else:
            N_X = self.reserved_NXY[0]
            N_Y = self.reserved_NXY[1]

        N = N_X * N_Y
        # print(N_X, N_Y)
        l = l.repeat([1, N, 1, 1])
        w = w.repeat([1, N, 1, 1])
        offset = torch.cat((l, w), dim=1)
        dtype = offset.data.type()
        ifself.padding:
            x = self.zero_padding(x)
        p = self._get_p(offset, dtype, N_X, N_Y) # (b, 2*N, h, w)
        p = p.contiguous().permute(0, 2, 3, 1) # (b, h, w, 2*N)
        q_lt = p.detach().floor()
        q_rb = q_lt + 1
        q_lt = torch.cat(
            [
                torch.clamp(q_lt[..., :N], 0, x.size(2) - 1),
                torch.clamp(q_lt[..., N:], 0, x.size(3) - 1),
            ],
            dim=-1,
        ).long()
        q_rb = torch.cat(
            [
                torch.clamp(q_rb[..., :N], 0, x.size(2) - 1),
                torch.clamp(q_rb[..., N:], 0, x.size(3) - 1),
            ],
            dim=-1,
        ).long()
        q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
        q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
        # clip p
        p = torch.cat(
            [
                torch.clamp(p[..., :N], 0, x.size(2) - 1),
                torch.clamp(p[..., N:], 0, x.size(3) - 1),
            ],
            dim=-1,
        )
        # bilinear kernel (b, h, w, N)
        g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (
                1 + (q_lt[..., N:].type_as(p) - p[..., N:])
        )
        g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (
                1 - (q_rb[..., N:].type_as(p) - p[..., N:])
        )
        g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (
                1 - (q_lb[..., N:].type_as(p) - p[..., N:])
        )
        g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (
                1 + (q_rt[..., N:].type_as(p) - p[..., N:])
        )
        # (b, c, h, w, N)
        x_q_lt = self._get_x_q(x, q_lt, N)
        x_q_rb = self._get_x_q(x, q_rb, N)
        x_q_lb = self._get_x_q(x, q_lb, N)
        x_q_rt = self._get_x_q(x, q_rt, N)
        # (b, c, h, w, N)
        x_offset = (
                g_lt.unsqueeze(dim=1) * x_q_lt
                + g_rb.unsqueeze(dim=1) * x_q_rb
                + g_lb.unsqueeze(dim=1) * x_q_lb
                + g_rt.unsqueeze(dim=1) * x_q_rt
        )
        x_offset = self._reshape_x_offset(x_offset, N_X, N_Y)
        x_offset = self.dropout2(x_offset)
        x_offset = self.convs[self.i_list.index(N_X * 10 + N_Y)](x_offset)
        out = x_offset * m + bias
        returnout

    def _get_p_n(self, N, dtype, n_x, n_y):
        p_n_x, p_n_y = torch.meshgrid(
            torch.arange(-(n_x - 1) // 2, (n_x - 1) // 2 + 1),
            torch.arange(-(n_y - 1) // 2, (n_y - 1) // 2 + 1),
        )
        p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
        p_n = p_n.view(1, 2 * N, 1, 1).type(dtype)
        return p_n

    def _get_p_0(self, h, w, N, dtype):
        p_0_x, p_0_y = torch.meshgrid(
            torch.arange(1, h * self.stride + 1, self.stride),
            torch.arange(1, w * self.stride + 1, self.stride),
        )
        p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
        return p_0

    def _get_p(self, offset, dtype, n_x, n_y):
        N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3)
        L, W = offset.split([N, N], dim=1)
        L = L / n_x
        W = W / n_y
        offsett = torch.cat([L, W], dim=1)
        p_n = self._get_p_n(N, dtype, n_x, n_y)
        p_n = p_n.repeat([1, 1, h, w])
        p_0 = self._get_p_0(h, w, N, dtype)
        p = p_0 + offsett * p_n
        return p

    def _get_x_q(self, x, q, N):
        b, h, w, _ = q.size()
        padded_w = x.size(3)
        c = x.size(1)
        x = x.contiguous().view(b, c, -1)
        index = q[..., :N] * padded_w + q[..., N:]
        index = (
            index.contiguous()
            .unsqueeze(dim=1)
            .expand(-1, c, -1, -1, -1)
            .contiguous()
            .view(b, c, -1)
        )
        x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
        return x_offset

    @staticmethod
    def _reshape_x_offset(x_offset, n_x, n_y):
        b, c, h, w, N = x_offset.size()
        x_offset = torch.cat([x_offset[..., s:s + n_y].contiguous().view(b, c, h, w * n_y) for s in range(0, N, n_y)],
                             dim=-1)
        x_offset = x_offset.contiguous().view(b, c, h * n_x, w * n_y)
        return x_offset

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。


下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。


下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。


交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
### 傅里叶自注意力模块即插即用实现介绍 #### 背景与优势 傅里叶自注意力模块是一种新颖的设计,在CVPR2023上被提出,旨在通过频域操作来增强计算机视觉模型的能力[^1]。该模块利用快速傅里叶变换(FFT)和逆快速傅里叶变换(IFFT),可以在对数线性时间复杂度内完成全局token混合,相比传统的基于空间域的方法具有显著的速度提升和更低的时间消耗[^2]。 #### 工作原理概述 此模块的核心在于它能在频率空间中有效地捕捉长程依赖关系,并且由于其高效的计算方式,可以轻松嵌入现有的神经网络架构而不增加过多负担。具体来说: - **输入转换**:原始图像数据先经过一次前向传播得到特征映射; - **频谱变换**:接着应用FFT将这些特征映射转化为对应的频谱表示形式; - **注意力建模**:在此基础上构建特殊的自注意力机制来进行信息交互; - **反变换重构**:最后再经由IFFT返回至原来的维度大小以便后续层继续处理。 这种设计不仅使得模型能够更加快捷地获取远距离像素间的关系,同时也因为减少了不必要的冗余运算而提高了整体性能表现。 #### Python代码示例 下面给出了一段简化版本的Python伪代码用于说明如何在一个典型的卷积神经网络(CNN)框架内部集成这样一个傅里叶自注意力单元: ```python import torch.nn as nn from fft_attention import FourierSelfAttentionLayer # 自定义库导入 class EnhancedCNN(nn.Module): def __init__(self, ...): super().__init__() self.conv_layers = ... self.fft_attn_layer = FourierSelfAttentionLayer() # 插入傅里叶自注意力层 def forward(self, x): out = self.conv_layers(x) enhanced_out = self.fft_attn_layer(out) # 应用傅里叶自注意力 return final_output # 继续其他层的操作... ``` 上述例子展示了怎样简单地把`FourierSelfAttentionLayer`作为一个组件加入到已有的网络结构当中去,从而获得更好的效果。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值