论文阅读:Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction

这篇文章是 NeruIPS 2024 best paper,是北京大学与字节跳动联合做的一项工作,有关自回归模型的。

Abstract

本文提出了视觉自回归建模(VAR),这是一种新的生成范式,它将图像上的自回归学习重新定义为从粗到细的 “下一级尺度预测” 或 “下一级分辨率预测”,有别于标准的光栅扫描 “下一个标记预测”。这种简单、直观的方法使自回归(AR)变换器能够快速学习视觉分布并能很好地泛化:VAR 首次让类似 GPT 风格的自回归模型在图像生成方面超越了扩散变换器。在 ImageNet 256×256 基准测试中,VAR 通过将弗雷歇初始距离(FID)从 18.65 显著提升到 1.73,初始分数(IS)从 80.4 提升到 350.2,大幅改进了自回归基线,并且推理速度快了 20 倍。同时,经验证 VAR 在包括图像质量、推理速度、数据效率和可扩展性等多个维度上都优于扩散变换器(DiT)。对 VAR 模型进行扩展呈现出与大型语言模型(LLMs)中所观察到的类似的清晰幂律扩展规律,线性相关系数接近 -0.998 就是有力证据。VAR 还进一步展示了在下游任务(包括图像补绘、外绘和编辑)中的零次学习泛化能力。这些结果表明 VAR 初步模拟了大型语言模型的两个重要特性:扩展规律和零次学习泛化。

在这里插入图片描述

  • 图 1:由在 ImageNet 数据集上训练的视觉自回归(VAR)变换器生成的样本。展示了 512×512 像素的样本(上方)、256×256 像素的样本(中间)以及零样本图像编辑结果(下方)。

Introduction

在这里插入图片描述

  • 图 2:标准自回归建模(AR)与我们提出的视觉自回归建模(VAR)对比。
    (a)应用于语言的自回归模型:按从左到右的顺序逐词生成连续的文本标记;
    (b)应用于图像的自回归模型:按光栅扫描顺序(从左到右、从上到下)依次生成视觉标记;
    (c)用于图像的视觉自回归建模(VAR):多尺度标记图以从粗到细的尺度(从低分辨率到高分辨率)进行自回归生成,并且在每个尺度内并行生成标记。视觉自回归建模(VAR)需要一个多尺度矢量量化变分自编码器(VQVAE)才能发挥作用。

GPT 系列以及更多自回归(AR)大型语言模型(LLMs)的问世,开启了人工智能领域的新纪元。这些模型在通用性和多功能性方面展现出了可观的智能,尽管存在诸如幻觉之类的问题,但仍被视为朝着通用人工智能(AGI)迈出了坚实的一步。这些模型的核心是一种自监督学习策略 —— 预测序列中的下一个标记,这是一种简单却意义深远的方法。对这些大型自回归模型成功原因的研究凸显了它们的可扩展性和泛化能力:前者以扩展规律为例,能让我们根据较小模型的性能来预测大型模型的性能,进而指导更好的资源分配;而后者通过零次学习和少次学习得以证明,强调了无监督训练模型对各种未见过任务的适应性。这些特性揭示了自回归模型从大量无标记数据中学习的潜力,体现了 “通用人工智能(AGI)” 的精髓。

与此同时,计算机视觉领域一直在努力开发大型自回归模型或世界模型,旨在模拟它们令人瞩目的可扩展性和泛化能力。诸如 VQGAN 和 DALL-E 等开创性的成果以及它们的后续模型展示了自回归模型在图像生成方面的潜力。这些模型利用视觉标记器将连续图像离散化为二维标记网格,然后将其扁平化为一维序列以进行自回归学习(图 2b),这与顺序语言建模的过程类似(图 2a)。然而,这些模型的扩展规律仍未得到充分探究,更令人沮丧的是,如图 3 所示,它们的性能明显落后于扩散模型。与大型语言模型取得的显著成就形成对比的是,自回归模型在计算机视觉领域的威力似乎在一定程度上受到了限制。

在这里插入图片描述

  • 图 3:不同模型系列在 ImageNet 256×256 图像生成基准测试上的扩展特性。验证集的 FID 可作为参考下限(1.78)。具有 20 亿(2B)参数的视觉自回归建模(VAR)达到了 1.73 的 FID,超过了具有 30 亿(3B)或 70 亿(7B)参数的分层离散图像变换器(L-DiT)。

自回归建模需要定义数据的顺序。我们的工作重新思考了如何给图像 “排序”:人类通常以一种分层的方式感知或创作图像,先是捕捉整体结构,然后是局部细节。这种多尺度、从粗到细的特性为图像暗示了一种 “顺序”。同样受到广泛的多尺度设计的启发,我们将图像的自回归学习定义为图 2(c)中的 “下一级尺度预测”,有别于图 2(b)中传统的 “下一个标记预测”。我们的方法首先将一幅图像编码为多尺度标记图。然后自回归过程从 1×1 标记图开始,并逐步提高分辨率进行扩展:在每一步中,变换器会基于之前所有的标记图来预测下一个更高分辨率的标记图。我们将这种方法称为视觉自回归(VAR)建模。

视觉自回归(VAR)建模直接利用类似 GPT-2 的变换器架构来进行视觉自回归学习。在 ImageNet 256×256 基准测试中,VAR 显著提升了其自回归基线,实现了 1.73 的弗雷歇初始距离(FID)以及 350.2 的初始分数(IS),并且推理速度快了 20 倍。值得注意的是,VAR 在弗雷歇初始距离 / 初始分数、数据效率、推理速度以及可扩展性方面超越了扩散变换器(DiT)—— 而扩散变换器(DiT)是诸如稳定扩散 3.0 和索拉(SORA)等领先扩散系统的基础。VAR 模型还展现出了与大型语言模型(LLMs)中类似的扩展规律。最后,我们展示了 VAR 在图像补绘、外绘以及编辑等任务中的零次学习泛化能力。总之,我们对该领域的贡献包括:

  • 一种采用带有下一级尺度预测的多尺度自回归范式的新型视觉生成框架,为计算机视觉领域的自回归算法设计提供了新的思路。
  • 对视觉自回归(VAR)模型的扩展规律和零次学习泛化潜力进行了实证验证,初步模拟了大型语言模型(LLMs)的吸引人的特性。
  • 在视觉自回归模型性能方面取得了突破,使得类似 GPT 风格的自回归方法首次在图像合成方面超越了强大的扩散模型。
  • 一套全面的开源代码集,包含了矢量量化(VQ)标记器和自回归模型训练流程,有助于推动视觉自回归学习的发展。
Properties of large autoregressive language models

Scaling laws 在自回归语言模型中发现并研究了扩展规律 Scaling laws,这些规律描述了模型(或数据集、计算量等)的规模与测试集上的交叉熵损失值之间的幂律关系。扩展规律使我们能够根据较小模型的性能直接预测较大模型的表现,从而指导更好的资源分配。更令人欣喜的是,它们表明大型语言模型(LLMs)的性能能够随着模型、数据以及计算量的增长而良好地扩展,且永远不会达到饱和,这被认为是等取得成功的一个关键因素。扩展规律所带来的成功启发了视觉领域去探索更多用于多模态理解和生成的类似方法。

Zero-shot generalization 零次学习泛化指的是一个模型(尤其是大型语言模型)执行其未曾被明确训练过的任务的能力。在计算机视觉领域内,人们对基础模型(如 CLIP、SAM、Dinov2)的零次学习和情境学习能力产生了浓厚的兴趣。像 Painter 和 LVM 等创新成果将视觉提示器进行了扩展,以实现在视觉方面的情境学习。

Visual generation

用于视觉生成的光栅扫描自回归模型需要将二维图像编码为一维标记序列。早期的尝试已经展示了以标准的逐行光栅扫描方式生成 RGB(或分组)像素的能力。VQGAN 通过在 VQVAE 的潜在空间中进行自回归学习。它采用仅含解码器的 GPT-2 变换器,按照光栅扫描顺序生成标记,就像 ViT将二维图像序列化为一维补丁那样。VQVAE-2 和 RQ-Transformer 也遵循这种光栅扫描方式,但使用了额外的尺度或堆叠代码。Parti 基于 ViT-VQGAN 的架构,将变换器扩展到 200 亿个参数,并且在文本到图像合成方面表现良好。

掩码预测模型。MaskGIT 采用了一个矢量量化(VQ)自动编码器以及一个与 BERT 类似的掩码预测变换器,通过贪心算法生成 VQ 标记。MagViT将这种方法应用于视频,。MUSE进一步将 MaskGIT 扩展到 30 亿个参数。

扩散模型的进展主要集中在改进学习或采样、引导、潜在学习以及架构方面。扩散变换器(DiT)和 U-ViT 用变换器替换或整合了 U-Net,并启发了近期的图像或视频合成系统,包括稳定扩散 3.0, SORA 和 Vidu。

Method

Preliminary: autoregressive modeling via next-token prediction

Formulation 考虑一个离散标记序列 x = ( x 1 , x 2 , . . . , x T ) x = (x_1, x_2, ..., x_T) x=(x1,x2,...,xT),其中 x t ∈ [ V ] x_t \in [V] xt[V] 是来自大小为 V V V 的词汇表中的一个整数。下一个标记自回归假定观察到当前标记 x t x_t xt 的概率仅取决于它的前缀 ( x 1 , x 2 , . . . , x t − 1 ) (x_1, x_2, ... , x_{t-1}) (x1,x2,...,xt1)。这种单向标记依赖假设允许对序列的似然性进行因式分解。

p ( x 1 , x 2 , . . . , x T ) = ∏ t = 1 T p ( x t ∣ x 1 , x 2 , . . . , x t − 1 ) (1) p(x_1, x_2, ... , x_T) = \prod_{t=1}^{T} p(x_t | x_1, x_2, ..., x_{t-1}) \tag{1} p(x1,x2,...,xT)=t=1Tp(xtx1,x2,...,xt1)(1)

训练一个自回归模型 p θ p_{\theta} pθ 需要针对一个数据集对 p θ ( x t ∣ x 1 , x 2 , . . . , x t − 1 ) p_{\theta}(x_t | x_1, x_2, ..., x_{t-1}) pθ(xtx1,x2,...,xt1) 进行优化。这被称为 “下一个标记预测”,经过训练 p θ p_{\theta} pθ 的能够生成新的序列。

Tokenization 图像本质上是二维连续信号。要通过下一个标记预测将自回归建模应用于图像,我们必须:1)将一幅图像标记化为若干离散标记;2)为单向建模定义标记的一维顺序。对于第一点,通常会使用诸如 [30] 中那样的量化自动编码器,将图像特征图 f ∈ R h × w × c f \in \mathbb{R}^{h \times w \times c} fRh×w×c 转换为离散标记 q ∈ [ V ] h × w q \in [V]^{h \times w} q[V]h×w

f = ε ( i m ) q = Q ( f ) (2) f = {\Large \varepsilon} (im) \quad q = \mathcal{Q}(f) \tag{2} f=ε(im)q=Q(f)(2)

其中, i m im im 表示原始图像, ε {\Large \varepsilon} ε 是一个编码器, Q \mathcal{Q} Q 是一个量化器,量化器通常包含一个可学习的码本 Z ∈ R V × C Z \in \mathbb{R}^{V \times C} ZRV×C,其中有 V V V 个向量。量化过程 q = Q ( f ) q = \mathcal{Q}(f) q=Q(f) 会将每个特征向量 f ( i , j ) f^{(i,j)} f(i,j) 按照欧几里得距离映射到与其最接近的码的码索引 q ( i , j ) q^{(i,j)} q(i,j) 上。

q ( i , j ) = ( arg min ⁡ v ∈ [ V ] ∥ l o o k u p ( Z , v ) − f ( i , j ) ∥ 2 ) ∈ [ V ] q^{(i, j)} = \left( \argmin_{v \in [V]} \left \| lookup(Z, v) - f^{(i,j)} \right \|_2 \right) \in [V] q(i,j)=(v[V]argmin lookup(Z,v)f(i,j) 2)[V]

其中, l o o k u p ( Z , v ) lookup(Z, v) lookup(Z,v) 表示取出码本 Z Z Z 中的第 v v v 个向量。为了训练量化自动编码器,通过每个 q ( i , j ) q^{(i,j)} q(i,j) 查找 Z Z Z 来得到 f f f 的近似值 f ^ \hat{f} f^,然后,利用给定 f ^ \hat{f} f^ 的解码器 D D D 重建一幅新图像,并且使得复合损失 L \mathcal{L} L最小化:

f ^ = L o o k u p ( Z , q ) i m ^ = D ( f ^ ) (4) \hat{f} = Lookup(Z, q) \qquad \hat{im} = D(\hat{f}) \tag{4} f^=Lookup(Z,q)im^=D(f^)(4)

L = ∥ i m − i m ^ ∥ 2 + ∥ f − f ^ ∥ 2 + λ p L p ( i m ^ ) + λ G L G ( i m ^ ) (5) \mathcal{L} = \left \| im - \hat{im} \right \|_2 + \left \| f - \hat{f} \right \|_2 + \lambda_p \mathcal{L}_{p}(\hat{im}) + \lambda_G \mathcal{L}_G(\hat{im}) \tag{5} L= imim^ 2+ ff^ 2+λpLp(im^)+λGLG(im^)(5)

L p \mathcal{L}_{p} Lp 类似感知损失函数 LPIPS, L G \mathcal{L}_{G} LG 是一种判别损失,比如 StyleGAN 的判别器损失, 一旦自动编码器 { ε , Q , D } \{ {\Large \varepsilon}, \mathcal{Q}, D \} {ε,Q,D} 完成了全面训练,它就会被用于对图像进行标记化,以便后续对单向自回归模型进行训练。

处于 q ∈ [ V ] h × w q \in [V]^{h \times w} q[V]h×w 中的图像标记排列成一个二维网格。与具有固有从左到右顺序的自然语言句子不同,对于单向自回归学习而言,图像标记的顺序必须被明确地定义。先前的自回归方法会使用诸如行优先光栅扫描、螺旋或 Z 曲线顺序等策略,将二维的 q q q 网格扁平化为一维序列 x = ( x 1 , x 2 , . . . , x h × w ) x = (x_1, x_2, ..., x_{h \times w}) x=(x1,x2,...,xh×w)。一旦完成扁平化,它们就能从数据集中提取出一组序列 x x x,然后训练一个自回归模型,通过下一个标记预测来使(1)式中的似然性最大化。

上述这种标记化及扁平化的方法使得在图像上进行下一个标记自回归学习成为可能,但也带来了几个问题:

  • 数学前提违背情况。在量化自编码器(矢量量化变分自编码器,VQVAEs)中,编码器通常会生成一个图像特征图 f f f,对于所有的 i , j i,j i,j,其特征向量 f ( i , j ) f(i, j) f(i,j) 是相互依赖的。因此,在量化和平展之后,标记序列 ( x 1 , x 2 , . . . , x h × w ) (x_1, x_2, ..., x_{h \times w}) (x1,x2,...,xh×w) 保留了双向相关性。这与自回归模型的单向依赖假设相矛盾,自回归模型规定每个标记 x t x_t xt 应该只依赖于其前缀 ( x 1 , x 2 , . . . , x t − 1 ) (x_1, x_2, ..., x_{t-1}) (x1,x2,...,xt1)

  • 无法进行某些零样本泛化。与问题 1)类似,图像自回归建模的单向性限制了它们在需要双向推理的任务中的泛化能力。例如,给定图像的底部部分,它无法预测图像的顶部部分。

  • 结构退化。展平操作破坏了图像特征图中固有的空间局部性。例如,标记 q ( i , j ) q^{(i,j)} q(i,j) 与其 4 个紧邻的邻居 q ( i ± 1 , j ) , q ( i , j ± 1 ) q^{(i \pm 1,j)} , q^{(i,j \pm 1)} q(i±1,j),q(i,j±1) 由于位置邻近而密切相关。这种空间关系在线性序列中遭到了破坏,其中单向约束削弱了这些相关性。

  • 效率低下问题。使用常规的自注意力变换器生成图像标记序列 x = ( x 1 , x 2 , . . . , x n × n ) x = (x_1, x_2, ..., x_{n \times n}) x=(x1,x2,...,xn×n) 会产生量级 O ( n 2 ) \mathcal{O}(n^2) O(n2) 的自回归步骤,并且计算成本达到 O ( n 6 ) \mathcal{O}(n^6) O(n6) 量级。

Visual autoregressive modeling via next-scale prediction

在这里插入图片描述

  • 图 4:视觉自回归建模(VAR)涉及两个独立的训练阶段。
    阶段 1:一个多尺度矢量量化(VQ)自编码器将一幅图像编码为 K K K 个标记图 R = { r 1 , r 2 , . . . , r K } R=\{r_1, r_2, ..., r_K \} R={r1,r2,...,rK},并通过复合损失(公式 5)进行训练。有关 “多尺度量化” 和 “嵌入” 的详细内容,请查看算法 1 和算法 2。
    阶段 2:一个视觉自回归建模(VAR)变换器通过下一尺度预测(公式 6)进行训练:它将([起始标记 s s s], r 1 , r 2 , . . . r k − 1 r_1, r_2,...r_{k-1} r1,r2,...rk1)作为输入来预测 ( r 1 , r 2 , . . . r k ) (r_1, r_2,...r_{k}) (r1,r2,...rk)。在训练过程中会使用注意力掩码来确保每个 r k r_k rk 只能关注 r ≤ k r_{\leq k} rk。训练时采用标准的交叉熵损失。

通过将 “下一个标记预测” 策略转变为 “下一尺度预测” 策略,对图像的自回归建模进行了重新概念化。在此,自回归单元是一整个标记图,而非单个标记。我们首先将一个特征图 f ∈ R h × w × c f \in \mathbb{R}^{h \times w \times c} fRh×w×c 量化为 K K K 个多尺度标记图 ( r 1 , r 2 , . . . , r K ) (r_1, r_2, ..., r_K) (r1,r2,...,rK),每个标记图的分辨率 h k × w k h_k \times w_k hk×wk 依次递增,最终 r K r_K rK 会与原始特征图的分辨率 h × w h \times w h×w 相匹配。自回归似然性的构建公式如下:

p ( r 1 , r 2 , . . . , r K ) = ∏ t = 1 K p ( r k ∣ x 1 , x 2 , . . . , x k − 1 ) (6) p(r_1, r_2, ... , r_K) = \prod_{t=1}^{K} p(r_k | x_1, x_2, ..., x_{k-1}) \tag{6} p(r1,r2,...,rK)=t=1Kp(rkx1,x2,...,xk1)(6)

其中,每个自回归单元 r k ∈ [ V ] h k × w k r_k \in [V]^{h_k \times w_k} rk[V]hk×wk 是尺度 k k k 下包含 h k × w k h_k \times w_k hk×wk 个标记的标记图,并且序列 ( r 1 , r 2 , . . . , r k − 1 ) (r_1, r_2,...,r_{k-1}) (r1,r2,...,rk1) 充当 r k r_k rk 的 “前缀”。在第 k k k 个自回归步骤期间, r k r_k rk h k × w k h_k \times w_k hk×wk 个标记上的所有分布都将基于 r k r_k rk 的前缀以及相关的第 k k k 个位置嵌入图并行生成。这种 “下一尺度预测” 方法就是我们所定义的视觉自回归建模(VAR),如图 4 右侧所示。请注意,在视觉自回归建模(VAR)的训练过程中,会使用分块因果注意力掩码来确保每个 r k r_k rk只能关注其前缀。在推理阶段,可以使用键值(kv)缓存,并且不需要掩码。

视觉自回归建模(VAR)对前文提到的三个问题的解决方式如下:

  • 如果我们约束每个 r k r_k rk 仅依赖于它的前缀,也就是获取的过程仅与 r ≤ k r_{\leq k} rk 相关,那么数学前提就能得到满足。由于这一约束与人类视觉感知和艺术绘画中从粗到细的自然递进特征相符(正如我们在第 1 节中所讨论的那样),所以是可接受的。更多细节将在下文的 “标记化” 部分提供。
  • 空间局部性得以保留,原因在于:(i) 在视觉自回归建模(VAR)中不存在展平操作;(ii) 每个中的标记是完全相关的。此外,多尺度设计进一步强化了空间结构。
  • 生成具有 n × n n \times n n×n个潜在变量的图像的复杂度显著降低至 O ( n 4 ) \mathcal{O}(n^4) O(n4),证明详见附录。这种效率提升源于在每个中并行生成标记。

Tokenization 我们开发了一种新的多尺度量化自编码器,用于将图像编码为视觉自回归建模(VAR)学习(公式 6)所需的 K K K 个多尺度离散标记图 R = ( r 1 , r 2 , . . . , r K ) R=(r_1, r_2,...,r_K) R=(r1,r2,...,rK)。我们采用了与矢量量化生成对抗网络(VQGAN)相同的架构,但使用了经过修改的多尺度量化层。针对 f f f f ^ \hat{f} f^ 采用残差设计的编码和解码流程在算法 1 和算法 2 中有详细说明。我们通过实验发现,这种残差式设计,其性能优于独立插值。算法 1 表明,每个 r k r_k rk 将仅依赖于它的前缀 ( r 1 , r 2 , . . . r k − 1 ) (r_1,r_2,...r_{k-1}) (r1,r2,...rk1)。请注意,在所有尺度上都使用了一个共享码本 Z Z Z,以确保每个 r k r_k rk 的标记都属于同一个码本 [ V ] [V] [V]。为了解决将 z k z_k zk 上采样到 h K × w K h_K \times w_K hK×wK 时出现的信息丢失问题,我们使用了 K K K 个额外的卷积层 { ϕ } k = 1 K \{ \phi \}_{k=1}^{K} {ϕ}k=1K。在将 f f f 下采样到 h k × w k h_k \times w_k hk×wk 之后则不使用卷积操作。

  • 算法 1 & 算法 2

在这里插入图片描述

Implementation details

VAR tokenizer, 如前文所述,我们使用了原始的矢量量化变分自编码器(VQVAE)架构以及带有 K K K 个额外卷积层(额外含个参数)的多尺度量化方案。我们针对所有尺度使用了一个共享码本,其大小 V = 4096 V=4096 V=4096。依照基线模型,我们的标记器同样是在开放图像数据集 (OpenImages) 上进行训练,采用复合损失(公式 5),并且空间下采样比率为 16 X 16X 16X

VAR transformer, 我们的主要关注点在于视觉自回归建模(VAR)算法,所以采用了简单的模型架构设计。我们采用了类似于 GPT - 2 和矢量量化生成对抗网络(VQGAN)的仅含解码器的标准变换器架构,并搭配自适应归一化(AdaLN),这种架构在许多视觉生成模型中被广泛采用且已证实其有效性。对于类别条件合成,我们将类别嵌入用作起始标记以及自适应归一化(AdaLN)的条件。我们发现,在进行注意力计算之前将查询向量和键向量归一化为单位向量能够稳定训练过程。我们没有使用大型语言模型中的高级技术,例如旋转位置嵌入(RoPE)、SwiGLU 多层感知机(MLP)或均方根归一化(RMS Norm)。我们的模型形状遵循一个类似的简单规则,即宽度 w w w、头数 h h h 以及丢弃率 d r dr dr 会随着深度 d d d 按如下方式进行线性缩放:

w = 64 d , h = d , d r = 0.1 d / 24 (7) w=64d, \quad h=d, \quad dr = 0.1d/24 \tag{7} w=64d,h=d,dr=0.1d/24(7)

因此,深度为 d d d 的视觉自回归建模(VAR)变换器的主要参数数量 N N N 由以下式子给出:

N ( d ) = d ⋅ 4 w 2 + d ⋅ 8 w 2 + d ⋅ 6 w 2 = 18 d w 2 = 73728 d 3 (8) N(d) = d \cdot 4 w^2 + d \cdot 8 w^2 + d \cdot 6 w^2 = 18 d w^2 = 73728 d^3 \tag{8} N(d)=d4w2+d8w2+d6w2=18dw2=73728d3(8)

所有模型都采用类似的设置进行训练:batch size 256 对应的基础学习率为 1e-4,使用 β 1 = 0.9 \beta_1=0.9 β1=0.9, β 2 = 0.95 \beta_2=0.95 β2=0.95 衰减率为 0.05 0.05 0.05 的 AdamW 优化器,batch size 从 768 到 1024 不等,训练轮次从 200 到 350(取决于模型大小)。实验表明这样一种简单的模型设计能够很好地进行扩展和泛化。

### 快速自回归Transformer与线性注意力机制 #### 定义与背景 快速自回归Transformer是一种特殊的架构设计,旨在加速传统Transformer模型中的自回归解码过程。这类模型通过优化计算流程和减少不必要的冗余操作来提升效率。具体来说,在RNN框架下实现具有线性复杂度的注意力机制可以显著降低计算资源消耗并加快推理速度。 #### 线性注意力机制的工作原理 为了使注意力机制达到线性的计算复杂度,研究者们提出了多种方案。其中一种有效的方法是在不牺牲表达能力的前提下简化原有的多头自我注意结构。例如,Infini-attention将压缩内存融入到了传统的注意力机制之中,并在线性时间内完成了对于长依赖关系的有效捕捉[^2]。这种方法不仅提高了处理无限长度输入序列的能力,而且还在一定程度上控制了内存开销与计算成本。 #### 实现细节 当把上述理念应用于基于RNN的体系时,可以通过调整标准的慢速多头注意力算法来获得更快的结果。下面给出了一种可能的技术路线: 1. **初始化**:设定初始状态$\mathbf{s}_0$作为隐藏层表示; 2. **迭代更新**:对于每一个时间步$t=1,\dots,T$, - 计算查询向量$q_t=\text{WQ}\cdot\mathbf{s}_{t-1}+\mathbf{b_q}$; - 利用预先准备好的键值对$(k,v)$集合直接获取加权求和后的上下文$c'_t=\sum_k \exp(q_t^\top k)v/\sqrt{|q|})$; 这里采用的是近似形式而非精确匹配; 3. **输出预测**:最终得到当前时刻的状态$s_t=f(\mathbf{c}'_t;\theta)$并通过softmax函数映射至词汇表空间完成单词生成任务。 ```python import torch.nn as nn class LinearAttention(nn.Module): def __init__(self, dim_model, num_heads): super().__init__() self.dim_model = dim_model self.num_heads = num_heads # 初始化权重矩阵和其他必要的组件... def forward(self, query, key_values): batch_size, seq_len, _ = query.size() q_proj = ... # 对query做投影变换 kv_pairs = ... # 准备好key-value pair列表 context = [] for t in range(seq_len): qt = q_proj[:, t:t+1, :] # 当前位置的查询向量 scores = torch.bmm(kv_pairs['keys'], qt.transpose(1, 2)) / (self.dim_model ** .5) weights = F.softmax(scores, dim=-1) ct_prime = torch.sum(weights * kv_pairs['values'], axis=-1).unsqueeze(dim=1) context.append(ct_prime) contexts = torch.cat(context, dim=1) output = f(contexts) # 应用激活函数或其他转换 return output ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值