MaskGIT:掩码图像生成经典方法
TL; DR:掩码图像生成,并行预测所有 token 的分布,采样时多轮迭代生成,每轮按照 mask schedule 取置信度较高的 token。
方法
与常规的 raster order AR 类似的,MaskGIT 的训练整体上也是需要两个过程,先训练 image tokenizer(vqvae/vqgan),然后训练双向 Transformer 进行掩码图像生成。第一阶段不是本文的重点,就直接用了 vqgan 的模型,接下来详细介绍一下第二阶段 MVTM (Masked Visual Token Modeling)的训练和采样。
MVTM 训练
Y = [ y i ] i = 1 N \mathbf{Y}=[y_i]_{i=1}^N Y=[yi]i=1N 表示图片经过 vq encoder 得到的隐层 token,N 是将 token matrix 展平后的序列长度。训练时,采样 Y \mathbf{Y} Y 的一个子集,产生一个掩码图 M = [ m i ] i = 1 N \mathbf{M}=[m_i]_{i=1}^N M=[mi]i=1N,根据掩码图,将 Y \mathbf{Y} Y 中 m i = 1 m_i=1 mi=1 对应的位置 i i i 处的 token y i y_i yi 替换为 [MASK], m i = 0 m_i=0 mi=0 的则保持原 token 不变
产生掩码的过程是根据一个 mask scheduling function γ ( r ) ∈ ( 0 , 1 ] \gamma(r)\in(0,1] γ(r)∈(0,1] 来进行的。大致来说,首先采样一个 0 到 1 之间的比率 r r r,然后根据 mask schedule 均匀地采样 ⌈ γ ( r ) ∗ N ⌉ \lceil \gamma(r)*N\rceil ⌈γ(r)∗N⌉ 个 token 进行掩码。这个 mask schedule 很重要,对结果的影响很大,后面会详细介绍。
按照
M
\mathbf{M}
M 对
Y
\mathbf{Y}
Y 进行掩码后,记作
Y
M
\mathbf{Y}_{{\mathbf{{M}}}}
YM(这里原文记号带了个 bar,但我看下文应该没有)。训练时把
Y
M
\mathbf{Y}_{{\mathbf{{M}}}}
YM 送入到一个双向注意力 Transformer 中,根据已知的 token,预测每个被 mask token 的分布
p
(
y
i
∣
Y
M
)
p(y_i|\mathbf{Y}_{{\mathbf{{M}}}})
p(yi∣YM),训练目标就是交叉熵分类损失:
L
mask
=
−
E
Y
∈
D
[
∑
∀
i
∈
[
1
,
N
]
,
m
i
=
1
log
p
(
y
i
∣
Y
M
)
]
\mathcal{L}_\text{mask}=-\mathbb{E}_{\mathbf{Y}\in \mathcal{D}}\left[\sum_{\forall i\in[1,N],m_i=1}\log p(y_i|\mathbf{Y}_{\mathbf{M}})\right] \notag \\
Lmask=−EY∈D
∀i∈[1,N],mi=1∑logp(yi∣YM)
相比于常规的 raster order AR,MVTM 使用的是双向注意力,可以更好地根据上下文信息来预测各掩码 token 的分布。
多轮迭代解码
理论上来说,MVTM 这种范式可以一次性生成一张完整图片,但这样就与训练时(部分掩码)不一致了,并且确实难度也太大。效果不好。因此,作者提出了多轮逐步生成,迭代解码(iterative decoding)的采样方法。一共需要 T T T 轮,在每一轮,模型根据上轮结果,预测出所有位置的 token 分布,但是只保留置信度分数最高的部分,保留多少由 mask schedule 根据当前轮次进程 t / T t/T t/T 计算出,在最后第 T T T 轮时,全部 token 都被生成出来。
具体来说,记第 t t t 轮的输入为 Y M ( t ) \mathbf{Y}_\mathbf{M}^{(t)} YM(t),在每轮进行如下三步操作:
-
预测。将 Y M ( t ) \mathbf{Y}_\mathbf{M}^{(t)} YM(t) 输入到模型中,模型预测出每个位置的概率分布 p ( t ) ∈ R N × K p^{(t)}\in\mathbb{R}^{N\times K} p(t)∈RN×K, K K K 是 codebook 大小;
-
采样。在每个掩码位置 i i i,根据预测出来的该位置的分布 p i ( t ) ∈ R K p^{(t)}_i\in\mathbb{R}^K pi(t)∈RK,采样一个 token y i ( t ) y_i^{(t)} yi(t),将其对应的预测分数作为一个 ”置信度“ 分数。对于未被掩码的(也就是已经生成好的)位置,置信度分数都是 1;
-
计算掩码数并进行掩码。mask schedule γ \gamma γ 根据当前轮次进程 t / T t/T t/T 计算下一步需要掩码的 token 数 n = ⌈ γ ( t / T ) ∗ N ⌉ n=\lceil\gamma(t/T)*N\rceil n=⌈γ(t/T)∗N⌉。取第 t t t 步结果中置信度分数较小的 n n n 个位置 mask 掉,其余部分保持,得到下一轮的掩码 M ( t + 1 ) \mathbf{M}^{(t+1)} M(t+1)
m i ( t + 1 ) = { 1 , if c i < sorted j ( c j ) [ n ] 0 , otherwise m^{(t+1)}_i= \begin{cases} 1,\quad \text{if}\ c_i<\text{sorted}_j(c_j)[n] \\ 0,\quad \text{otherwise} \end{cases} \notag \\ mi(t+1)={1,if ci<sortedj(cj)[n]0,otherwise
进而得到下一轮的输入 Y M ( t + 1 ) \mathbf{Y}_\mathbf{M}^{(t+1)} YM(t+1).
mask schedule
还有一个很重要的东西没有讨论,那就是 mask schedule γ ( r ) \gamma(r) γ(r)。前面介绍过,我们在训练和采样时都要用到它,训练时需要它基于随机采样的 ratio r r r 来覆盖从 0 到 1 掩码率,迭代采样时需要它来根据当前的轮次进度 t / T t/T t/T 来计算应该掩码的 token 数量。从属性上来看,我们需要 γ ( ⋅ ) \gamma(\cdot) γ(⋅) 是一个定义域和值域都在 [ 0 , 1 ] [0,1] [0,1] 之间的单调递减函数。在常见的函数中,有线性、凹函数(余弦、平方、立方、指数等)和凸函数(对数函数等)三大类选择。
直觉上来理解,是取凹函数比较合适,这样我们的整个多轮迭代采样就是一个 less-to-more 的过程,即一开始只预测少量高置信度的 token,把整张图片的框架仔细确定下来后,到后期大面积补全其他 token。
作者也通过实验验证了这一点,可以看到,从 FID (越小越好)上来看,几种凹函数显著优于线性函数和凸函数,其中又以 cosine 函数效果最好。另外,从这个实验也可以看到,在采用凹函数作为 mask schedule 时,生成质量随总步数 T T T 也不是单调地越大越好的,在 8 步左右达到最优,并且达到最优的步数,也是约 ”凹“ 越小,从这个角度来看,cosine 也是最好的选择。
至于为什么当总步数更大之后效果反而变差,作者认为是因为过多的步数,每步只会取那些概率非常高的 token,这样会使得模型无法采样到那些置信度稍低的 token,从而降低 token 的多样性。个人理解这里就类似于 LLM 中的采样并不是每次 greedy 地取最大概率的 token、diffusion 生成过程也是每步需要加一些随机性(sde instead of ode)。即生成模型采样过程中,还是有一点随机性会更好。
总结
掩码图像生成的经典工作,个人理解应该也可以算是一种广义的自回归,即 random (adaptive?) order multi token prediction,感觉比常规的 raster order next token prediction 做图像生成要更合理一些。