MaskGIT: Masked Generative Image Transformer

Abstract

生成式变压器在计算机视觉领域中合成高保真度和高分辨率图像方面取得了快速的流行度增长。然而,迄今为止最好的生成式变压器模型仍然将图像天真地视为令牌序列,并按照光栅扫描顺序(即逐行)顺序解码图像。我们发现这种策略既不是最优的,也不是高效的。本文提出了一种新的图像合成范式,双向Transformer解码器,我们称之为MaskGIT。在训练期间,MaskGIT 通过关注各个方向的token来学习预测随机掩码的token。在推理时,模型首先同时生成图像的所有token,然后以上一代为条件迭代地细化图像。我们的实验表明,MaskGIT在ImageNet数据集上明显优于最先进的变压器模型,并将自回归解码加速了多达64倍。此外,我们还证明了MaskGIT可以轻松扩展到各种图像编辑任务,如修补、外推和图像操作。

在这里插入图片描述

1. Introduction

近年来,深度图像合成作为一个领域取得了很大的进展。目前,持有最先进结果的是生成对抗网络(GANs),它们能够以惊人的速度合成高保真度的图像。然而,它们存在一些众所周知的问题,包括训练不稳定性和模式崩溃,这导致样本多样性的缺乏。解决这些问题仍然是一个开放的研究问题。

受到Transformer和GPT在自然语言处理中的成功启发,生成式变压器模型在图像合成领域引起了越来越多的关注。一般来说,这些方法的目标是将图像建模为一个序列,并利用现有的自回归模型生成图像。图像的生成分为两个阶段:第一阶段是将图像量化为一系列离散的令牌(或视觉单词)。在第二阶段,学习一个自回归模型(如Transformer)根据先前生成的结果顺序生成图像令牌(即自回归解码)。与GANs中使用微妙的极小极大优化不同,这些模型通过最大似然估计来学习。由于设计上的差异,现有的研究已经证明它们在提供稳定的训练和改进的分布覆盖或多样性方面优于GANs。

现有的关于生成式变压器的研究主要集中在第一阶段,即如何量化图像以最小化信息损失,并借用了自然语言处理中的相同第二阶段。因此,即使是最先进的生成式变压器[15, 35]仍然将图像天真地视为序列,其中图像被展平为按照光栅扫描顺序(从左到右,逐行)的一维token序列(参见图2)。 画家从草图开始,并逐步通过填充或微调细节来完善它,这与之前的工作[7, 15]中使用的逐行打印明显不同。这个描述很形象了 此外,将图像视为一个扁平的序列意味着自回归序列的长度呈二次增长,很容易形成一个非常长的序列,比任何自然语言句子都要长。这不仅对建模长期相关性提出了挑战,而且使得解码变得棘手。例如,在具有32x32 tokens 的GPU上,使用自回归方式生成单个图像需要30秒的时间。

这篇论文介绍了一种名为Masked Generative Image Transformer(MaskGIT)的新型双向变压器用于图像合成。在训练过程中,MaskGIT在类似于BERT中的掩码预测的代理任务上进行训练。在推理时,MaskGIT采用了一种新颖的非自回归解码方法,以恒定的步数合成图像。具体而言,在每个迭代中,模型同时并行地预测所有token,但只保留最自信的token。剩余的token被mask掉,并将在下一个迭代中重新预测。屏蔽比例逐步减少,直到通过几次迭代细化生成所有token。如图2所示,MaskGIT的解码速度比自回归解码快一个数量级,因为它只需要8个步骤(而不是256个步骤)生成一张图像,并且每个步骤内的预测可以并行化。此外,双向自注意力使得模型不仅仅依赖于顺序光栅扫描中的先前token进行条件生成,还可以从各个方向上生成新的令牌。我们发现掩码调度(即每次迭代中屏蔽的图像比例)对生成质量有显著影响。我们提出使用余弦调度,并在消融研究中证明了其有效性。

在ImageNet基准测试中,我们通过实验证明,与最先进的自回归变压器(如VQ-GAN)相比,MaskGIT在256×256和512×512分辨率下的类条件生成中,不仅速度显著提高(最多64倍),而且能够生成更高质量的样本。即使与领先的GAN模型(如BigGAN)和扩散模型(如ADM)相比,MaskGIT在样本质量上也具有可比性,同时产生更有利的多样性。值得注意的是,我们的模型在分类准确率(CAS)和FID上建立了新的最先进水平,用于合成512×512的图像。据我们所知,本文首次提供了关于在常见的ImageNet基准测试上使用掩码建模进行图像生成的有效性的证据。

此外,MaskGIT的多方向性使其可以轻松扩展到对于自回归模型而言困难的图像操作任务。图1展示了类条件图像编辑的一个新应用,其中MaskGIT在给定类别的情况下重新生成边界框内的内容,同时保持上下文(框外部分)不变。对于自回归模型来说,这个任务要么不可行,要么对于GAN模型来说很困难,但对于我们的模型来说却是轻而易举的。定量上,我们通过将MaskGIT应用于图像修补和任意方向的图像外推来展示这种灵活性。尽管我们的模型并非专为这些任务设计,但它在每个任务上的性能与专用模型相当。
在这里插入图片描述

2. Related Work

2.1 Image Synthesis

深度生成模型[12, 17, 29, 34, 41, 45, 46, 53]在图像合成任务中取得了很多成功。基于GAN的方法展示了生成高保真样本的惊人能力[4, 17, 27, 44, 53]。相反,基于似然的方法,如变分自编码器(VAEs)[29, 45]、扩散模型[12, 24, 41]和自回归模型[34, 46],提供了分布覆盖,因此可以生成更多样化的样本[41, 45, 46]。

然而,在像素空间直接最大化似然可能具有挑战性。因此,VQVAE [37, 47]提出了在两个阶段中在潜空间生成图像的方法。在第一阶段,即标记化阶段,它试图将图像 x ∈ R H × W × 3 x \in \mathbb{R}^{H \times W \times 3} xRH×W×3 压缩为离散的潜空间,并主要由以下三个组件组成:

  • 编码器 E学习将图像 x ∈ R H × W × 3 x \in \mathbb{R}^{H \times W \times 3} xRH×W×3 映射到潜空间嵌入 E ∈ R p × q E\in \mathbb{R}^{p \times q} ERp×q
  • 一个码本 e k ∈ R D , k ∈ 1 , 2 , . . . , K e_k ∈ \mathbb{R}^D,k ∈ {1, 2, ..., K} ekRD,k1,2,...,K,用于最近邻查找,将嵌入量化为视觉标记,
  • 解码器 G,从视觉标记 e 预测重建图像 x ^ \hat x x^

在第二阶段,该方法首先使用深度自回归模型预测视觉标记的潜在先验,然后使用第一阶段的解码器将标记序列映射到图像像素。由于这种两阶段方法的有效性,一些方法采用了相同的范式。DALL-E [35]使用Transformer [48]来改进第二阶段的标记预测。VQGAN [15]在第一阶段中添加了对抗损失和感知损失[26,54]以提高图像保真度。与我们的工作相近的VIM [51]提出使用VIT骨干[13]进一步改进标记化阶段。由于这些方法仍然采用自回归模型,第二阶段的解码时间与标记序列的长度成比例增长

2.2 Masked Modeling with Bi-directional Transformers

Transformer架构[48]最初是在自然语言处理领域提出的,最近已经扩展到计算机视觉领域[6, 13]。Transformer由多个自注意力层组成,可以捕捉序列中所有元素之间的相互作用。特别是,BERT [11]引入了掩码语言建模(MLM)任务用于语言表示学习。BERT [11]中使用的双向自注意力机制允许利用来自两个方向的上下文来预测MLM中的掩码标记。

在计算机视觉领域,BERT [11]中的掩码建模已经扩展到图像表示学习[2, 21],其中图像被量化为离散的标记。然而,由于使用双向注意力进行自回归解码的困难,很少有工作成功地将相同的掩码建模应用于图像生成[56]。据我们所知,本文提供了首个证据,证明了掩码建模在常见的ImageNet基准测试中对图像生成的有效性。我们的工作受到NLP领域双向机器翻译[16, 19, 20]的启发,我们的创新之处在于提出了新的掩码策略和解码算法,正如我们的实验证明的那样,这对于图像生成是至关重要的。

3. Method

我们的目标是设计一种利用并行解码和双向生成的新的图像合成范例。

按照2.1节中讨论的两阶段方法进行操作,如图3所示。由于我们的目标是改进第二阶段,我们采用了与VQGAN模型[15]相同的第一阶段设置,并将对标记化步骤的潜在改进留给未来的工作。

对于第二阶段,我们提出通过Masked Visual Token Modeling (MVTM)学习双向Transformer。我们在3.1节介绍了MVTM的训练方法,在3.2节介绍了采样过程。然后我们在3.3节讨论了掩码设计的关键技术。

3.1 MVTM in Training

Y = [ y i ] i = 1 N Y = [y_i]_{i=1}^{N} Y=[yi]i=1N表示通过将图像输入VQ编码器获得的潜在token序列,其中 N N N是reshape后的token矩阵的长度, M = [ m i ] i = 1 N M = [m_i]_{i=1}^{N} M=[mi]i=1N表示相应的二进制掩码。在训练过程中,我们从token序列中随机采样一部分,并用特殊的[MASK] token替换它们。如果 m i = 1 m_i=1 mi=1,则将 y i y_i yi替换为[MASK];如果 m i = 0 m_i=0 mi=0,则保持 y i y_i yi不变。

采样过程由一个掩码调度函数 γ ( r ) ∈ ( 0 , 1 ] \gamma (r)\in (0, 1] γ(r)(0,1] 参数化,并按照以下方式执行:我们首先从0到1中采样一个比率,然后在 Y Y Y 中均匀选择 ⌈ γ ( r ) ⋅ N ⌉ \left \lceil \gamma (r)\cdot N \right \rceil γ(r)N 个标记来放置掩码,其中 N N N 是长度。掩码调度显著影响图像生成的质量,将在3.3节中讨论。

Y M ˉ Y_{\bar M} YMˉ 为将掩码 M M M 应用于 Y Y Y 后的结果。训练目标是最小化被掩码标记的负对数似然:

L mask = − E Y ∈ D [ ∑ ∀ i ∈ [ 1 , N ] , m i = 1 log ⁡ P ( y i ∣ Y M ˉ ) ] (1) L_{\text{mask}} = -\mathbb E_{Y\in \mathcal D}[\sum_{\forall i\in[1,N],m_i=1} \log P(y_i | Y_{\bar M})] \tag{1} Lmask=EYD[i[1,N],mi=1logP(yiYMˉ)](1)

具体来说,我们将掩码后的 Y M ˉ Y_{\bar M} YMˉ 输入到一个多层双向 Transformer 中,用于预测每个掩码标记的概率 P ( y i ∣ Y M ˉ ) P(y_i| Y_{\bar M}) P(yiYMˉ)。负对数似然损失通过计算实际标签与预测标记之间的交叉熵来衡量。需要注意的是,与自回归建模不同,MVTM 中的条件依赖关系是双向的,这使得图像生成能够利用更丰富的上下文信息,同时关注图像中的所有标记。

3.2 Iterative Decoding

在自回归解码中,标记是根据先前生成的输出顺序逐个生成的。这个过程不可并行化,因此对于图像而言非常缓慢,因为图像的标记长度通常要比语言的标记长度大得多,例如256或1024。我们引入了一种新颖的解码方法,其中图像中的所有标记都同时并行生成。这是由于MTVM的双向自注意力所能实现的。

理论上,我们的模型能够在一次遍历中推断出所有标记并生成整个图像。然而,由于与训练任务不一致,我们发现这是具有挑战性的。下面介绍了我们提出的迭代解码方法。在推理时生成图像时,我们从一个空白画布开始,所有标记都被掩码,即 Y M ( 0 ) Y^{(0)}_M YM(0).对于迭代 t t t,我们的算法运行如下:

  • 预测:给定当前迭代中的掩码标记 Y M ( t ) Y_M^{(t)} YM(t),我们的模型并行预测所有掩码位置的概率,表示为 p ( t ) ∈ R N × K p^{(t)} \in \mathbb{R}^{N \times K} p(t)RN×K.

  • 抽样:对于每个掩码位置 i i i,我们根据其在编码本中所有可能标记的预测概率 p i ( t ) ∈ R K p^{(t)}_i \in \mathbb{R}^{K} pi(t)RK 进行抽样,得到标记 y i ( t ) y^{(t)}_i yi(t)。抽样后,相应的预测分数被用作“置信度”分数,表示模型对此预测的信任程度。对于未被掩码的位置,在 Y M ( t ) Y^{(t)}_M YM(t)中,我们将其置信度分数设为1.0

  • 掩码进度 Mask Schedule:我们根据掩码进度函数 γ \gamma γ 计算需要掩码的标记数量,记作 n = ⌈ γ ( t T ) N ⌉ n = \lceil\gamma(\frac{t}{T})N\rceil n=γ(Tt)N,其中 N 是输入长度,T 是总迭代次数。

  • 掩码:我们通过在 Y M ( t ) Y^{(t)}_M YM(t)中掩码 n 个标记,得到 Y M ( t + 1 ) Y^{(t+1)}_M YM(t+1)。迭代 t + 1 t+1 t+1的掩码 M ( t + 1 ) M^{(t+1)} M(t+1)计算方式如下:
    m i ( t + 1 ) = { 1 , if  c i < sorted j ( c j ) [ n ] ,   0 , otherwise . m^{(t+1)}_i = \begin{cases} 1, & \text{if } c_i < \text{sorted}_j(c_j)[n], \\\ 0, & \text{otherwise}. \end{cases} mi(t+1)={1, 0,if ci<sortedj(cj)[n],otherwise.
    其中, c i c_i ci是第 i i i个token的置信度分数。

解码算法在 T 步中合成图像。在每个迭代中,模型同时预测所有标记,但只保留置信度最高的标记。其余的标记被掩码,并在下一次迭代中重新预测。掩码比率逐渐减小,直到在 T 次迭代中生成所有标记。在实践中,掩码的标记通过温度退火进行随机采样,以鼓励更多的多样性,在第4.4节中我们将讨论其影响。图2展示了我们解码过程的一个示例。它在 T=8 次迭代中生成一张图像,在每个迭代中未掩码的标记在网格中被突出 显示,例如当 t=1 时,我们只保留一个标记并掩码其余部分。

Masking Design

我们发现图像生成的质量受到掩码设计的显著影响。我们通过一个掩码调度函数 γ ( ⋅ ) \gamma(\cdot) γ() 来建模掩码过程,该函数计算给定潜在标记的掩码比率。正如前面讨论的,函数 γ \gamma γ 在训练和推断中都会使用。在推断阶段,它接受参数 0 / T , 1 / T , . . . , ( T − 1 ) / T 0/T,1/T, ...,(T-1)/T 0/T,1/T,...,(T1)/T,表示解码进度。在训练中,我们随机采样一个比率 r ∈ [ 0 , 1 ) r \in [0,1) r[0,1),以模拟不同的解码情况。

BERT使用固定的15% 掩码比率 [11],即始终掩码15% 的标记,但这对于我们的任务来说不合适,因为我们的解码器需要从零开始生成图像。因此,我们需要新的掩码调度策略。在讨论具体方案之前,我们首先检查掩码调度函数的属性。首先, γ ( r ) \gamma(r) γ(r) 需要是一个连续函数,范围在 0 0 0 1 1 1 之间,其中 r ∈ [ 0 , 1 ] r \in [0,1] r[0,1]。其次, γ ( r ) \gamma(r) γ(r) 应该随着 r r r 的减小而(单调地)减小,并且满足 γ ( 0 ) → 1 \gamma(0) \rightarrow 1 γ(0)1 γ ( 1 ) → 0 \gamma(1) \rightarrow 0 γ(1)0。第二个属性确保了我们解码算法的收敛性。

本文考虑常见的函数并进行简单的变换,以满足这些属性。图8可视化了这些函数,它们被分为三组:

  • 线性函数是一个直接的解决方案,每次掩码相同数量的标记。

  • 凹函数捕捉到图像生成遵循从少到多信息流的直觉。在开始时,大多数标记被掩码,因此模型只需对一些它感到有信心的预测进行正确。随着时间的推移,掩码比率急剧下降,迫使模型进行更多的正确预测。在这个过程中,有效信息是增加的。凹函数家族包括余弦、平方、立方和指数函数。

  • 凸函数实现了从多到少的过程。模型需要在前几个迭代中完成绝大部分标记。这个家族包括平方根和对数函数。

我们在4.4中以实证的方式比较了上述掩码调度函数,并发现余弦函数在我们的所有实验中表现最好。
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值