- 前面两篇文章给出了 DDPM 的两种推导,“DDPM = 拆楼 + 建楼” 更为直白易懂,但无法做更多的理论延伸和定量理解,“DDPM = 自回归式 VAE” 理论分析上更加完备一些,但稍显形式化,启发性不足。下面再分享 DDPM 的一种推导,它主要利用到了贝叶斯定理来简化计算,整个过程的 “推敲” 味道颇浓,很有启发性。不仅如此,它还跟 DDIM 模型有着紧密的联系
请贝叶斯
- 利用贝叶斯公式,理论上我们想要获得如下生成过程
p
(
x
t
−
1
∣
x
t
)
p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t)
p(xt−1∣xt) 的表示
p ( x t − 1 ∣ x t ) = p ( x t ∣ x t − 1 ) p ( x t − 1 ) p ( x t ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) = \frac{p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})p(\boldsymbol{x}_{t-1})}{p(\boldsymbol{x}_t)} p(xt−1∣xt)=p(xt)p(xt∣xt−1)p(xt−1)然而,我们并不知道 p ( x t − 1 ) , p ( x t ) p(\boldsymbol{x}_{t-1}),p(\boldsymbol{x}_t) p(xt−1),p(xt), p ( x t − 1 ) , p ( x t ) p(\boldsymbol{x}_{t-1}),p(\boldsymbol{x}_t) p(xt−1),p(xt) 的表达式,所以此路不通。但我们可以退而求其次,在给定 x 0 \boldsymbol{x}_0 x0 的条件下使用贝叶斯定理:
p ( x t − 1 ∣ x t , x 0 ) = p ( x t ∣ x t − 1 ) p ( x t − 1 ∣ x 0 ) p ( x t ∣ x 0 ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{x}_0) = \frac{p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)}{p(\boldsymbol{x}_t|\boldsymbol{x}_0)} p(xt−1∣xt,x0)=p(xt∣x0)p(xt∣xt−1)p(xt−1∣x0)其中 p ( x t ∣ x t − 1 ) = N ( x t ; α t x t − 1 , β t 2 I ) p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})=\mathcal{N}(\boldsymbol{x}_t;\alpha_t \boldsymbol{x}_{t-1}, \beta_t^2 \boldsymbol{I}) p(xt∣xt−1)=N(xt;αtxt−1,βt2I), p ( x t − 1 ∣ x 0 ) = N ( x t − 1 ; α ˉ t − 1 x 0 , β ˉ t − 1 2 I ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)=\mathcal{N}(\boldsymbol{x}_{t-1};\bar{\alpha}_{t-1} \boldsymbol{x}_0, \bar{\beta}_{t-1}^2 \boldsymbol{I}) p(xt−1∣x0)=N(xt−1;αˉt−1x0,βˉt−12I), p ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , β ˉ t 2 I ) p(\boldsymbol{x}_{t}|\boldsymbol{x}_0)=\mathcal{N}(\boldsymbol{x}_{t};\bar{\alpha}_{t} \boldsymbol{x}_0, \bar{\beta}_{t}^2 \boldsymbol{I}) p(xt∣x0)=N(xt;αˉtx0,βˉt2I). 代入可得指数部分除掉 − 1 / 2 −1/2 −1/2 因子外,结果是:
∥ x t − α t x t − 1 ∥ 2 β t 2 + ∥ x t − 1 − α ˉ t − 1 x 0 ∥ 2 β ˉ t − 1 2 − ∥ x t − α ˉ t x 0 ∥ 2 β ˉ t 2 \frac{\Vert \boldsymbol{x}_t - \alpha_t \boldsymbol{x}_{t-1}\Vert^2}{\beta_t^2} + \frac{\Vert \boldsymbol{x}_{t-1} - \bar{\alpha}_{t-1}\boldsymbol{x}_0\Vert^2}{\bar{\beta}_{t-1}^2} - \frac{\Vert \boldsymbol{x}_t - \bar{\alpha}_t \boldsymbol{x}_0\Vert^2}{\bar{\beta}_t^2} βt2∥xt−αtxt−1∥2+βˉt−12∥xt−1−αˉt−1x0∥2−βˉt2∥xt−αˉtx0∥2它关于 x t − 1 \boldsymbol{x}_{t-1} xt−1 是二次的,因此最终的分布必然也是正态分布,我们只需要求出其均值和协方差。不难看出,展开式中 ∥ x t − 1 ∥ 2 \Vert \boldsymbol{x}_{t-1}\Vert^2 ∥xt−1∥2 项的系数是
α t 2 β t 2 + 1 β ˉ t − 1 2 = α t 2 β ˉ t − 1 2 + β t 2 β ˉ t − 1 2 β t 2 = α t 2 ( 1 − α ˉ t − 1 2 ) + β t 2 β ˉ t − 1 2 β t 2 = 1 − α ˉ t 2 β ˉ t − 1 2 β t 2 = β ˉ t 2 β ˉ t − 1 2 β t 2 \frac{\alpha_t^2}{\beta_t^2} + \frac{1}{\bar{\beta}_{t-1}^2} = \frac{\alpha_t^2\bar{\beta}_{t-1}^2 + \beta_t^2}{\bar{\beta}_{t-1}^2 \beta_t^2} = \frac{\alpha_t^2(1-\bar{\alpha}_{t-1}^2) + \beta_t^2}{\bar{\beta}_{t-1}^2 \beta_t^2} = \frac{1-\bar{\alpha}_t^2}{\bar{\beta}_{t-1}^2 \beta_t^2} = \frac{\bar{\beta}_t^2}{\bar{\beta}_{t-1}^2 \beta_t^2} βt2αt2+βˉt−121=βˉt−12βt2αt2βˉt−12+βt2=βˉt−12βt2αt2(1−αˉt−12)+βt2=βˉt−12βt21−αˉt2=βˉt−12βt2βˉt2所以整理好的结果必然是 β ˉ t 2 β ˉ t − 1 2 β t 2 ∥ x t − 1 − μ ~ ( x t , x 0 ) ∥ 2 \frac{\bar{\beta}_t^2}{\bar{\beta}_{t-1}^2 \beta_t^2}\Vert \boldsymbol{x}_{t-1} - \tilde{\boldsymbol{\mu}}(\boldsymbol{x}_t, \boldsymbol{x}_0)\Vert^2 βˉt−12βt2βˉt2∥xt−1−μ~(xt,x0)∥2 的形式 (协方差矩阵必然是对角矩阵。此外,由于二次项系数都相同,因此协方差矩阵必为单位矩阵的倍数),这意味着协方差矩阵是 β ˉ t − 1 2 β t 2 β ˉ t 2 I \frac{\bar{\beta}_{t-1}^2 \beta_t^2}{\bar{\beta}_t^2}\boldsymbol{I} βˉt2βˉt−12βt2I。另一边,把一次项系数拿出来是 − 2 ( α t β t 2 x t + α ˉ t − 1 β ˉ t − 1 2 x 0 ) -2\left(\frac{\alpha_t}{\beta_t^2}\boldsymbol{x}_t + \frac{\bar{\alpha}_{t-1}}{\bar{\beta}_{t-1}^2}\boldsymbol{x}_0 \right) −2(βt2αtxt+βˉt−12αˉt−1x0),除以 − 2 β ˉ t 2 β ˉ t − 1 2 β t 2 \frac{-2\bar{\beta}_t^2}{\bar{\beta}_{t-1}^2 \beta_t^2} βˉt−12βt2−2βˉt2 后便可以得到
μ ~ ( x t , x 0 ) = α t β ˉ t − 1 2 β ˉ t 2 x t + α ˉ t − 1 β t 2 β ˉ t 2 x 0 \tilde{\boldsymbol{\mu}}(\boldsymbol{x}_t, \boldsymbol{x}_0)=\frac{\alpha_t\bar{\beta}_{t-1}^2}{\bar{\beta}_t^2}\boldsymbol{x}_t + \frac{\bar{\alpha}_{t-1}\beta_t^2}{\bar{\beta}_t^2}\boldsymbol{x}_0 μ~(xt,x0)=βˉt2αtβˉt−12xt+βˉt2αˉt−1βt2x0最终得到下式,它可以借助原图像完成对当前图像的去噪:
p ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; α t β ˉ t − 1 2 β ˉ t 2 x t + α ˉ t − 1 β t 2 β ˉ t 2 x 0 , β ˉ t − 1 2 β t 2 β ˉ t 2 I ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{x}_0) = \mathcal{N}\left(\boldsymbol{x}_{t-1};\frac{\alpha_t\bar{\beta}_{t-1}^2}{\bar{\beta}_t^2}\boldsymbol{x}_t + \frac{\bar{\alpha}_{t-1}\beta_t^2}{\bar{\beta}_t^2}\boldsymbol{x}_0,\frac{\bar{\beta}_{t-1}^2\beta_t^2}{\bar{\beta}_t^2} \boldsymbol{I}\right) p(xt−1∣xt,x0)=N(xt−1;βˉt2αtβˉt−12xt+βˉt2αˉt−1βt2x0,βˉt2βˉt−12βt2I)
去噪过程
- 下面我们需要在不借助原图像
x
0
\boldsymbol{x}_0
x0 的前提下完成去噪。一个 “异想天开” 的想法是用
μ
ˉ
(
x
t
)
\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)
μˉ(xt) 来预估
x
0
\boldsymbol{x}_0
x0,损失函数为
∥
x
0
−
μ
ˉ
(
x
t
)
∥
2
\Vert \boldsymbol{x}_0 - \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)\Vert^2
∥x0−μˉ(xt)∥2,这实际上在训练一个去噪模型,这也就是 DDPM 的第一个 “D” 的含义 (Denoising). 由于
x
0
=
1
α
ˉ
t
(
x
t
−
β
ˉ
t
ε
)
\boldsymbol{x}_0 = \frac{1}{\bar{\alpha}_t}\left(\boldsymbol{x}_t - \bar{\beta}_t \boldsymbol{\varepsilon}\right)
x0=αˉt1(xt−βˉtε),因此将
μ
ˉ
(
x
t
)
\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)
μˉ(xt) 参数化为
μ ˉ ( x t ) = 1 α ˉ t ( x t − β ˉ t ϵ θ ( x t , t ) ) \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) = \frac{1}{\bar{\alpha}_t}\left(\boldsymbol{x}_t - \bar{\beta}_t \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right) μˉ(xt)=αˉt1(xt−βˉtϵθ(xt,t))此时损失函数变为
∥ x 0 − μ ˉ ( x t ) ∥ 2 = β ˉ t 2 α ˉ t 2 ∥ ε − ϵ θ ( α ˉ t x 0 + β ˉ t ε , t ) ∥ 2 \Vert \boldsymbol{x}_0 - \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)\Vert^2 = \frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\left\Vert\boldsymbol{\varepsilon} - \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\bar{\alpha}_t \boldsymbol{x}_0 + \bar{\beta}_t \boldsymbol{\varepsilon}, t)\right\Vert^2 ∥x0−μˉ(xt)∥2=αˉt2βˉt2∥∥ε−ϵθ(αˉtx0+βˉtε,t)∥∥2省去前面的系数,就得到 DDPM 原论文所用的损失函数了 (提示:出于推导的流畅性考虑,这里的 ϵ θ \boldsymbol{\epsilon}_{\boldsymbol{\theta}} ϵθ 跟前两个视角介绍不一样,反而跟 DDPM 原论文一致)。可以发现,这里是直接得出了从 x t \boldsymbol{x}_t xt 到 x 0 \boldsymbol{x}_0 x0 的去噪过程,而不是像之前两个视角那样,通过 x t \boldsymbol{x}_t xt 到 x t − 1 \boldsymbol{x}_{t-1} xt−1 的去噪过程再加上积分变换来推导,相比之下这里的推导可谓更加一步到位了 - 训练完成后,我们就认为
p ( x t − 1 ∣ x t ) ≈ p ( x t − 1 ∣ x t , x 0 = μ ˉ ( x t ) ) = N ( x t − 1 ; α t β ˉ t − 1 2 β ˉ t 2 x t + α ˉ t − 1 β t 2 β ˉ t 2 μ ˉ ( x t ) , β ˉ t − 1 2 β t 2 β ˉ t 2 I ) = N ( x t − 1 ; 1 α t ( x t − β t 2 β ˉ t ϵ θ ( x t , t ) ) , β ˉ t − 1 2 β t 2 β ˉ t 2 I ) \begin{aligned} p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) &\approx p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{x}_0=\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t)) \\&= \mathcal{N}\left(\boldsymbol{x}_{t-1}; \frac{\alpha_t\bar{\beta}_{t-1}^2}{\bar{\beta}_t^2}\boldsymbol{x}_t + \frac{\bar{\alpha}_{t-1}\beta_t^2}{\bar{\beta}_t^2}\bar{\boldsymbol{\mu}}(\boldsymbol{x}_t),\frac{\bar{\beta}_{t-1}^2\beta_t^2}{\bar{\beta}_t^2} \boldsymbol{I}\right) \\&= \mathcal{N}\left(\boldsymbol{x}_{t-1}; \frac{1}{\alpha_t}\left(\boldsymbol{x}_t - \frac{\beta_t^2}{\bar{\beta}_t}\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\right),\frac{\bar{\beta}_{t-1}^2\beta_t^2}{\bar{\beta}_t^2} \boldsymbol{I}\right) \end{aligned} p(xt−1∣xt)≈p(xt−1∣xt,x0=μˉ(xt))=N(xt−1;βˉt2αtβˉt−12xt+βˉt2αˉt−1βt2μˉ(xt),βˉt2βˉt−12βt2I)=N(xt−1;αt1(xt−βˉtβt2ϵθ(xt,t)),βˉt2βˉt−12βt2I)这就是反向的采样过程所用的分布,连同采样过程所用的方差也一并确定下来了
预估修正
- 不知道读者有没有留意到一个有趣的地方:我们要做的事情,就是想将 x T \boldsymbol{x}_T xT 慢慢地变为 x 0 \boldsymbol{x}_0 x0,而我们在借用 p ( x t − 1 ∣ x t , x 0 ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{x}_0) p(xt−1∣xt,x0) 近似 p ( x t − 1 ∣ x t ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) p(xt−1∣xt) 时,却包含了 “用 μ ˉ ( x t ) \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) μˉ(xt) 来预估 x 0 \boldsymbol{x}_0 x0” 这一步,要是能预估准的话,那就直接一步到位了,还需要逐步采样吗?
- 真实情况是,“用 μ ˉ ( x t ) \bar{\boldsymbol{\mu}}(\boldsymbol{x}_t) μˉ(xt) 来预估 x 0 \boldsymbol{x}_0 x0” 当然不会太准的,至少开始的相当多步内不会太准。它仅仅起到了一个前瞻性的预估作用,然后我们只用 p ( x t − 1 ∣ x t ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) p(xt−1∣xt) 来推进一小步,这就是很多数值算法中的 “预估-修正” 思想,即我们用一个粗糙的解往前推很多步,然后利用这个粗糙的结果将最终结果推进一小步,以此来逐步获得更为精细的解
Random Sample - 方差选取
- (1) 假设整个数据集只有一个样本,不失一般性,假设该样本为
0
\boldsymbol{0}
0,此时
p
~
(
x
0
)
\tilde{p}(\boldsymbol{x}_0)
p~(x0) 为狄拉克分布
δ
(
x
0
)
\delta(\boldsymbol{x}_0)
δ(x0),可以直接算出
p
(
x
t
)
=
p
(
x
t
∣
0
)
p(\boldsymbol{x}_t)=p(\boldsymbol{x}_t|\boldsymbol{0})
p(xt)=p(xt∣0)。代入下式
p ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; α t β ˉ t − 1 2 β ˉ t 2 x t + α ˉ t − 1 β t 2 β ˉ t 2 x 0 , β ˉ t − 1 2 β t 2 β ˉ t 2 I ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{x}_0) = \mathcal{N}\left(\boldsymbol{x}_{t-1};\frac{\alpha_t\bar{\beta}_{t-1}^2}{\bar{\beta}_t^2}\boldsymbol{x}_t + \frac{\bar{\alpha}_{t-1}\beta_t^2}{\bar{\beta}_t^2}\boldsymbol{x}_0,\frac{\bar{\beta}_{t-1}^2\beta_t^2}{\bar{\beta}_t^2} \boldsymbol{I}\right) p(xt−1∣xt,x0)=N(xt−1;βˉt2αtβˉt−12xt+βˉt2αˉt−1βt2x0,βˉt2βˉt−12βt2I)有
p ( x t − 1 ∣ x t ) = p ( x t − 1 ∣ x t , x 0 = 0 ) = N ( x t − 1 ; α t β ˉ t − 1 2 β ˉ t 2 x t , β ˉ t − 1 2 β t 2 β ˉ t 2 I ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) = p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{x}_0=\boldsymbol{0}) = \mathcal{N}\left(\boldsymbol{x}_{t-1};\frac{\alpha_t\bar{\beta}_{t-1}^2}{\bar{\beta}_t^2}\boldsymbol{x}_t,\frac{\bar{\beta}_{t-1}^2\beta_t^2}{\bar{\beta}_t^2} \boldsymbol{I}\right) p(xt−1∣xt)=p(xt−1∣xt,x0=0)=N(xt−1;βˉt2αtβˉt−12xt,βˉt2βˉt−12βt2I)我们主要关心其方差为 β ˉ t − 1 2 β t 2 β ˉ t 2 \frac{\bar{\beta}_{t-1}^2\beta_t^2}{\bar{\beta}_t^2} βˉt2βˉt−12βt2,这便是采样方差的选择之一 - (2) 假设数据集服从标准正态分布,即
p
~
(
x
0
)
=
N
(
x
0
;
0
,
I
)
\tilde{p}(\boldsymbol{x}_0)=\mathcal{N}(\boldsymbol{x}_0;\boldsymbol{0},\boldsymbol{I})
p~(x0)=N(x0;0,I)。由于
x
t
=
α
ˉ
t
x
0
+
β
ˉ
t
ε
,
ε
∼
N
(
0
,
I
)
\boldsymbol{x}_t = \bar{\alpha}_t \boldsymbol{x}_0 + \bar{\beta}_t \boldsymbol{\varepsilon},\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})
xt=αˉtx0+βˉtε,ε∼N(0,I),
x
0
∼
N
(
0
,
I
)
\boldsymbol{x}_0\sim\mathcal{N}(\boldsymbol{0}, \boldsymbol{I})
x0∼N(0,I),所以由正态分布的叠加性,
x
t
\boldsymbol{x}_t
xt 正好也服从标准正态分布。现在有
p
(
x
t
∣
x
t
−
1
)
=
N
(
x
t
;
α
t
x
t
−
1
,
β
t
2
I
)
p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})=\mathcal{N}(\boldsymbol{x}_t;\alpha_t \boldsymbol{x}_{t-1}, \beta_t^2 \boldsymbol{I})
p(xt∣xt−1)=N(xt;αtxt−1,βt2I),
p
(
x
t
−
1
∣
x
0
)
=
N
(
x
t
−
1
;
0
,
I
)
p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)=\mathcal{N}(\boldsymbol{x}_{t-1};0, \boldsymbol{I})
p(xt−1∣x0)=N(xt−1;0,I),
p
(
x
t
∣
x
0
)
=
N
(
x
t
;
0
,
I
)
p(\boldsymbol{x}_{t}|\boldsymbol{x}_0)=\mathcal{N}(\boldsymbol{x}_{t};0, \boldsymbol{I})
p(xt∣x0)=N(xt;0,I). 将标准正态分布的概率密度代入
p
(
x
t
−
1
∣
x
t
,
x
0
)
=
p
(
x
t
∣
x
t
−
1
)
p
(
x
t
−
1
∣
x
0
)
p
(
x
t
∣
x
0
)
p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{x}_0) = \frac{p(\boldsymbol{x}_t|\boldsymbol{x}_{t-1})p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_0)}{p(\boldsymbol{x}_t|\boldsymbol{x}_0)}
p(xt−1∣xt,x0)=p(xt∣x0)p(xt∣xt−1)p(xt−1∣x0), 结果的指数部分除掉
−
1
/
2
−1/2
−1/2 因子外,结果是:
∥ x t − α t x t − 1 ∥ 2 β t 2 + ∥ x t − 1 ∥ 2 − ∥ x t ∥ 2 \frac{\Vert \boldsymbol{x}_t - \alpha_t \boldsymbol{x}_{t-1}\Vert^2}{\beta_t^2} + \Vert \boldsymbol{x}_{t-1}\Vert^2 - \Vert \boldsymbol{x}_t\Vert^2 βt2∥xt−αtxt−1∥2+∥xt−1∥2−∥xt∥2跟推导 p ( x t − 1 ∣ x t , x 0 ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0) p(xt−1∣xt,x0) 的过程类似,可以得到上述指数对应于
p ( x t − 1 ∣ x t ) = N ( x t − 1 ; α t x t , β t 2 I ) p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) = \mathcal{N}\left(\boldsymbol{x}_{t-1};\alpha_t\boldsymbol{x}_t,\beta_t^2 \boldsymbol{I}\right) p(xt−1∣xt)=N(xt−1;αtxt,βt2I)我们同样主要关心其方差为 β t 2 \beta_t^2 βt2,这便是采样方差的另一个选择
References
- 苏剑林. (Jul. 19, 2022). 《生成扩散模型漫谈(三):DDPM = 贝叶斯 + 去噪 》[Blog post]. Retrieved from https://kexue.fm/archives/9164