文章目录
- 1、Global Filter
- 2、代码实现
1、Global Filter
自注意力机制和纯 MLP 模型在视觉任务中展现出潜力,但计算复杂度高,难以扩展到高分辨率特征。而局部自注意力机制虽有效,但引入了人为选择和限制感受野。对此,论文首先分析了傅里叶变换,指出其是分析图像频谱信息的重要工具,具有对数线性复杂度,能够高效地处理全局信息。并基于此提出一种 全局滤波器(Global Filter)。
GlobalFilter 的基本思想是利用傅里叶变换将空间特征转换为频率域,学习空间位置的长期依赖关系。使用可学习的全局滤波器对频率域特征进行逐元素乘法,捕获不同位置之间的交互。最后通过傅里叶逆变换将特征映射回空间域。
对于输入X,Global Filter 的实现过程:
- 傅里叶变换:对特征图进行二维傅里叶变换,将空间信息转换为频率域。
- 滤波:使用可学习的全局滤波器对频率域特征进行逐元素乘法,模拟不同频率成分的交互。
- 逆变换:对滤波后的频率域特征进行二维傅里叶逆变换,将特征映射回空间域,即为最终输出。
与现有的滤波器相比,Global Filter 具有以下优势:
- 高效:傅里叶变换和逆变换具有对数线性复杂度,比自注意力和 MLP 更高效。
- 全局信息:能够有效地捕获全局空间信息,避免局部自注意力机制的局限性。
- 灵活性:可通过调整滤波器设计,控制模型对不同频率成分的关注程度。
Global Filter 结构图:
2、代码实现
import torch
import math
from torch import nn
from einops.einops import rearrange
class GlobalFilter(nn.Module):
def __init__(self, dim, h=14, w=8):
super().__init__()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
self.w = w
self.h = h
def forward(self, x, spatial_size=None):
B, N, C = x.shape
if spatial_size is None:
a = b = int(math.sqrt(N))
else:
a, b = spatial_size
x = x.view(B, a, b, C)
x = x.to(torch.float32)
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
weight = torch.view_as_complex(self.complex_weight)
x = x * weight
x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho')
x = x.reshape(B, N, C)
return x
if __name__ == '__main__':
H, W = 14, 14
x = torch.randn(4, 384, 14, 14).cuda()
x = rearrange(x, 'b c h w -> b (h w) c')
model = GlobalFilter(384, h=H, w=H//2 + 1).cuda()
out = model(x)
out = rearrange(out, 'b (h w) c -> b c h w', h=H, w=W)
print(out.shape)