(即插即用模块-Attention部分) 六十三、(2024 CVPR) MLKA 多尺度大核注意力

在这里插入图片描述

paper:MULTI-SCALE ATTENTION NETWORK FOR SINGLE IMAGE SUPER-RESOLUTION

Code:https://github.com/icandle/MAN


1、Multi-scale Large Kernel Attention

为了解决如何有效地建立不同区域之间的长距离相关性,并避免由于大卷积核带来的“块效应”问题。这篇论文在 LKA 的基础上提出了一种 多尺度大核注意力(Multi-scale Large Kernel Attention),MLKA 的设计动机是为了解决图像超分辨率任务中,MLKA 结合了 大卷积核分解 和 多尺度机制 来实现这一目标。

MLKA 的实现过程:

  1. 输入特征图 X: 输入特征图 X 被分解成多个组,每个组包含相同数量的通道。
  2. LKA 模块: 对每个组应用 LKA 模块,生成不同尺度上的注意力图 LKAi。
  3. 门控模块: 为了避免扩张卷积带来的“块效应”,对每个组生成的注意力图进行动态重校准。这样可以更好地保留局部纹理信息。通过对每个 LKAi 应用门控模块,生成门控注意力图 MLKAi。
  4. 聚合: 将所有 MLKAi 聚合,得到最终的注意力图。

MLKA 的优势:

  • 更全面的长距离相关性学习: 通过多尺度机制,MLKA 可以学习不同尺度上的长距离相关性,从而更好地恢复图像细节。
  • 避免“块效应”: 通过门控机制,MLKA 可以有效地避免扩张卷积带来的“块效应”,从而更好地保留图像的平滑性。
  • 计算效率高: MLKA 通过大卷积核分解和门控机制,实现了计算效率的提升。

Multi-scale Large Kernel Attention 结构图:
在这里插入图片描述


2、代码实现

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x


class MLKA(nn.Module):
    def __init__(self, n_feats, k=2, squeeze_factor=15):
        super().__init__()
        i_feats = 2 * n_feats

        self.norm = LayerNorm(n_feats, data_format='channels_first')
        self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)

        # Multiscale Large Kernel Attention
        self.LKA7 = nn.Sequential(
            nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3),
            nn.Conv2d(n_feats // 3, n_feats // 3, 9, stride=1, padding=(9 // 2) * 4, groups=n_feats // 3, dilation=4),
            nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))
        self.LKA5 = nn.Sequential(
            nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3),
            nn.Conv2d(n_feats // 3, n_feats // 3, 7, stride=1, padding=(7 // 2) * 3, groups=n_feats // 3, dilation=3),
            nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))
        self.LKA3 = nn.Sequential(
            nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3),
            nn.Conv2d(n_feats // 3, n_feats // 3, 5, stride=1, padding=(5 // 2) * 2, groups=n_feats // 3, dilation=2),
            nn.Conv2d(n_feats // 3, n_feats // 3, 1, 1, 0))

        self.X3 = nn.Conv2d(n_feats // 3, n_feats // 3, 3, 1, 1, groups=n_feats // 3)
        self.X5 = nn.Conv2d(n_feats // 3, n_feats // 3, 5, 1, 5 // 2, groups=n_feats // 3)
        self.X7 = nn.Conv2d(n_feats // 3, n_feats // 3, 7, 1, 7 // 2, groups=n_feats // 3)

        self.proj_first = nn.Sequential(
            nn.Conv2d(n_feats, i_feats, 1, 1, 0))

        self.proj_last = nn.Sequential(
            nn.Conv2d(n_feats, n_feats, 1, 1, 0))

    def forward(self, x, pre_attn=None, RAA=None):
        shortcut = x.clone()

        x = self.norm(x)

        x = self.proj_first(x)

        a, x = torch.chunk(x, 2, dim=1)

        a_1, a_2, a_3 = torch.chunk(a, 3, dim=1)

        a = torch.cat([self.LKA3(a_1) * self.X3(a_1), self.LKA5(a_2) * self.X5(a_2), self.LKA7(a_3) * self.X7(a_3)],
                      dim=1)

        x = self.proj_last(x * a) * self.scale + shortcut

        return x


if __name__ == '__main__':
    x = torch.randn(4, 360, 64, 64).cuda()
    model = MLKA(360).cuda()
    out = model(x)
    print(out.shape)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值