paper:Efficient Frequency Domain-based Transformers for High-Quality Image Deblurring
1、Frequency domain-based Self-Attention
传统的 Transformer 模型在图像去模糊任务中表现出色,但其 scaled dot-product attention 计算过程存在效率问题。该计算需要将 query token 与所有 key token 进行矩阵乘法,导致时间复杂度和空间复杂度较为复杂。为了解决这个问题,研究者们提出了多种方法,例如降低图像分辨率、减少 patch 数量或基于特征深度域计算 attention。然而,这些方法都存在一些缺点,例如信息损失或忽略空间信息。所以这篇论文提出一种 频域自注意力(Frequency domain-based Self-Attention)。
FSA的基本思想包含以下两点:1、scaled dot-product attention,该操作旨在计算 query token 与 key token 之间的相关性,从而获取图像中不同区域之间的联系信息。2、FFT 和 IFFT: 使用快速傅里叶变换 (FFT) 将特征图从空间域转换到频率域,然后进行点乘操作,最后使用逆快速傅里叶变换 (IFFT) 将结果转换回空间域。
对于输入X,FSA 的实现过程:
- 特征提取: 使用 1x1 卷积和 3x3 深度卷积分别提取 query 特征 Fq、key 特征 Fk 和 value 特征 Fv。
- FFT: 对 Fq 和 Fk 进行快速傅里叶变换,得到 F(Fq) 和 F(Fk)。
- 频率域相关性计算: 使用 FFT 结果进行点乘操作,得到特征 A,其中 A 表示 query token 与所有 key token 在频率域的相关性。
- 层归一化: 对 A 进行层归一化,得到 A’ 。
- IFFT: 对 A’ 进行逆快速傅里叶变换,得到特征 Vatt。
- 特征融合: 将 Vatt 与 Fv 进行点乘操作,得到 Fv’ = Vatt * Fv。
- 输出特征: 使用 1x1 卷积将 Fv’ 与原始特征 X 进行融合,得到输出特征。
Frequency domain-based Self-Attention 结构图:
2、代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import numbers
from einops.einops import rearrange
def to_3d(x):
return rearrange(x, 'b c h w -> b (h w) c')
def to_4d(x, h, w):
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma + 1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm, self).__init__()
if LayerNorm_type == 'BiasFree':
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
h, w = x.shape[-2:]
return to_4d(self.body(to_3d(x)), h, w)
class FSAS(nn.Module):
def __init__(self, dim, bias=False):
super(FSAS, self).__init__()
self.to_hidden = nn.Conv2d(dim, dim * 6, kernel_size=1, bias=bias)
self.to_hidden_dw = nn.Conv2d(dim * 6, dim * 6, kernel_size=3, stride=1, padding=1, groups=dim * 6, bias=bias)
self.project_out = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias)
self.norm = LayerNorm(dim * 2, LayerNorm_type='WithBias')
self.patch_size = 8
def forward(self, x):
hidden = self.to_hidden(x)
q, k, v = self.to_hidden_dw(hidden).chunk(3, dim=1)
q_patch = rearrange(q, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
patch2=self.patch_size)
k_patch = rearrange(k, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
patch2=self.patch_size)
q_fft = torch.fft.rfft2(q_patch.float())
k_fft = torch.fft.rfft2(k_patch.float())
out = q_fft * k_fft
out = torch.fft.irfft2(out, s=(self.patch_size, self.patch_size))
out = rearrange(out, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size,
patch2=self.patch_size)
out = self.norm(out)
output = v * out
output = self.project_out(output)
return output
if __name__ == '__main__':
x = torch.randn(4, 512, 32, 32)
model = FSAS(512)
output = model(x)
print(output.shape)