论文精读(保姆级解析)—— Flash Diffusion

0 前言

  今天分析的论文是《Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation》。该论文发表在2024年,目前已开源在arxiv上,主要提出了一种高效、快速且多功能的蒸馏方法,用于加速预训练扩散模型的生成:Flash Diffusion。下面给出论文的地址和代码仓库链接:

1 摘要

  本文提出了一种高效、快速且多功能的蒸馏方法,以加速预训练扩散模型的生成:Flash Diffusion。在COCO2014和COCO2017数据集上,该方法在FID和CLIP-Score方面达到了最先进的性能,与现有的方法相比,只需几个GPU小时的训练和更少的可训练参数。除了其高效性外,该方法的多功能性也在多个任务中得到了体现,例如文本生成图像、修复、换脸、超分辨率,以及使用不同的骨干网络(如基于UNet的去噪器SD1.5、SDXL或DiT (Pixart-α))和适配器。在所有情况下,该方法都显著减少了采样步骤的数量,同时保持了非常高质量的图像生成。
在这里插入图片描述

2 引言

  扩散模型很有效,应用非常广泛。然而,由于采样机制的内在迭代行,导致计算成本比较高。。。
  最近,出现了一些更高效的求解器或扩散蒸馏方法,旨在从训练的扩散模型中减少所需的采样步骤以生成令人满意的样本。然而,求解器通常需要至少10次神经函数评估(NFEs)才能生成令人满意的样本,而蒸馏方法可能需要大量的训练资源或需要一种迭代训练程序来在整个训练过程中更新teacher model,这限制了它们的应用范围。此外,大多数现有的蒸馏方法是为特定任务(如文本生成图像)量身定制的,目前尚不清楚它们在使用不同条件和扩散模型架构的其他任务中的表现。此外,最有效的方法依赖对抗训练程序,这可能导致训练不稳定并需要大量的超参数调整。

  在本文中,我们提出了Flash Diffusion,这是一种快速、稳健且多功能的扩散蒸馏方法,能够大幅减少采样步骤的数量,同时保持非常高的图像生成质量。提出的方法旨在训练一个student model,使其在单步预测中对损坏的输入样本进行去噪的多步教师预测。该方法还通过对抗目标将学生分布引导至真实输入样本流形,并通过分布匹配确保其不会过度偏离已学习的教师分布。

  该方法与LoRA兼容,与现有的方法相比,该方法能够在仅需几个步骤的情况下生成高质量的样本,同时只需要几个GPU小时的训练时间更少的可训练参数。该方法能够在COCO2014和COCO2017数据集上以少步图像生成的FID和CLIP分数达到最新的(SOTA)性能。除了其高效性外,该方法在多个任务(如文本生成图像、修复、换脸、图像放大)以及使用不同的扩散模型骨干(SD1.5 [57]、SDXL [50] 和Pixart-α [5])和适配器 [46] 中也展示了其多功能性。在所有情况下,该方法都能显著减少采样步骤的数量,同时保持非常高质量的图像生成(论文都得吹一下)。

  本文的主要贡献如下:

  • 提出了一种高效、快速、多功能且与LoRA兼容的蒸馏方法,旨在减少从训练的扩散模型中生成高质量样本所需的采样步骤。
  • 验证了该方法在文本生成图像任务中的效果,显示其能够在标准基准数据集上仅用两个神经函数评估(NFE)以及少量的图像生成步骤达到SOTA效果,相当于使用无分类器指导的一步,同时所需的训练参数远少于竞争对手,仅需几个GPU小时的训练。
  • 进行了广泛的消融研究,展示了该方法各个组成部分的影响,并证明了其稳健性和可靠性。
  • 通过广泛的实验研究,强调了该方法的多功能性,涵盖了各种任务(文本生成图像、图像修复、超分辨率、换脸)和扩散模型架构(SD1.5、SDXL和Pixart-α),并展示了其与适配器的兼容性。

3 相关工作

3.1 扩散模型

  扩散模型包括根据给定的噪声调度人为地破坏输入数据[64,17,67],使数据分布最终类似于标准高斯分布。然后,它们被训练来估计添加的噪声量,以学习反向扩散过程,从而在训练完成后能够从高斯噪声生成新样本。这些模型可以根据各种输入进行条件化,如图像、深度图、边缘、姿势或文本,在这些条件下它们展示了非常令人印象深刻的结果。然而,为了生成高质量的样本,在推理时需要大量的采样步骤(通常为50步),这限制了它们在实时应用中的使用和推广。

3.2 扩散蒸馏法

  为了解决这一限制,最近出现了几种方法来减少推理时所需的函数计算次数。一方面,几篇论文尝试构建更高效的求解器来加速生成过程,但这些方法仍然需要使用多个步骤(通常为10步)来生成令人满意的样本。另一方面,一些依赖模型蒸馏的方法提出训练一个学生网络,使其学会在更少的步骤中匹配教师模型生成的样本。一种简单的方法是建立噪声/教师样本对,并训练一个学生模型,使其在具有回归损失的单步中匹配来自相同噪声的教师预测。尽管如此,这种方法仍然非常有限,并且很难与教师模型的质量相匹配,因为在充满噪音的环境中,学生没有潜在的有用信息可以学习。在此基础上,一些方法提出先对输入样本应用正向扩散过程,然后将其传递给学生网络。学生的预测然后使用回归损失、对抗目标或分布匹配与教师模型的学习分布进行比较。

3.3 渐进式蒸馏

  渐进蒸馏(Progressive Distillation)也是一种被证明相当有前景的方法。它包括训练一个学生模型在一个步骤中预测一个噪声样本的两步教师去噪,理论上减少了一半所需的采样步骤。然后教师模型被新的学生模型替换,这个过程重复多次。这种方法也被丰富为基于GAN的目标,使得所需的采样步骤从4-8步进一步减少到一个步骤。InstaFlow提出依靠修正流来简化单步蒸馏过程。然而,这种方法可能需要大量的训练参数和长时间的训练过程,使其计算密集。

3.4 一致性模型

  一致性模型(Consistency Models)也是一种在文献中提出的有前景、有效且多功能的蒸馏方法。主要思想是训练一个模型,将位于概率流常微分方程上的任何点映射到其原点,理论上解锁单步生成。Luo等人结合潜在一致性模型和LoRAs,展示了在非常有限的训练参数和几个GPU数小时的训练下,训练出一个强大的学生模型的可能性。然而,这些模型仍然难以实现单步生成并达到同类方法的采样质量。

  在最近进行的一项平行研究中,Yin等人还引入了联合使用分布匹配损失和对抗损失的方法,作者也在论文中使用了这种方法。然而,它们不依赖于在我们的实验中证明非常有效的蒸馏损失的使用,也不计算相对于相同输入的对抗损失。此外,他们的方法仍然需要训练另一个去噪器来评估假样本的分数,显著增加了可训练参数的数量和方法的计算负担。此外,他们方法在不同任务和扩散模型架构中进行泛化和有效表现的能力仍不明确。

4.1 扩散模型

  设 x 0 ∈ X x_0 \in X x0X 是一组输入数据,使得 x 0 ∼ p ( x 0 ) x_0 \sim p(x_0) x0p(x0),其中 p ( x 0 ) p(x_0) p(x0) 是一个未知分布。扩散模型是一类生成模型,它们定义了一个马尔可夫过程 ( x t ) t ∈ [ 0 , T ] (x_t)_{t \in [0, T]} (xt)t[0,T],通过向数据 x 0 x_0 x0中逐步注入高斯噪声来创建 x 0 x_0 x0的噪声版本 x t x_t xt。随着 t t t 增加,噪声样本 x t x_t xt 的分布最终变得等同于各向同性高斯分布。噪声调度由两个可微函数 α ( t ) \alpha(t) α(t) σ ( t ) \sigma(t) σ(t) 控制,对于任意 t ∈ [ 0 , T ] t \in [0, T] t[0,T],使得信噪比的对数 log ⁡ [ α ( t ) 2 / σ ( t ) 2 ] \log[\alpha(t)^2 / \sigma(t)^2] log[α(t)2/σ(t)2] 随时间递减。给定任意 t ∈ [ 0 , T ] t \in [0, T] t[0,T],噪声样本相对于输入 x 0 x_0 x0 的分布 q ( x t ∣ x 0 ) q(x_t | x_0) q(xtx0) 称为前向过程,定义为 q ( x t ∣ x 0 ) = N ( x t ; α ( t ) ⋅ x 0 , σ ( t ) 2 ⋅ I ) q(x_t | x_0) = \mathcal{N}(x_t; \alpha(t) \cdot x_0, \sigma(t)^2 \cdot I) q(xtx0)=N(xt;α(t)x0,σ(t)2I),可以如下进行采样:
x t = α ( t ) ⋅ x 0 + σ ( t ) ⋅ ϵ 其中 ϵ ∼ N ( 0 , I ) (1) x_t = \alpha(t) \cdot x_0 + \sigma(t) \cdot \epsilon \quad \text{其中} \quad \epsilon \sim \mathcal{N}(0, I) \tag{1} xt=α(t)x0+σ(t)ϵ其中ϵN(0,I)(1)

  扩散模型的主要思想是学习对噪声样本 x t ∼ q ( x t ∣ x 0 ) x_t \sim q(x_t | x_0) xtq(xtx0) 进行去噪,以学习反向过程,最终从纯噪声中生成样本 x ~ 0 \tilde{x}_0 x~0。在实践中,在训练过程中,扩散模型包括学习一个以时间步长 t t t为条件的参数化函数 x θ x_\theta xθ,并将噪声样本 x t x_t xt作为输入,使其预测原始样本 x 0 x_0 x0的去噪版本。参数 θ \theta θ 通过去噪得分匹配学习:
L = E x 0 ∼ p ( x 0 ) , t ∼ π ( t ) , ϵ ∼ N ( 0 , I ) [ λ ( t ) ∥ x θ ( x t , t ) − x 0 ∥ 2 ] (2) L = \mathbb{E}_{x_0 \sim p(x_0), t \sim \pi(t), \epsilon \sim \mathcal{N}(0, I)} \left[ \lambda(t) \left\| x_\theta(x_t, t) - x_0 \right\|^2 \right] \tag{2} L=Ex0p(x0),tπ(t),ϵN(0,I)[λ(t)xθ(xt,t)x02](2)
  其中 λ ( t ) \lambda(t) λ(t) 是一个取决于时间步 t ∈ [ 0 , 1 ] t \in [0, 1] t[0,1] 的缩放因子, π ( t ) \pi(t) π(t) 是时间步的分布。注意,公式 (2) 实际上等同于学习一个函数 ϵ θ \epsilon_\theta ϵθ,它估计添加到原始样本上的噪声 ϵ \epsilon ϵ,通过重参数化 ϵ θ ( x t , t ) = x t − α ( t ) ⋅ x θ ( x t , t ) σ ( t ) \epsilon_\theta(x_t, t) = \frac{x_t - \alpha(t) \cdot x_\theta(x_t, t)}{\sigma(t)} ϵθ(xt,t)=σ(t)xtα(t)xθ(xt,t)得到。Song 等人表明, ϵ θ \epsilon_\theta ϵθ 可以通过求解以下 PF-ODE 来从高斯噪声生成新的数据点:
d x t = [ f ( x t , t ) − 1 2 g 2 ( t ) ∇ log ⁡ p θ ( x t ) ] d t (3) d x_t = \left[ f(x_t, t) - \frac{1}{2} g^2(t) \nabla \log p_\theta(x_t) \right] dt \tag{3} dxt=[f(xt,t)21g2(t)logpθ(xt)]dt(3)
其中 f ( x t , t ) f(x_t, t) f(xt,t) g ( t ) g(t) g(t) 分别是 PF-ODE 的漂移函数和扩散函数,定义如下:
f ( x t , t ) = d log ⁡ α ( t ) d t ⋅ x t g 2 ( t ) = d σ ( t ) 2 d t − 2 ⋅ d log ⁡ α ( t ) d t ⋅ σ 2 ( t ) f(x_t, t) = \frac{d \log \alpha(t)}{dt} \cdot x_t \\ g^2(t) = \frac{d \sigma(t)^2}{dt} - 2 \cdot \frac{d \log \alpha(t)}{dt} \cdot \sigma^2(t) f(xt,t)=dtdlogα(t)xtg2(t)=dtdσ(t)22dtdlogα(t)σ2(t)
∇ log ⁡ p θ ( x t ) = − ϵ θ ( x t , t ) σ ( t ) \nabla \log p_\theta(x_t) = -\frac{\epsilon_\theta(x_t, t)}{\sigma(t)} logpθ(xt)=σ(t)ϵθ(xt,t)称为 p θ ( x t ) p_θ(x_t) pθ(xt)的分数函数。PF-ODE 可以使用神经 ODE 积分器 [7] 求解,该积分器通过给定的更新规则,如欧拉方法 [67] 或 Heun 解算器 [23],迭代地应用学习到的函数 ϵ θ \epsilon_\theta ϵθ
  通过学习条件去噪函数 ϵ θ ( x t , t , c ) \epsilon_\theta(x_t, t, c) ϵθ(xt,t,c) x θ ( x t , t , c ) x_\theta(x_t, t, c) xθ(xt,t,c) ,可以训练条件扩散模型从条件分布 p ( x 0 ∣ c ) p(x_0 | c) p(x0c)生成样本。在这种特定设置下,Classifier-Free Guidance (CFG) 已证明是一种非常有效的方法,可以更好地强制模型遵守条件,从而提高采样质量。CFG 是一种技术,它在训练期间以一定概率丢弃条件 c c c,并在推理时用以下线性组合替换条件噪声估计 ϵ θ ( x t , t , c ) \epsilon_\theta(x_t, t, c) ϵθ(xt,t,c)
ϵ θ ( x t , t , c ) = ω ⋅ ϵ θ ( x t , t , c ) + ( 1 − ω ) ⋅ ϵ θ ( x t , t , ∅ ) (4) \epsilon_\theta(x_t, t, c) = \omega \cdot \epsilon_\theta(x_t, t, c) + (1 - \omega) \cdot \epsilon_\theta(x_t, t, \emptyset) \tag{4} ϵθ(xt,t,c)=ωϵθ(xt,t,c)+(1ω)ϵθ(xt,t,)(4)
其中 ω > 0 \omega > 0 ω>0 被称为引导尺度。

4.2 一致性模型

  由于本文的方法受到一致性模型(Consistency Models,CM)的启发,作者回顾了一些这些模型的要素。CM 是一种新型的生成模型,主要用于学习一致性函数 f θ f_\theta fθ,该函数将位于公式(3)给出的PF-ODE轨迹上的任意样本 x t x_t xt 直接映射到原始样本 x 0 x_0 x0,同时确保任意 t ∈ [ ϵ , T ] t \in [\epsilon, T] t[ϵ,T] ϵ > 0 \epsilon > 0 ϵ>0时的自一致性属性:
f θ ( x t , t ) = f θ ( x t ′ , t ′ ) , ∀ ( t , t ′ ) ∈ [ ϵ , T ] 2 (5) f_\theta(x_t, t) = f_\theta(x_{t'}, t'), \forall(t, t') \in [\epsilon, T]^2 \tag{5} fθ(xt,t)=fθ(xt,t),(t,t)[ϵ,T]2(5)
为了确保一致性属性,Song等人提出对 f θ f_\theta fθ 进行如下参数化:
f θ ( x t , t ) = c skip ( t ) ⋅ x t + c out ( t ) ⋅ F θ ( x t , t ) , f_\theta(x_t, t) = c_{\text{skip}}(t) \cdot x_t + c_{\text{out}}(t) \cdot F_\theta(x_t, t) , fθ(xt,t)=cskip(t)xt+cout(t)Fθ(xt,t),
其中 F θ F_\theta Fθ 是使用神经网络进行参数化的, c skip c_{\text{skip}} cskip c out c_{\text{out}} cout 是可微函数【68, 42】。一致性模型可以从头开始训练(Consistency Training)或可用于蒸馏现有的DM(Consistency Distillation)。在这两种情况下,模型的目标是学习 f θ f_\theta fθ 以匹配目标函数 f θ − f_{\theta^-} fθ 的输出,其权重使用指数移动平均(EMA)进行更新,针对任意位于 PF-ODE 轨迹上的点 ( x t , x t ′ ) (x_t, x_{t'}) (xt,xt)
L = E x 0 ∼ p ( x 0 ) , t ∼ π ( t ) , ϵ ∼ N ( 0 , I ) [ ∥ f θ ( x t , t ) − f θ − ( x t ′ , t ′ ) ∥ 2 ] L = \mathbb{E}_{x_0 \sim p(x_0), t \sim \pi(t), \epsilon \sim \mathcal{N}(0, I)} \left[ \| f_\theta(x_t, t) - f_{\theta^-}(x_{t'}, t') \|^2 \right] L=Ex0p(x0),tπ(t),ϵN(0,I)[fθ(xt,t)fθ(xt,t)2]
  换句话说,给定使用公式(1)得到的噪声样本 x t x_t xt,其思想是强制 f θ ( x t , t ) = f θ − ( x t ′ , t ′ ) f_\theta(x_t, t) = f_{\theta^-}(x_{t'}, t') fθ(xt,t)=fθ(xt,t),其中 x t ′ x_{t'} xt 是使用相同噪声 ϵ \epsilon ϵ 和输入 x 0 x_0 x0 通过公式(1)进行一致性训练得到的或使用训练好的扩散模型 ϵ ∅ teacher \epsilon_{\emptyset}^{\text{teacher}} ϵteacher 和 ODE 求解器 Ψ \Psi Ψ 进行一致性蒸馏。一旦模型训练完毕,可以通过首先绘制噪声样本 x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) xTN(0,I),然后应用学习到的函数 f θ f_\theta fθ 来理论上在一步内生成样本 x ~ 0 \tilde{x}_0 x~0。在实际操作中,需要进行多次迭代以生成令人满意的样本,因此估计的样本 x ~ 0 \tilde{x}_0 x~0 被使用学习到的函数 f θ f_\theta fθ 多次反复地添加噪声和去噪。

5 方法

  在这一部分中,作者介绍了基于文献中提出的若干理念构建的方法。接下来,作者将自己置于潜在扩散模型[57]的背景下进行图像生成,并将教师模型称为 ϵ ϕ teacher \epsilon_\phi^{\text{teacher}} ϵϕteacher,学生模型称为 ϵ θ student \epsilon_\theta^{\text{student}} ϵθstudent,训练图像称为 x 0 x_0 x0 及其未知分布为 p ( x 0 ) p(x_0) p(x0)。将 z 0 = ϵ ( x 0 ) z_0 = \epsilon(x_0) z0=ϵ(x0) 称为通过编码器 ϵ \epsilon ϵ 得到的相关潜变量。记时间步长的概率密度函数为 π \pi π,并设定 T = 1 T = 1 T=1。请注意,所提出的方法也可以直接应用于像素空间扩散模型。

5.1 蒸馏一个预训练的扩散模型

  该方法主要是为了实现一种快速、鲁棒且可靠的方案,能够轻松适应不同的使用场景。给定一组数据 x 0 ∈ X x_0 \in X x0X,使得 x 0 ∼ p ( x 0 ) x_0 \sim p(x_0) x0p(x0),以及通过编码器 E E E 得到的相关潜变量 z 0 = ϵ ( x 0 ) z_0 = \epsilon(x_0) z0=ϵ(x0),该方法的主要思想与扩散模型相似。给定由 α ( t ) \alpha(t) α(t) σ ( t ) \sigma(t) σ(t) 定义的噪声时间表,作者建议创建一个带噪声的潜变量样本 z t z_t zt,其中 t ∼ π ( t ) t \sim \pi(t) tπ(t),如公式(1)所示,并训练一个函数 f θ student f_\theta^{\text{student}} fθstudent 来预测原始样本 z 0 z_0 z0 的去噪版本 z ~ 0 \tilde{z}_0 z~0。与扩散模型的主要区别在于它不是使用 z 0 z_0 z0作为目标,作者建议利用教师模型的知识,使用属于教师模型学习的数据分布 p ϕ t e a c h e r ( z 0 ) p_{\phi}^{teacher} (z_0) pϕteacher(z0) 的样本。换句话说,就是使用教师模型和一个 ODE 求解器生成一个去噪的潜变量样本 z ~ 0 teacher ( z t ) \tilde{z}_{0}^{\text{teacher}}(z_t) z~0teacher(zt),它属于学习数据分布,并将其作为学生模型的目标。主要的蒸馏损失函数写作:
L distil = E z 0 , t , ϵ [ ∥ f student θ ( z t , t ) − z ~ 0 teacher ( z t ) ∥ 2 ] , ( 6 ) L_{\text{distil}} = \mathbb{E}_{z_0, t, \epsilon} \left[ \| f_{\text{student}}^\theta(z_t, t) - \tilde{z}_{0}^{\text{teacher}}(z_t) \|^2 \right] ,(6) Ldistil=Ez0,t,ϵ[fstudentθ(zt,t)z~0teacher(zt)2](6)
其中 π ( t ) \pi(t) π(t) 表示时间步长的分布, z 0 ~ teacher ( z t ) \tilde{z_0}^{\text{teacher}}(z_t) z0~teacher(zt) 是通过在教师模型 ϵ ϕ t e a c h e r \epsilon_\phi^{teacher} ϵϕteacher 上从 z t = α ( t ) ⋅ z 0 + σ ( t ) ⋅ ϵ z_t = \alpha(t) \cdot z_0 + \sigma(t) \cdot \epsilon zt=α(t)z0+σ(t)ϵ 开始运行ODE 求解器 Ψ \Psi Ψ 的若干步得到的。类似的思想在Sauer等人的方法中也有应用,但作者生成完全合成的样本,这意味着样本 z t z_t zt 是纯噪声, z t ∼ N ( 0 , I ) z_t \sim \mathcal{N}(0, I) ztN(0,I)。相反,在本文的方法中,作者假设允许 z t z_t zt 保留一些来自真实编码样本 z 0 z_0 z0 的信息可以增强蒸馏过程。如Luo等人所述,在蒸馏条件扩散模型时,我们还与教师模型一起执行无分类器指导(Classifier-Free Guidance,CFG),以更好地确保模型遵守条件。这项技术实际上显著提高了学生生成的样本质量。此外,它消除了在学生推理过程中执行 CFG 的需要,进一步减少了每步计算量的一半。训练期间使用的指导尺度 ω \omega ω 的值作者在消融实验中进行了展示,但在实践中, ω \omega ω [ ω min ⁡ , ω max ⁡ ] [\omega_{\min}, \omega_{\max}] [ωmin,ωmax] 中均匀采样,其中 0 ≤ ω min ⁡ ≤ ω max ⁡ 0 \leq \omega_{\min} \leq \omega_{\max} 0ωminωmax。如第4.2节所述,作者说他的方法与现有的一致性模型相似。但是他们不是依赖学生模型的先前实例来估计 PF-ODE 的起源,而是直接使用教师模型结合 ODE 求解器生成目标。并且观察到这些要素增强了训练过程的稳定性。

5.2 时间步长采样

  方法的核心在于时间步长概率密度函数 π ( t ) \pi(t) π(t) 的选择。根据【67】中介绍的连续建模,扩散模型(DMs)被训练在任意连续时间 t t t 上从潜在样本 z t z_t zt 中去除噪声。然而,由于我们目标是在推理时实现少步数数据生成(通常为1-4步),学习的函数 ϵ θ \epsilon_\theta ϵθ 仅在少数离散的时间步长 { t i } i = 1 K \{t_i\}_{i=1}^K {ti}i=1K 上进行评估。

  为了解决这个问题并确保蒸馏过程集中于最相关的时间步长,我们建议在区间 [ 0 , 1 ] [0, 1] [0,1] 内选择 K(通常为16、32或64)个均匀分布的时间步长,并根据概率质量函数 π ( t ) \pi(t) π(t) 为每个时间步长分配概率。我们选择 π ( t ) \pi(t) π(t) 作为由一系列权重 { β i } i = 1 K \{\beta_i\}_{i=1}^K {βi}i=1K 控制的高斯分布混合:
π ( t ) = 1 2 π σ 2 ∑ i = 1 K β i exp ⁡ ( − ( t − μ i ) 2 2 σ 2 ) , (7) \pi(t) = \frac{1}{\sqrt{2\pi\sigma^2}} \sum_{i=1}^K \beta_i \exp \left( - \frac{(t - \mu_i)^2}{2\sigma^2} \right) , \tag{7} π(t)=2πσ2 1i=1Kβiexp(2σ2(tμi)2),(7)

其中每个高斯分布的均值由 { μ i = i K } i = 1 K \{\mu_i = \frac{i}{K}\}_{i=1}^K {μi=Ki}i=1K 控制,方差固定为 σ = 0.5 K 2 \sigma = \sqrt{\frac{0.5}{K^2}} σ=K20.5 。这种方法使得在蒸馏教师模型时,只有少数 K 个离散时间步长被采样,而不是连续区间 [ 0 , 1 ] 3 [0, 1]^3 [0,1]3。此外,分布 π \pi π 定义为在 K 个选定的时间步长中,用于1、2和4步生成的4个时间步长被过采样(通常我们设定 β i > 0 \beta_i > 0 βi>0 如果 i ∈ [ K 4 , K 2 , 3 K 4 , K ] i \in [\frac{K}{4}, \frac{K}{2}, \frac{3K}{4}, K] i[4K,2K,43K,K] β i = 0 \beta_i = 0 βi=0)。与其Sauer等人的方法相比,本文不仅关注这4个时间步长,因为我们注意到这可能会导致生成样本的多样性减少,对此,作者进行了消融研究验证。实际上,作者注意到热身阶段对训练过程是有益的。因此,决定首先对对应于最少噪声增加的时间步长施加更高的概率,通过设定 β K / 4 = β K / 2 = 0.5 \beta_{K/4} = \beta_{K/2} = 0.5 βK/4=βK/2=0.5 和其他 β i = 0 \beta_i = 0 βi=0。然后我们逐渐将概率质量转移到全噪声,以促进单步生成,同时仍然对目标的4个时间步长进行过采样,设定严格正值的 β i \beta_i βi,其中 i ≡ 0 [ K 4 ] i \equiv 0[\frac{K}{4}] i0[4K],其他 β i = 0 \beta_i = 0 βi=0。图2中展示了 K=32 的 π \pi π 示例。如图所示, [ 0 , 1 ] [0, 1] [0,1] 区间被分为32个时间步长。在热身阶段,概率质量将更高的概率分配给时间步长 [ 0.25 , 0.5 ] [0.25, 0.5] [0.25,0.5] 以简化蒸馏过程。随着训练的进行,概率质量函数逐渐向全噪声转移,以促进单步生成,同时始终为4个时间步长 [ 0.25 , 0.5 , 0.75 , 1 ] [0.25, 0.5, 0.75, 1] [0.25,0.5,0.75,1] 分配更高的概率。时间步长分布的影响在第6.2节中进一步讨论。
在这里插入图片描述

5.3对抗性目标

  为了进一步提高样本的质量,并且由于一些文献中提出的几项工作证明了实现几步图像生成的效率很高,于是作者决定引入对抗性目标。核心思想是训练学生模型生成与真实数据分布 p ( x 0 ) p(x_0) p(x0) 难以区分的样本。为此,我们提出训练一个判别器 D ν D_\nu Dν 来区分生成样本 x ~ 0 \tilde{x}_0 x~0 与真实样本 x 0 ∼ p ( x 0 ) x_0 \sim p(x_0) x0p(x0)。如Sauer和Lin等人所建议的,我们也直接在潜在空间中应用判别器。这种方法避免了使用VAE解码样本的必要性,这一过程在Sauer等人的文章中有所概述,被证明是昂贵的并且阻碍了方法在高分辨率图像上的可扩展性。

  借鉴Sauer和Lin等人提出的文章中的灵感,作者提出了一种方法,在这种方法中,一步的学生预测 z ~ 0 \tilde{z}_0 z~0 和输入的潜变量 z 0 z_0 z0 按照教师的噪声计划重新添加噪声。这个过程使用一个时间步长 t ′ t' t,它从集合 [ 0.01 , 0.25 , 0.5 , 0.75 ] [0.01, 0.25, 0.5, 0.75] [0.01,0.25,0.5,0.75] 中均匀选择。样本首先通过冻结的教师模型,然后通过判别器,得到真假预测。当使用 UNet 架构作为教师模型时,我们的方法专注于仅使用 UNet 的编码器部分,生成更压缩的潜变量表示,并进一步减少判别器的参数数量。我们仔细选择特定的时间步长,以使判别器能够有效地根据高频和低频细节区分样本,如Lin等人所讨论的。需要注意的是,在本文提出的方法中,判别器是唯一需要训练的组件,而教师模型保持冻结状态。对抗损失 L a d v L_{adv} Ladv 和判别器损失 L discriminator L_{\text{discriminator}} Ldiscriminator 写作:
L adv = 1 2 E z 0 , t ′ , ϵ [ ∣ ∣ D ν ( f θ s t u d e n t ( z t ′ , t ′ ) ) − 1 ∣ ∣ 2 ] , L discriminator = 1 2 E z 0 , t ′ , ϵ [ ∥ D ν ( z 0 ) − 1 ∥ 2 + ( D ν ( f θ s t u d e n t ( z t ′ , t ′ ) ) − 0 ) 2 ] (8) L_{\text{adv}} = \frac{1}{2} \mathbb{E}_{z_0,t',\epsilon} \left[ || D_\nu(f_\theta^{student}(z_{t'}, t')) - 1 ||^2 \right], \\ L_{\text{discriminator}} = \frac{1}{2} \mathbb{E}_{z_0,t',\epsilon} \left[ \|D_\nu(z_0) - 1\|^2 + \left( D_\nu(f_\theta^{student}(z_{t'}, t')) - 0 \right)^2 \right] \tag{8} Ladv=21Ez0,t,ϵ[∣∣Dν(fθstudent(zt,t))1∣2]Ldiscriminator=21Ez0,t,ϵ[Dν(z0)12+(Dν(fθstudent(zt,t))0)2](8)
其中 ν \nu ν 表示判别器参数。我们选择这些特定的损失是因为它们在训练过程中表现出可靠性和稳定性,如我们的实验所示。在消融研究中,作者强调了所选择的对抗损失 L adv L_{\text{adv}} Ladv 的影响。在实践中,鉴别器的架构被设计为一个简单的卷积神经网络(CNN),其步幅为2,核大小为4,SiLU激活和组归一化。

5.4 分布匹配

  受Yin等人工作的启发,作者还提出引入分布匹配蒸馏(DMD)损失,以确保生成的样本紧密反映教师模型学习到的数据分布。具体来说,这涉及最小化学生模型的样本分布 p θ student p_\theta^{\text{student}} pθstudent 和教师模型学习到的数据分布 p ∅ teacher p_\emptyset^{\text{teacher}} pteacher 之间的Kullback-Leibler(KL)散度:

L DMD = D KL ( p θ s t u d e n t ∣ ∣ p ∅ t e a c h e r ) = E z 0 , t , ϵ [ − ( log ⁡ p ∅ t e a c h e r ( f θ s t u d e n t ( z t , t ) ) − log ⁡ p θ s t u d e n t ( f θ s t u d e n t ( z t , t ) ) ) ] (9) L_{\text{DMD}} = D_{\text{KL}}(p_\theta^{student}|| p_\emptyset^{teacher}) = \\ \mathbb{E}_{z_0,t,\epsilon} \left[ -\left( \log p_\empty^{teacher} \left( f_\theta^{student}(z_t, t) \right)- \log p_\theta^{student} \left( f_\theta^{student}(z_t, t) \right) \right) \right] \tag{9} LDMD=DKL(pθstudent∣∣pteacher)=Ez0,t,ϵ[(logpteacher(fθstudent(zt,t))logpθstudent(fθstudent(zt,t)))](9)
对KL散度关于学生模型参数 θ \theta θ 求导得到以下更新规则:
∇ θ L DMD = E z 0 , t , ϵ [ − ( s t e a c h e r ( f θ s t u d e n t ( z t , t ) ) − s s t u d e n t ( f θ s t u d e n t ( z t , t ) ) ) ∇ f θ s t u d e n t ( z t , t ) ] , \nabla_\theta L_{\text{DMD}} = \\ \mathbb{E}_{z_0,t,\epsilon} \left[ -\left( s^{teacher}\left( f_\theta^{student}(z_t, t) \right)- s^{student}\left( f_\theta^{student}(z_t, t) \right) \right) \nabla f_\theta^{student}(z_t, t) \right], θLDMD=Ez0,t,ϵ[(steacher(fθstudent(zt,t))sstudent(fθstudent(zt,t)))fθstudent(zt,t)]
其中 s teacher s^{\text{teacher}} steacher s student s^{\text{student}} sstudent 分别是教师和学生分布的得分函数。

  受Yin等人的启发,单步学生预测使用均匀采样的时间步长 t ′ ′ ∼ U ( [ 0 , 1 ] ) t'' \sim U([0, 1]) t′′U([0,1]) 和教师的噪声计划重新加噪。新的有噪声样本通过冻结的教师模型以获取教师分布的得分函数: s t e a c h e r ( f θ s t u d e n t ( z t ′ ′ , t ′ ′ ) ) = − ( ϵ ∅ t e a c h e r ( x t ′ ′ , t ′ ′ ) / σ ( t ′ ′ ) ) s^{teacher}(f_\theta^{student}(z_{t^{\prime\prime}}, t^{\prime\prime}))=-(\epsilon_\empty^{teacher}(x_{t^{\prime\prime}},t^{\prime\prime})/\sigma(t^{\prime\prime})) steacher(fθstudent(zt′′,t′′))=(ϵteacher(xt′′,t′′)/σ(t′′))。在我们的方法中,我们利用学生模型来获取学生分布的得分函数,而不是像Yin等人所提到的专用扩散模型。这一选择显著减少了可训练参数的数量和计算成本。

5.5 模型训练

  在追求鲁棒性和多样性的同时,我们还旨在设计一个可训练参数最少的模型,因为它涉及加载计算密集型函数(教师模型和学生模型)。为此,我们提出依赖参数高效方法LoRA并将其应用于我们的学生模型。通过这种方式,我们大幅减少了参数数量并加快了训练过程。

  简而言之,我们的学生模型被训练以最小化蒸馏损失(Eq. (6))、对抗性损失(Eq. (8))和分布匹配损失(Eq. (9))的加权组合:
L = L distil + λ adv L adv + λ DMD L DMD (10) L = L_{\text{distil}} + \lambda_{\text{adv}} L_{\text{adv}} + \lambda_{\text{DMD}} L_{\text{DMD}} \tag{10} L=Ldistil+λadvLadv+λDMDLDMD(10)
在这里插入图片描述

  训练过程详见算法1,并在下图中进行了说明。具体来说,首先从未知数据分布中随机选取一个样本 x 0 ∼ p ( x 0 ) x_0 \sim p(x_0) x0p(x0)。然后使用编码器 ϵ \epsilon ϵ 对该样本进行编码,得到相应的潜在样本 z 0 z_0 z0。根据第5.2节中详述的时间步长概率质量函数 π \pi π,绘制时间步长 t t t,并使用Eq. (1) 创建一个有噪声样本 z t z_t zt。然后使用教师模型 ϵ φ teacher \epsilon_\varphi^{\text{teacher}} ϵφteacher 和ODE求解器 Ψ \Psi Ψ 来求解PF-ODE,从而生成一个属于教师模型学习到的分布的合成样本 z ~ 0 teacher \tilde{z}_0^{\text{teacher}} z~0teacher。同时,学生模型 f θ s t u d e n t f_\theta^{student} fθstudent 被用来在单步内生成一个去噪样本 z ~ 0 student = f θ student ( z t , t ) \tilde{z}_0^{\text{student}} = f_\theta^{\text{student}}(z_t, t) z~0student=fθstudent(zt,t)。然后,根据Eq. (6) 计算蒸馏损失。接着,重新对单步学生预测 z ~ 0 student \tilde{z}_0^{\text{student}} z~0student 和输入的潜在样本 z 0 z_0 z0 进行加噪,并按第5.3节所述计算对抗性损失。最后,对于分布匹配,再次取单步学生预测 z ~ 0 student \tilde{z}_0^{\text{student}} z~0student,并使用均匀采样的时间步长 t ∼ U ( [ 0 , 1 ] ) t \sim U([0, 1]) tU([0,1]) 对其进行加噪。新的有噪声样本通过教师模型获取教师得分函数 s teacher s^{\text{teacher}} steacher,同时我们使用学生模型(而不是Yin等人的专用扩散模型)获取学生得分函数 s student s^{\text{student}} sstudent。然后按第5节所述计算分布匹配损失。
在这里插入图片描述

  总的来说,我们提出的方法仅依赖少量参数的训练。这是通过将LoRA应用于学生模型,利用冻结的教师模型进行对抗性方法,并直接使用学生去噪器而不是引入一个新的扩散模型来计算分布匹配损失的假分数实现的。这种方法不仅大大减少了参数数量,还加快了训练过程。

实验

  作者将所提出的方法与现有的蒸馏方法在文本到图像生成中的效果进行比较。在本节中,我们将我们的蒸馏方法应用于公开可用的SD1.5模型,并在COCO2014和COCO2017数据集上报告FID和CLIP得分。模型在LAION数据集上进行训练,我们选择美学评分高于6的样本,并使用CogVLM提示词生成合成图像。对于COCO2017,我们依赖于[45]中提出的评估方法,并从验证集中选择5000个提示来生成合成图像。对于COCO2014,采用【22】中提出的评估协议,从验证集中选择30000个提示词。然后,计算与各自验证集中真实图像的FID,COCO2017验证集包含5000张图像,COCO2014验证集包含40504张图像。模型在2个H100-80Gb GPU上进行20k次迭代训练,批量大小为4,学习率为1e-5,使用Adam优化器训练学生模型和判别器。使用第5.2节中详述的时间步长分布 π ( t ) \pi(t) π(t),其中 K = 32 K=32 K=32,每5000次迭代进行一次相移。从 λ adv = 0 \lambda_{\text{adv}}=0 λadv=0 λ DMD = 0 \lambda_{\text{DMD}}=0 λDMD=0开始,并在每次更改时间步长分布时逐步增加,最终值分别设为0.3和0.7。指导尺度 ω \omega ω U ( [ 3 , 13 ] ) U([3, 13]) U([3,13])中采样。学生模型的权重全部用教师模型的权重初始化。

  表1和表2给出了定量比较结果。本文方法在COCO2017和COCO2014上分别达到了22.6和12.27的FID,仅需要2个NFE(网络功能评估)即可达到少步数图像生成的SOTA(最新技术)。在COCO2017上,该方法在2和4个NFE下分别达到了0.306和0.311的CLIP得分。另外,该方法只需要训练2640万参数(相对于900M的教师参数)和仅26个H100 GPU小时的训练时间。这与许多竞争对手需要训练整个学生UNet架构(涉及数亿参数)的情况形成鲜明对比。在图4中提供了1、2和4个NFE下生成样本的视觉可视化。
在这里插入图片描述
在这里插入图片描述

  以上就是对本篇论文的解读,如有任何问题欢迎留言,批评指正!

  • 17
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

I松风水月

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值