(即插即用模块-Attention部分) 三十五、(CVPR 2023) FSA 频域自注意力

在这里插入图片描述

paper:Efficient Frequency Domain-based Transformers for High-Quality Image Deblurring

Code:https://github.com/kkkls/FFTformer


1、Frequency domain-based Self-Attention

传统的 Transformer 模型在图像去模糊任务中表现出色,但其 scaled dot-product attention 计算过程存在效率问题。该计算需要将 query token 与所有 key token 进行矩阵乘法,导致时间复杂度和空间复杂度较为复杂。为了解决这个问题,研究者们提出了多种方法,例如降低图像分辨率、减少 patch 数量或基于特征深度域计算 attention。然而,这些方法都存在一些缺点,例如信息损失或忽略空间信息。所以这篇论文提出一种 频域自注意力(Frequency domain-based Self-Attention)

FSA的基本思想包含以下两点:1、scaled dot-product attention,该操作旨在计算 query token 与 key token 之间的相关性,从而获取图像中不同区域之间的联系信息。2、FFT 和 IFFT: 使用快速傅里叶变换 (FFT) 将特征图从空间域转换到频率域,然后进行点乘操作,最后使用逆快速傅里叶变换 (IFFT) 将结果转换回空间域。

对于输入X,FSA 的实现过程:

  1. 特征提取: 使用 1x1 卷积和 3x3 深度卷积分别提取 query 特征 Fq、key 特征 Fk 和 value 特征 Fv。
  2. FFT: 对 Fq 和 Fk 进行快速傅里叶变换,得到 F(Fq) 和 F(Fk)。
  3. 频率域相关性计算: 使用 FFT 结果进行点乘操作,得到特征 A,其中 A 表示 query token 与所有 key token 在频率域的相关性。
  4. 层归一化: 对 A 进行层归一化,得到 A’ 。
  5. IFFT: 对 A’ 进行逆快速傅里叶变换,得到特征 Vatt。
  6. 特征融合: 将 Vatt 与 Fv 进行点乘操作,得到 Fv’ = Vatt * Fv。
  7. 输出特征: 使用 1x1 卷积将 Fv’ 与原始特征 X 进行融合,得到输出特征。

Frequency domain-based Self-Attention 结构图:
在这里插入图片描述

2、代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import numbers
from einops.einops import rearrange


def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')


def to_4d(x, h, w):
    return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)


class BiasFree_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(BiasFree_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return x / torch.sqrt(sigma + 1e-5) * self.weight


class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias


class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type):
        super(LayerNorm, self).__init__()
        if LayerNorm_type == 'BiasFree':
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)

    def forward(self, x):
        h, w = x.shape[-2:]
        return to_4d(self.body(to_3d(x)), h, w)


class FSAS(nn.Module):
    def __init__(self, dim, bias=False):
        super(FSAS, self).__init__()

        self.to_hidden = nn.Conv2d(dim, dim * 6, kernel_size=1, bias=bias)
        self.to_hidden_dw = nn.Conv2d(dim * 6, dim * 6, kernel_size=3, stride=1, padding=1, groups=dim * 6, bias=bias)

        self.project_out = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias)

        self.norm = LayerNorm(dim * 2, LayerNorm_type='WithBias')

        self.patch_size = 8

    def forward(self, x):
        hidden = self.to_hidden(x)

        q, k, v = self.to_hidden_dw(hidden).chunk(3, dim=1)

        q_patch = rearrange(q, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
                            patch2=self.patch_size)
        k_patch = rearrange(k, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
                            patch2=self.patch_size)
        q_fft = torch.fft.rfft2(q_patch.float())
        k_fft = torch.fft.rfft2(k_patch.float())

        out = q_fft * k_fft
        out = torch.fft.irfft2(out, s=(self.patch_size, self.patch_size))
        out = rearrange(out, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size,
                        patch2=self.patch_size)

        out = self.norm(out)

        output = v * out
        output = self.project_out(output)

        return output


if __name__ == '__main__':
    x = torch.randn(4, 512, 32, 32)
    model = FSAS(512)
    output = model(x)
    print(output.shape)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值