渐进蒸馏和v-prediction
TL;DR:比较早期的用蒸馏的思想来做扩散模型采样加速的方法,通过渐进地对预训练的扩散模型进行蒸馏,学生模型一步学习教师模型两步的去噪结果,不断降低采样步数。并提出一种新的参数化形式 v \mathbf{v} v-prediction 来解决渐进蒸馏过程中信噪比太低时误差影响较大的问题。
渐进蒸馏
在一开始,我们有一个预训练的原始扩散模型作为初始教师模型。我们首先将学生模型初始化为一个结构、参数都与教师模型一模一样的扩散模型。然后,不断采样干净图像数据,加噪声,训练学生模型的去噪能力。由于我们要进行蒸馏,所以这里学生模型的预测目标不是干净的图片 x \mathbf{x} x,而是要学生模型单步(DDIM)预测出教师模型两步(DDIM)的去噪结果 x ~ \tilde{\mathbf{x}} x~。
具体来说,我们这里考虑的是连续时间步 t ∈ [ 0 , 1 ] t\in[0,1] t∈[0,1],目标步数(即学生模型的步数)为 N N N,从而步长是 1 / N 1/N 1/N,在时刻 t t t 是要去噪从 z t \mathbf{z}_t zt 到 z t − 1 / N \mathbf{z}_{t-1/N} zt−1/N。这样教师模型的步数是 2 N 2N 2N,每一步是从 z t \mathbf{z}_{t} zt 到 z t − 0.5 / N \mathbf{z}_{t-0.5/N} zt−0.5/N。我们这里连续运行教师模型两步,即从 z t \mathbf{z}_t zt 到 z t − 0.5 / N \mathbf{z}_{t-0.5/N} zt−0.5/N 再到 z t − 1 / N \mathbf{z}_{t-1/N} zt−1/N,我们的学生模型训练目标就是要一步直接从 z t \mathbf{z}_t zt 预测出教师模型的两步去噪的结果 z t − 1 / N \mathbf{z}_{t-1/N} zt−1/N。
在收敛之后,我们将当前的学生模型作为下一轮的噪声模型,再将自身进行拷贝重新初始化一个新的学生模型,重复上述步骤。循环往复,即可通过渐进蒸馏不断降低模型的采样步数。
下面是渐进蒸馏的算法流程,对比了标准的扩散模型训练流程,主要就是将模型的预测目标从上一步的加噪结果改换成了教师模型的两步去噪结果,并渐进式地迭代这一过程。
参数化形式和训练损失
自从 DDPM 以来,扩散模型的参数化形式一般都是 ϵ \epsilon ϵ-prediction,即预测噪声,再根据噪声计算出数据 x \mathbf{x} x。相当于间接地预测 x \mathbf{x} x: x ^ θ ( z t ) = 1 α t ( z t − σ t ϵ ^ θ ( z t ) ) \hat{\mathbf{x}}_\theta(\mathbf{z}_t)=\frac{1}{\alpha_t}(\mathbf{z}_t-\sigma_t\hat\epsilon_\theta(\mathbf{z}_t)) x^θ(zt)=αt1(zt−σtϵ^θ(zt))。
在常规的扩散模型训练以及渐进蒸馏训练的早期(步数还比较多时),噪声预测的参数化形式工作得很好。因为这时信噪比 α t 2 / σ t 2 \alpha_t^2/\sigma_t^2 αt2/σt2 在一个比较宽的范围内。当随着渐进蒸馏的进行,步数越来越少,信噪比越来越低以至于接近于 0,此时 α t \alpha_t αt 接近于 0。根据上式, α t \alpha_t αt 在间接预测 x ^ θ ( z t ) \hat{\mathbf{x}}_\theta(\mathbf{z}_t) x^θ(zt) 公式的分母上,因此此时网络输出预测噪声 ϵ ^ θ ( z t ) \hat{\epsilon}_\theta(\mathbf{z}_t) ϵ^θ(zt) 都会噪声 x \mathbf{x} x 的巨大变化,从而导致训练不稳定。并且渐进蒸馏后期步数较少,无法通过后面的步数进行修正。
最终,如果我们将模型蒸馏到只剩下一个采样步,那么模型的输入就只是纯噪声 ϵ \epsilon ϵ,此时信噪比为零,即 α t = 0 , σ t = 1 \alpha_t = 0, \sigma_t = 1 αt=0,σt=1。在这种极端情况下, ϵ \epsilon ϵ 预测和 x \mathbf{x} x 预测之间的联系完全中断:观测数据 z t = ϵ z_t = \epsilon zt=ϵ 不再包含 x \mathbf{x} x 的信息,并且 ϵ \epsilon ϵ 的预测 ϵ ^ θ ( z t ) \hat{\epsilon}_{\theta}(\mathbf{z}_t) ϵ^θ(zt) 也无法再间接地预测 x \mathbf{x} x。在损失函数中,加权函数 w ( λ t ) w(\lambda_t) w(λt) 在此时的权重也成了 0。
为了解决这一问题,作者尝试了直接预测
x
\mathbf{x}
x、同时分别预测
x
\mathbf{x}
x 和
ϵ
\epsilon
ϵ 后合并出
x
^
\hat{\mathbf{x}}
x^,还提出了一种新的参数化形式
v
\mathbf{v}
v-prediction:
v
≡
α
t
ϵ
−
σ
t
x
\mathbf{v}\equiv \alpha_t\epsilon-\sigma_t\mathbf{x}
v≡αtϵ−σtx
从而:
x
^
=
α
t
z
t
−
σ
t
v
^
θ
(
z
t
)
\hat{\mathbf{x}}=\alpha_t\mathbf{z}_t-\sigma_t\hat{\mathbf{v}}_\theta(\mathbf{z}_t)
x^=αtzt−σtv^θ(zt)
实验显示,这三种方式在渐进蒸馏训练中都表现得不错,并在在常规扩散模型的训练中效果也很好。
下面对作者设计的 v \mathbf{v} v-prediction 进行推导:
DDPM 的加噪公式:
z
t
=
α
t
x
+
σ
t
ϵ
\mathbf{z}_t=\alpha_t\mathbf{x}+\sigma_t\epsilon
zt=αtx+σtϵ
令
ϕ
t
=
arctan
(
σ
t
/
α
t
)
\phi_t=\arctan(\sigma_t/\alpha_t)
ϕt=arctan(σt/αt),则有
α
t
=
cos
(
ϕ
)
,
σ
t
=
sin
(
ϕ
)
\alpha_t=\cos(\phi),\sigma_t=\sin(\phi)
αt=cos(ϕ),σt=sin(ϕ),从而:
z
ϕ
=
cos
(
ϕ
)
x
+
sin
(
ϕ
)
ϵ
\mathbf{z}_\phi=\cos(\phi)\mathbf{x}+\sin(\phi)\epsilon
zϕ=cos(ϕ)x+sin(ϕ)ϵ
定义
z
ϕ
z_\phi
zϕ 的 “速度” 为其关于
ϕ
\phi
ϕ 的导数:
v
ϕ
≡
d
z
ϕ
d
ϕ
=
d
cos
(
ϕ
)
d
ϕ
x
+
d
sin
ϕ
d
ϕ
ϵ
=
sin
(
ϕ
)
x
−
cos
(
ϕ
)
ϵ
\mathbf{v}_\phi\equiv\frac{d\mathbf{z}_\phi}{d\phi}=\frac{d\cos(\phi)}{d\phi}\mathbf{x}+\frac{d\sin{\phi}}{d\phi}\epsilon=\sin(\phi)\mathbf{x}-\cos(\phi)\epsilon
vϕ≡dϕdzϕ=dϕdcos(ϕ)x+dϕdsinϕϵ=sin(ϕ)x−cos(ϕ)ϵ
这里就是上面
v
\mathbf{v}
v 的定义
v
≡
α
t
ϵ
−
σ
t
x
\mathbf{v}\equiv \alpha_t\epsilon-\sigma_t\mathbf{x}
v≡αtϵ−σtx。稍微进行变换,有:
sin
(
ϕ
)
=
cos
(
ϕ
)
ϵ
−
v
ϕ
=
cos
(
ϕ
)
sin
(
ϕ
)
(
z
−
cos
(
ϕ
)
x
)
−
v
ϕ
sin
2
(
ϕ
)
x
=
cos
(
ϕ
)
z
−
cos
2
(
ϕ
)
x
−
sin
(
ϕ
)
v
ϕ
sin
2
(
ϕ
)
x
+
cos
2
(
ϕ
)
x
=
cos
(
ϕ
)
z
−
sin
(
ϕ
)
v
ϕ
x
=
cos
(
ϕ
)
z
−
sin
(
ϕ
)
v
ϕ
\begin{align} \sin(\phi)&=\cos(\phi)\epsilon-\mathbf{v}_\phi\\ &=\frac{\cos(\phi)}{\sin(\phi)}(\mathbf{z}-\cos(\phi)\mathbf{x})-\mathbf{v}_\phi\\ \sin^2(\phi)\mathbf{x}&=\cos(\phi)\mathbf{z}-\cos^2(\phi)\mathbf{x}-\sin(\phi)\mathbf{v}_\phi\\ \sin^2(\phi)\mathbf{x}+\cos^2(\phi)\mathbf{x}&=\cos(\phi)\mathbf{z}-\sin(\phi)\mathbf{v}_\phi\\ \mathbf{x}&=\cos(\phi)\mathbf{z}-\sin(\phi)\mathbf{v}_\phi \end{align}
sin(ϕ)sin2(ϕ)xsin2(ϕ)x+cos2(ϕ)xx=cos(ϕ)ϵ−vϕ=sin(ϕ)cos(ϕ)(z−cos(ϕ)x)−vϕ=cos(ϕ)z−cos2(ϕ)x−sin(ϕ)vϕ=cos(ϕ)z−sin(ϕ)vϕ=cos(ϕ)z−sin(ϕ)vϕ
这里就是上面的第二个公式
x
^
=
α
t
z
t
−
σ
t
v
^
θ
(
z
t
)
\hat{\mathbf{x}}=\alpha_t\mathbf{z}_t-\sigma_t\hat{\mathbf{v}}_\theta(\mathbf{z}_t)
x^=αtzt−σtv^θ(zt)。这个推导过程可以参考下图来理解。
总结
早期提出的渐进蒸馏是一种比较直觉的扩散模型步数蒸馏方法,其提出的 v-prediction 在后来也有广泛的应用。