[2111] Adaptive Fourier Neural Operators: Efficient Token Mixers for Transformers

paper

Abstract

ViTs and ViMLPs

  • similarity capture long-terms dependencies between spatial locations
  • difference ViMLPs more simplified than ViTs by replacing SA with FCs
  • problem large computational complexity O ( N 2 ) \mathcal{O}(N^2) O(N2) for token X ∈ R N × C X\in\Reals^{N\times C} XRN×C

existed solution

  • Swin: replace global SA with local SA by window partitioning
  • GFNet: propose depth-wise global conv and perform well in Fourier domain
    • main steps
      • 2D fast Fourier transform: transfer input features from space domain to frequency domain
      • frequent gating: frequent features element-wise multiply with learnable global filter
      • 2D inverse fast Fourier transform: transfer learnt features from frequency domain to space domain
    • drawbacks
      • lack adaptivity and expressiveness at high resolution    ⟸    \impliedby complexity and parameter grow with sequence size
      • no channel mixing involved in frequent gating operation

2111_afno_t1
Complexity, parameter count, and interpretation for FNO, AFNO, GFN, and Self-Attention. N ≔ h w N\coloneqq hw N:=hw, d d d, and k k k refer to the sequence size, channel size, and block count, respectively.

contribution

  • propose adaptive Fourier neural operators (AFNO), an efficient token mixer with a quasi-linear complexity in sequence length
  • improve expressiveness and generalization of AFNO, by imposing block-diagonal structure, adaptive weight-sharing and sparsity

Method

Fourier neural operator

kernel integration

denote x n , m ∈ R d x_{n, m}\in\Reals^d xn,mRd as the ( n , m ) (n, m) (n,m)-th token in input tensor X ∈ R N × d , N ≔ h w X\in\Reals^{N\times d}, N\coloneqq hw XRN×d,N:=hw
index sequence token X [ s ] ≔ X [ n s , m s ] X[s]\coloneqq X[n_s, m_s] X[s]:=X[ns,ms] for some s , t ∈ [ h w ] s, t\in[hw] s,t[hw]
definition 1 (self attention) define self attention mixing as A t t : R N × D → R N × d Att: \Reals^{N\times D}\rightarrow\Reals^{N\times d} Att:RN×DRN×d
A t t ( X ) ≔ s o f t m a x ( X W q ( X W k ) T d ) X W v Att(X)\coloneqq softmax(\frac{XW_q(XW_k)^T}{\sqrt{d}})XW_v Att(X):=softmax(d XWq(XWk)T)XWv

where, W q , W k , W v ∈ R d × d W_q, W_k, W_v\in\Reals^{d\times d} Wq,Wk,WvRd×d are query, key, value matrices

write self-attention as a kernel integration
define K ≔ s o f t m a x ( ⟨ X W q , X W k ⟩ d ) K\coloneqq softmax(\frac{\langle XW_q, XW_k\rangle}{\sqrt{d}}) K:=softmax(d XWq,XWk) as attention matric
treat self attention as an asymmetric matrix-valued kernel κ : [ N ] × [ N ] → R d × d \kappa: [N]\times[N]\rightarrow\Reals^{d\times d} κ:[N]×[N]Rd×d parametrized as κ [ s , t ] = K [ s , t ] ∘ W v T \kappa[s, t]=K[s, t]\circ W_v^T κ[s,t]=K[s,t]WvT
view self attention as kernel summation
A t t ( X ) [ s ] ≔ ∑ t = 1 N κ [ s , t ] X [ t ] , ∀ s ∈ [ N ] Att(X)[s]\coloneqq\sum_{t=1}^N\kappa[s, t]X[t], \forall s\in[N] Att(X)[s]:=t=1Nκ[s,t]X[t],s[N]

简单解释下这个公式的含义: A t t ( X ) [ s ] Att(X)[s] Att(X)[s] 代表 A t t ( X ) ∈ R N × N Att(X)\in\Reals^{N\times N} Att(X)RN×N 矩阵的第 s s s 行,是一个 N N N 维的向量;它是一系列的 N N N 维向量 X [ t ] X[t] X[t] 的加权和,其中 X [ t ] X[t] X[t] 对应的权值是 κ [ s , t ] = K [ s , t ] ∘ W v T ∈ R d × d \kappa[s, t]=K[s, t]\circ W_v^T\in\Reals^{d\times d} κ[s,t]=K[s,t]WvTRd×d
ref: zhihu

extend kernel summation into continuous kernel integrals
input tensor X X X is a spatial function in function space X ∈ ( D , R d ) X\in(D, \Reals^d) X(D,Rd), rather than a finite-dimensional vector in Euclidean space X ∈ R N × d X\in\Reals^{N\times d} XRN×d
definition 2 (kernel integral) define kernel integral operator K : ( D , R d ) → ( D , R d ) \mathcal{K}: (D, \Reals^d)\rightarrow(D, \Reals^d) K:(D,Rd)(D,Rd) as
K ( X ) ( s ) = ∫ D κ ( s , t ) X ( t ) d t , ∀ s ∈ D \mathcal{K}(X)(s)=\int_D\kappa(s, t)X(t)\mathrm{d}t, \forall s\in D K(X)(s)=Dκ(s,t)X(t)dt,sD

with continuous kernel function κ : D × D → R d × d \kappa: D\times D\rightarrow\Reals^{d\times d} κ:D×DRd×d

integral lead to global convolution
definition 3 (global convolution) given special case of Green kernel: κ ( s , t ) = κ ( s − t ) \kappa(s, t)=\kappa(s-t) κ(s,t)=κ(st), kernel operator admit
K ( X ) ( s ) = ∫ D κ ( s − t ) X ( t ) d t , ∀ s ∈ D \mathcal{K}(X)(s)=\int_D\kappa(s-t)X(t)\mathrm{d}t, \forall s\in D K(X)(s)=Dκ(st)X(t)dt,sD

convolution has smaller complexity than integration
global convolution can be efficiently implemented by FFT

Fourier neural operator (FNO)

define Fourier neural operator as
definition 4 (Fourier neural operator) for continuous input X ∈ D X\in D XD and kernel κ \kappa κ, kernel integral at token s s s is found as
K ( X ) ( s ) = F − 1 ( F ( κ ) ⋅ F ( X ) ) ( s ) , ∀ s ∈ D \mathcal{K}(X)(s)=\mathcal{F}^{-1}(\mathcal{F}(\kappa)\cdot\mathcal{F}(X))(s), \forall s\in D K(X)(s)=F1(F(κ)F(X))(s),sD

where, F , F − 1 \mathcal{F}, \mathcal{F}^{-1} F,F1 are Fourier transform and its inverse

discrete FNO
for images with finite dimension on a discrete grid, mix tokens using discrete Fourier transform (DFT)
given input token tensor X ∈ R h × w × d X\in\Reals^{h\times w\times d} XRh×w×d, do DFT per token ( m , n ) ∈ [ h ] × [ w ] (m, n)\in[h]\times[w] (m,n)[h]×[w]
step 1 token mixing: discrete F ( X ) \mathcal{F}(X) F(X)
z m , n = [ D F T ( X ) ] m , n z_{m, n}=[DFT(X)]_{m, n} zm,n=[DFT(X)]m,n

step 2 channel mixing: discrete F ( κ ) \mathcal{F}(\kappa) F(κ)
z ~ m , n = W m , n z m , n \tilde{z}_{m, n}=W_{m, n}z_{m, n} z~m,n=Wm,nzm,n

where, W m , n ≔ D F T ( κ ) ∈ C h × w × d × d W_{m, n}\coloneqq DFT(\kappa)\in\Complex^{h\times w\times d\times d} Wm,n:=DFT(κ)Ch×w×d×d is complex-valued weight tensor to parametrize kernel
step 3 token de-mixing: discrete F − 1 ( Z ~ ) \mathcal{F}^{-1}(\tilde{Z}) F1(Z~)
y m , n = [ I D F T ( Z ~ ) ] m , n y_{m, n}=[IDFT(\tilde{Z})]_{m, n} ym,n=[IDFT(Z~)]m,n

step 4 add a residual term x m , n x_{m, n} xm,n (parametrized as a convolution) to y m , n y_{m, n} ym,n
compensate local features and non-periodic boundaries
   ⟸    \impliedby DFT assume a global convolution applied on periodic images, which is not typically true for real-world images

conclusions of FNO

  • merits
    • after training on one resolution, directly evaluated at another resolution
    • encode higher-frequency information in channel dimension
  • demerits
    • static weights W m , n W_{m, n} Wm,n: unadaptive to different input resolution

adaptive Fourier neural operator

model architecture

2111_afno_f2
The multi-layer transformer network with FNO, GFN, and AFNO mixers. GFNet performs element-wise matrix multiplication with separate weights across channels ( k k k). FNO performs full matrix multiplication that mixes all the channels. AFNO performs block-wise channel mixing using MLP along with soft-thresholding. The symbols h h h, w w w, d d d, and k k k refer to the height, width, channel size, and block count, respectively.

adaptive Fourier neural operator (AFNO)

AFNO mainly modify step 2 in FNO
impose a block diagonal structure on W W W, divided into k k k weight blocks of size d k × d k \frac{d}k\times\frac{d}k kd×kd
kernel operate independently on each block
z ~ m , n ( ℓ ) = W m , n ( ℓ ) z m , n ( ℓ ) , ℓ = 1 , . . . , k \tilde{z}_{m, n}^{(\ell)}=W_{m, n}^{(\ell)}z_{m, n}^{(\ell)}, \ell=1, ..., k z~m,n()=Wm,n()zm,n(),=1,...,k

note that each block can be interpreted a head as in multi-head self-attention
implemented by a 2-layer perceptron for ( n , m ) (n, m) (n,m)-th token
z ~ m , n = M L P ( z m , n ) = W 2 R e L U ( W 1 z m , n ) + b \tilde{z}_{m, n}=\mathrm{MLP}(z_{m, n})=W_2\mathrm{ReLU}(W_1z_{m, n})+b z~m,n=MLP(zm,n)=W2ReLU(W1zm,n)+b

where, W 1 , W 2 , b W_1, W_2, b W1,W2,b are shared for all tokens

images are inherently sparse in Fourier domain
   ⟹    \implies adaptively mask tokens according to their importance towards end task
use LASSO channel mixing to sparsify tokens
min ⁡ ∥ z ~ m , n − W m , n z m , n ∥ 2 + λ ∥ z ~ m , n ∥ 1 \min\Vert\tilde{z}_{m, n}-W_{m, n}z_{m, n}\Vert^2+\lambda\Vert\tilde{z}_{m, n}\Vert_1 minz~m,nWm,nzm,n2+λz~m,n1

implemented by soft-thresholding and shrinkage operation
z ~ m , n = S λ ( W m , n z m , n ) S λ = s i g n ( x ) max ⁡ { ∣ x ∣ − λ , 0 } \begin{aligned} \tilde{z}_{m, n}&=S_{\lambda}(W_{m, n}z_{m, n}) \\ S_{\lambda}&=\mathrm{sign}(x)\max\{\vert x\vert-\lambda, 0\} \end{aligned} z~m,nSλ=Sλ(Wm,nzm,n)=sign(x)max{xλ,0}

where, λ \lambda λ is a tuning parameter to control sparsity

Experiment

image classification

dataset ImageNet-1K
loss function cross-entropy
optimizer Adam: 300 epochs, weigh decay=0.05, init lr=5e-4, linear warm-up 5 epochs, cosine decay to 1e-5
max gradient norm clipped to 1.0

2111_afno_t5
ImageNet-1K classification efficiency-accuracy trade-off when the input resolution is 224 × 224 224\times 224 224×224.

image inpainting

dataset ImageNet-1K
optimizer Adam: 100 epochs, weigh decay=0.01, init lr=1e-4 for self-attention or 1e-3 for other mixers, cosine decay to 1e-5
max gradient norm clipped to 1.0

2111_afno_t2
Inpainting PSNR and SSIM for ImageNet-1k validation data. AFNO matches the performance of Self-Attention despite using significantly less FLOPs.

few-shot segmentation

dataset CelebA-Faces, ADE-Cars, LSUN-Cats
loss function cross-entropy
optimizer 2000 epochs, init lr=1e-4 for self-attention or 1e-3 for other mixers

2111_afno_t3
Few-shot segmentation mIoU for AFNO versus alternative mixers. AFNO surpasses Self-Attention for 2/3 datasets while using less flops.

cityscapes segmentation
  • pre-training
    dataset ImageNet-1K
    optimizer Adam: batch size=1024, 300 epochs, weigh decay=0.05, init lr=1e-3, warm-up 6250 iterations, cosine decay to 1e-5
    max gradient norm clipped to 1.0
  • fine-tuning
    dataset Cityscapes
    optimizer Adam: 450 epochs, weigh decay=0.05

2111_afno_t4
mIoU and FLOPs for Cityscapes segmentation at 1024 × 1024 1024\times 1024 1024×1024 resolution. Note, both the mixer and total FLOPs are included. For GFN and AFNO, the MLP layers are the bottleneck for the complexity. Also, AFNO-25% only keeps 25% of the low frequency modes, while AFNO-100% keeps all the modes. Results for self-attention cannot be obtained due to the long sequence length in the first few layers.

ablation studies
sparsity threshold

2111_afno_f4
Ablations for the sparsity thresholds and block count measured by inpainting validation PSNR. The results suggest that soft thresholding and blocks are effective.

blocks number
impact of adaptive weights

2111_afno_t6
Ablations for AFNO versus FNO, AFNO without adaptive weights, and hard thresholding. Results are on inpainting pretraining with 10% of ImageNet along with few-show segmentation mIoU on CelebAFaces. Hard thresholding only keeps 35% of low frequency modes. AFNO demonstrates superior performance for the same parameter count in both tasks.

comparison to FNO
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值