模型介绍
Transformer 自 2017 年推出以来,其架构就开始在 NLP 领域占据主导地位。Transformer 应用的唯一限制之一,即 Transformer 关键组件的巨大计算开销–一种自注意力机制,这种机制可以根据序列长度以二次复杂度进行扩展。
基于此,来自谷歌的研究者建议用简单的线性变换替代自注意力子层,该线性变换混合输入 token,以较小的准确率成本损失显著的提高了 transformer 编码器速度。更令人惊讶的是,研究者发现采用标准的、非参数化的傅里叶变换替代自注意力子层,可以在 GLUE 基准测试中实现 92%~97%的 BERT 准确率,在标准的512输入长度的token中,在 GPU 上的训练时间快 80%,在 TPU 上的训练时间快 70%。在更长的输入长度上,FNet模型的训练时间更快。而且FNet 在Long Range Arena基准评估中,与所有的高效 transformer 具有竞争力,同时在所有序列长度上拥有更少的内存占用。
Transformer 自注意力机制使得输入可以用高阶单元表示,从而可以灵活地捕获自然语言中各种语法和语义关系。长期以来,研究人员一直认为,与 Transformer 相关的高复杂性和内存占用量是不可避免的提高性能的折衷方案。但是在本论文中,Google 团队用 FNet 挑战了这一思想,FNet 是一种新颖的模型,在速度、内存占用量和准确率之间取得了很好的平衡。
模型结构
离散傅里叶变换
傅里叶变换将函数分解为其组成频率。 给定一个 n ∈ [ 0 , N − 1 ] n ∈ [0, N − 1] n∈[0,N−1] 的序列 x n {x_n} xn,离散傅里叶变换 (DFT) 由以下公式定义:
对于每个 k,DFT 生成一个新的表示 X k X_k Xk作为所有原始输入标记 x n x_n xn 的总和,具有所谓的“旋转因子”。 计算 DFT 有两种主要方法:快速傅里叶变换 (FFT) 和矩阵乘法。 标准的 FFT 算法是 Cooley–Tukey 算法,它递归地重新表示长度为 N = N1,N2的序列的 DFT,用 N1 个大小为 N2 的较小 DFT 来减少计算时间到 O(N log N)。
另一种方法是简单地将 DFT 矩阵应用于输入序列。 DFT 矩阵 W 是一个 Vandermonde 矩阵,用于单位根直至归一化因子:
其中 n, k = 0, . . . , N − 1。这种矩阵乘法是 O ( N 2 ) O(N^2) O(N2) 运算,它比 FFT 具有更高的渐近复杂度,但在 TPU 上相对较短的序列更快。
FNet结构
FNet 是一种无注意力的 Transformer 架构,其中每一层由一个傅里叶混合子层和一个前馈子层组成。 架构如图所示。
本质上,我们将每个 Transformer 编码器层的自注意力子层替换为傅里叶子层,该子层将 2D DFT 应用于其(序列长度,隐藏维度)embedding 输入 - 一个沿序列的 1D DFT 维度
F
s
e
q
F_{seq}
Fseq 和一个沿隐藏维度
F
h
F_h
Fh 的一维 DFT:
如上式所示,只保留结果的实部; 因此,不需要修改(非线性)前馈子层或输出层来处理复数。 我们发现,仅在傅里叶子层末尾提取总变换的实部时,也就是说,在应用了
F
s
e
q
F_{seq}
Fseq 和
F
h
F_h
Fh之后,FNet 获得了最好的结果。
傅里叶变换最简单的解释是作为混合Token的一种特别有效的机制,这显然为前馈子层提供了对所有Token的充分访问。由于傅里叶变换的对偶性,我们还可以将每个交替编码器块视为应用交替傅里叶和傅里叶逆变换,在时域和频域之间来回变换输入。 因为在频域中乘以前馈子层系数相当于在时域中进行卷积(与一组相关的系数),所以 FNet 可以被认为是乘法和(大核)卷积之间的交替。
模型参考
论文地址:https://arxiv.org/abs/2105.03824
代码地址:https://github.com/google-research/google-research/tree/master/f_net