系统理解扩散模型(Diffusion Models):从柏拉图洞穴之喻开始(中)
变分扩散模型(Variational Diffusion Models)
设想我们给HVAE模型增添三个限制条件:
- 潜在变量的维度和数据的维度相同;
- 每一层级的编码器不是通过学习得到的,而是事先预定好的线性高斯模型;
- 最后一层(第 T T T层)的潜在变量分布是一个标准高斯分布。
加上上述条件的HVAE模型就是变分扩散模型(Variational Diffusion Models,VDM)。
由第一个条件,我们可以先统一符号:用
x
0
x_0
x0表示真实的数据样本,而用
x
t
,
t
∈
[
1
,
T
]
x_t, t \in [1, T]
xt,t∈[1,T]表示对应第
t
t
t层的潜在变量。此时,后验分布可以重写为:
q
(
x
1
:
T
∣
x
0
)
=
∏
t
=
1
T
q
(
x
t
∣
x
t
−
1
)
\begin{equation} q(x_{1:T}|x_0) = \prod_{t=1}^T q(x_t|x_{t-1}) \end{equation}
q(x1:T∣x0)=t=1∏Tq(xt∣xt−1)
根据第二个条件,我们将高斯编码器的均值设置为
μ
t
(
x
t
)
=
α
t
x
t
−
1
\mu_t(x_t)=\sqrt \alpha_t x_{t-1}
μt(xt)=αtxt−1,并将其方差设置为
Σ
t
(
x
t
)
=
(
1
−
α
t
)
I
\Sigma_t(x_t) = (1-\alpha_t) I
Σt(xt)=(1−αt)I。此时,编码器可以表示为:
q
(
x
t
∣
x
t
−
1
)
=
N
(
x
t
;
α
t
x
t
−
1
,
(
1
−
α
t
)
I
)
\begin{equation} q(x_t|x_{t-1}) = \mathcal{N}(x_t; \sqrt \alpha_t x_{t-1}, (1-\alpha_t) I) \end{equation}
q(xt∣xt−1)=N(xt;αtxt−1,(1−αt)I)
根据第三个条件,
α
t
\alpha_t
αt的值需要遵循一定的规律,使得最后一层的潜在分布
p
(
x
T
)
p(x_T)
p(xT)是一个标准高斯分布。此时,VDM的联合分布可以重写为
p
(
x
0
:
T
)
=
p
(
x
T
)
∏
t
=
1
T
p
θ
(
x
t
−
1
∣
x
t
)
where
p
(
x
T
)
=
N
(
x
T
;
0
,
I
)
\begin{align} p(x_{0:T}) &= p(x_T)\prod_{t=1}^{T}p_\theta(x_{t-1}|x_t) \\ &\text{where} \\ &p(x_T) = \mathcal{N}(x_T; 0, I) \end{align}
p(x0:T)=p(xT)t=1∏Tpθ(xt−1∣xt)wherep(xT)=N(xT;0,I)
如果以图片为输入,这就相当于不断给这张图片加上一系列的噪声,只至输出为纯高斯噪声。值得注意的是,由于编码过程就是按照既定过程加高斯噪声,编码器分布 q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q(xt∣xt−1)不再有参数 ϕ \phi ϕ。也就是说,对于VDM模型,我们关注 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) pθ(xt−1∣xt)即可,并以此生成新数据。具体而言,训练完成后,我们从 p ( x T ) p(x_T) p(xT)采样出高斯噪声,然后逐步执行去噪过程 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) pθ(xt−1∣xt)生成新的 x 0 x_0 x0。
相似地,我们可以通过最大化ELBO来优化VDM:
log
p
(
x
)
=
log
∫
p
(
x
0
:
T
)
d
x
1
:
T
=
log
∫
p
(
x
0
:
T
)
q
(
x
1
:
T
∣
x
0
)
q
(
x
1
:
T
∣
x
0
)
d
x
1
:
T
=
log
E
q
(
x
1
:
T
∣
x
0
)
[
p
(
x
0
:
T
)
q
(
x
1
:
T
∣
x
0
)
]
≥
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
0
:
T
)
q
(
x
1
:
T
∣
x
0
)
]
=
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
T
)
∏
t
=
1
T
p
θ
(
x
t
−
1
∣
x
t
)
∏
t
=
1
T
q
(
x
t
∣
x
t
−
1
)
]
=
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
T
)
p
θ
(
x
0
∣
x
1
)
∏
t
=
2
T
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
T
∣
x
T
−
1
)
∏
t
=
1
T
−
1
q
(
x
t
∣
x
t
−
1
)
]
=
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
T
)
p
θ
(
x
0
∣
x
1
)
∏
t
=
1
T
−
1
p
θ
(
x
t
∣
x
t
+
1
)
q
(
x
T
∣
x
T
−
1
)
∏
t
=
1
T
−
1
q
(
x
t
∣
x
t
−
1
)
]
=
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
T
)
p
θ
(
x
0
∣
x
1
)
q
(
x
T
∣
x
T
−
1
)
]
+
E
q
(
x
1
:
T
∣
x
0
)
[
log
∏
t
=
1
T
−
1
p
θ
(
x
t
∣
x
t
+
1
)
q
(
x
t
∣
x
t
−
1
)
]
=
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
θ
(
x
0
∣
x
1
)
]
+
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
T
)
q
(
x
T
∣
x
T
−
1
)
]
+
E
q
(
x
1
:
T
∣
x
0
)
[
∑
t
=
1
T
−
1
log
p
θ
(
x
t
∣
x
t
+
1
)
q
(
x
t
∣
x
t
−
1
)
]
=
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
θ
(
x
0
∣
x
1
)
]
+
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
T
)
q
(
x
T
∣
x
T
−
1
)
]
+
∑
t
=
1
T
−
1
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
θ
(
x
t
∣
x
t
+
1
)
q
(
x
t
∣
x
t
−
1
)
]
=
E
q
(
x
1
∣
x
0
)
[
log
p
θ
(
x
0
∣
x
1
)
]
+
E
q
(
x
T
−
1
,
x
T
∣
x
0
)
[
log
p
(
x
T
)
q
(
x
T
∣
x
T
−
1
)
]
+
∑
t
=
1
T
−
1
E
q
(
x
t
−
1
,
x
t
,
x
t
+
1
∣
x
0
)
[
log
p
θ
(
x
t
∣
x
t
+
1
)
q
(
x
t
∣
x
t
−
1
)
]
=
E
q
(
x
1
∣
x
0
)
[
log
p
θ
(
x
0
∣
x
1
)
]
−
E
q
(
x
T
−
1
∣
x
0
)
[
D
K
L
(
q
(
x
T
∣
x
T
−
1
)
∣
∣
p
(
x
T
)
)
]
−
∑
t
=
1
T
−
1
E
q
(
x
t
−
1
,
x
t
+
1
∣
x
0
)
[
D
K
L
(
q
(
x
t
∣
x
t
−
1
)
∣
∣
p
θ
(
x
t
∣
x
t
+
1
)
)
]
\begin{align} \log p(x) &= \log \int p(x_{0:T}) dx_{1:T} \\ &= \log \int \frac{p(x_{0:T})q(x_{1:T}|x_0)}{q(x_{1:T}|x_0)}dx_{1:T} \\ &= \log \mathbb{E}_{q(x_{1:T}|x_0)}\left[\frac{p(x_{0:T})}{q(x_{1:T}|x_0)}\right] \\ & \geq \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_{0:T})}{q(x_{1:T}|x_0)}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)\prod_{t=1}^{T}p_\theta(x_{t-1}|x_t)}{\prod_{t=1}^T q(x_t|x_{t-1})}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)p_\theta(x_0|x_1)\prod_{t=2}^{T}p_\theta(x_{t-1}|x_t)}{q(x_T|x_{T-1}) \prod_{t=1}^{T-1} q(x_t|x_{t-1})}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)p_\theta(x_0|x_1)\prod_{t=1}^{T-1}p_\theta(x_{t}|x_{t+1})}{q(x_T|x_{T-1}) \prod_{t=1}^{T-1} q(x_t|x_{t-1})}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)p_\theta(x_0|x_1)}{q(x_T|x_{T-1}) }\right] \\ &\quad+ \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\prod_{t=1}^{T-1}\frac{p_\theta(x_{t}|x_{t+1})}{q(x_t|x_{t-1})}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log p_\theta(x_0|x_1)\right] + \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)}{q(x_T|x_{T-1}) }\right]\\ &\quad+\mathbb{E}_{q(x_{1:T}|x_0)}\left[\sum_{t=1}^{T-1}\log\frac{p_\theta(x_{t}|x_{t+1})}{q(x_t|x_{t-1})}\right] \\ &=\mathbb{E}_{q(x_{1:T}|x_0)}\left[\log p_\theta(x_0|x_1)\right] + \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)}{q(x_T|x_{T-1}) }\right]\\ &\quad+\sum_{t=1}^{T-1}\mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p_\theta(x_{t}|x_{t+1})}{q(x_t|x_{t-1})}\right] \\ &=\mathbb{E}_{q(x_{1}|x_0)}\left[\log p_\theta(x_0|x_1)\right] + \mathbb{E}_{q(x_{T-1}, x_T|x_0)}\left[\log\frac{p(x_T)}{q(x_T|x_{T-1}) }\right]\\ &\quad+\sum_{t=1}^{T-1}\mathbb{E}_{q(x_{t-1}, x_t, x_{t+1}|x_0)}\left[\log\frac{p_\theta(x_{t}|x_{t+1})}{q(x_t|x_{t-1})}\right] \\ &=\mathbb{E}_{q(x_{1}|x_0)}\left[\log p_\theta(x_0|x_1)\right] \\ &\quad- \mathbb{E}_{q(x_{T-1}|x_0)}\left[D_{KL}(q(x_T|x_{T-1}) \ ||\ p(x_T))\right] \\ &\quad-\sum_{t=1}^{T-1}\mathbb{E}_{q(x_{t-1}, x_{t+1}|x_0)}\left[D_{KL}(q(x_t|x_{t-1})\ ||\ p_\theta(x_{t}|x_{t+1}))\right] \end{align}
logp(x)=log∫p(x0:T)dx1:T=log∫q(x1:T∣x0)p(x0:T)q(x1:T∣x0)dx1:T=logEq(x1:T∣x0)[q(x1:T∣x0)p(x0:T)]≥Eq(x1:T∣x0)[logq(x1:T∣x0)p(x0:T)]=Eq(x1:T∣x0)[log∏t=1Tq(xt∣xt−1)p(xT)∏t=1Tpθ(xt−1∣xt)]=Eq(x1:T∣x0)[logq(xT∣xT−1)∏t=1T−1q(xt∣xt−1)p(xT)pθ(x0∣x1)∏t=2Tpθ(xt−1∣xt)]=Eq(x1:T∣x0)[logq(xT∣xT−1)∏t=1T−1q(xt∣xt−1)p(xT)pθ(x0∣x1)∏t=1T−1pθ(xt∣xt+1)]=Eq(x1:T∣x0)[logq(xT∣xT−1)p(xT)pθ(x0∣x1)]+Eq(x1:T∣x0)[logt=1∏T−1q(xt∣xt−1)pθ(xt∣xt+1)]=Eq(x1:T∣x0)[logpθ(x0∣x1)]+Eq(x1:T∣x0)[logq(xT∣xT−1)p(xT)]+Eq(x1:T∣x0)[t=1∑T−1logq(xt∣xt−1)pθ(xt∣xt+1)]=Eq(x1:T∣x0)[logpθ(x0∣x1)]+Eq(x1:T∣x0)[logq(xT∣xT−1)p(xT)]+t=1∑T−1Eq(x1:T∣x0)[logq(xt∣xt−1)pθ(xt∣xt+1)]=Eq(x1∣x0)[logpθ(x0∣x1)]+Eq(xT−1,xT∣x0)[logq(xT∣xT−1)p(xT)]+t=1∑T−1Eq(xt−1,xt,xt+1∣x0)[logq(xt∣xt−1)pθ(xt∣xt+1)]=Eq(x1∣x0)[logpθ(x0∣x1)]−Eq(xT−1∣x0)[DKL(q(xT∣xT−1) ∣∣ p(xT))]−t=1∑T−1Eq(xt−1,xt+1∣x0)[DKL(q(xt∣xt−1) ∣∣ pθ(xt∣xt+1))]
- E q ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] \mathbb{E}_{q(x_{1}|x_0)}\left[\log p_\theta(x_0|x_1)\right] Eq(x1∣x0)[logpθ(x0∣x1)]是恢复项,表示了给定 x 1 x_1 x1预测原始数据样本 x 0 x_0 x0的对数概率;
- E q ( x T − 1 ∣ x 0 ) [ D K L ( q ( x T ∣ x T − 1 ) ∣ ∣ p ( x T ) ) ] \mathbb{E}_{q(x_{T-1}|x_0)}\left[D_{KL}(q(x_T|x_{T-1}) \ ||\ p(x_T))\right] Eq(xT−1∣x0)[DKL(q(xT∣xT−1) ∣∣ p(xT))]是先验匹配项,要求最后的潜在分布和高斯先验一致,注意,这一项没有可优化的参数;
- E q ( x t − 1 , x t + 1 ∣ x 0 ) [ D K L ( q ( x t ∣ x t − 1 ) ∣ ∣ p θ ( x t ∣ x t + 1 ) ) ] \mathbb{E}_{q(x_{t-1}, x_{t+1}|x_0)}\left[D_{KL}(q(x_t|x_{t-1})\ ||\ p_\theta(x_{t}|x_{t+1}))\right] Eq(xt−1,xt+1∣x0)[DKL(q(xt∣xt−1) ∣∣ pθ(xt∣xt+1))]是一致性项,该项旨在限定在 x t x_t xt处的分布一致,也就是说一个去噪步和对应的加噪步应该保持一致。我们可以通过训练 p θ ( x t ∣ x t + 1 ) p_\theta(x_{t}|x_{t+1}) pθ(xt∣xt+1),使其与高斯分布 q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q(xt∣xt−1)一致来最小化该项。
优化VDM的花费主要在第三项,因为我们得对所有的中间步 t t t进行优化。
但是,上述推导中第三项要同时对两个变量
{
x
t
−
1
,
x
t
+
1
}
\{x_{t-1}, x_{t+1}\}
{xt−1,xt+1}求期望,其对应的蒙特卡洛估计方差一般比对一个变量的要高。接下来,我们进一步推导出一个ELBO,使得最终每一项只需对一个变量求期望。关键的一步是利用
q
(
x
t
∣
x
t
−
1
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
q(x_t|x_{t-1})=q(x_t|x_{t-1}, x_0)
q(xt∣xt−1)=q(xt∣xt−1,x0),这是因为在马尔可夫过程中条件
x
0
x_0
x0是多余的。进一步,我们利用贝叶斯公式:
q
(
x
t
∣
x
t
−
1
,
x
0
)
=
q
(
x
t
−
1
∣
x
t
,
x
0
)
q
(
x
t
∣
x
0
)
q
(
x
t
−
1
∣
x
0
)
\begin{equation} q(x_t|x_{t-1}, x_0) = \frac{q(x_{t-1}|x_{t}, x_0)q(x_t|x_0)}{q(x_{t-1}|x_0)} \end{equation}
q(xt∣xt−1,x0)=q(xt−1∣x0)q(xt−1∣xt,x0)q(xt∣x0)
代入上文中的推导,我们得到:
log
p
(
x
)
≥
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
0
:
T
)
q
(
x
1
:
T
∣
x
0
)
]
=
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
T
)
∏
t
=
1
T
p
θ
(
x
t
−
1
∣
x
t
)
∏
t
=
1
T
q
(
x
t
∣
x
t
−
1
)
]
=
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
T
)
p
θ
(
x
0
∣
x
1
)
∏
t
=
2
T
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
1
∣
x
0
)
∏
t
=
2
T
q
(
x
t
∣
x
t
−
1
)
]
=
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
T
)
p
θ
(
x
0
∣
x
1
)
∏
t
=
2
T
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
1
∣
x
0
)
∏
t
=
2
T
q
(
x
t
∣
x
t
−
1
,
x
0
)
]
=
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
T
)
p
θ
(
x
0
∣
x
1
)
q
(
x
1
∣
x
0
)
+
log
∏
t
=
2
T
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
t
∣
x
t
−
1
,
x
0
)
]
=
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
T
)
p
θ
(
x
0
∣
x
1
)
q
(
x
1
∣
x
0
)
+
log
∏
t
=
2
T
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
t
−
1
∣
x
t
,
x
0
)
q
(
x
t
∣
x
0
)
q
(
x
t
−
1
∣
x
0
)
]
=
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
T
)
p
θ
(
x
0
∣
x
1
)
q
(
x
1
∣
x
0
)
+
log
q
(
x
1
∣
x
0
)
q
(
x
T
∣
x
0
)
+
log
∏
t
=
2
T
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
t
−
1
∣
x
t
,
x
0
)
]
=
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
T
)
p
θ
(
x
0
∣
x
1
)
q
(
x
T
∣
x
0
)
+
log
∏
t
=
2
T
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
t
−
1
∣
x
t
,
x
0
)
]
=
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
θ
(
x
0
∣
x
1
)
]
+
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
(
x
T
)
q
(
x
T
∣
x
0
)
]
+
∑
t
=
2
T
E
q
(
x
1
:
T
∣
x
0
)
[
log
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
t
−
1
∣
x
t
,
x
0
)
]
=
E
q
(
x
1
∣
x
0
)
[
log
p
θ
(
x
0
∣
x
1
)
]
+
E
q
(
x
T
∣
x
0
)
[
log
p
(
x
T
)
q
(
x
T
∣
x
0
)
]
+
∑
t
=
2
T
E
q
(
x
t
−
1
,
x
t
∣
x
0
)
[
log
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
t
−
1
∣
x
t
,
x
0
)
]
=
E
q
(
x
1
∣
x
0
)
[
log
p
θ
(
x
0
∣
x
1
)
]
−
D
K
L
(
q
(
x
T
∣
x
0
)
∥
p
(
x
T
)
)
−
∑
t
=
2
T
E
q
(
x
t
∣
x
0
)
[
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∥
p
θ
(
x
t
−
1
∣
x
t
)
)
]
\begin{align*} \log p(x) &\geq \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_{0:T})}{q(x_{1:T}|x_0)}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)\prod_{t=1}^{T}p_\theta(x_{t-1}|x_t)}{\prod_{t=1}^T q(x_t|x_{t-1})}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)p_\theta(x_0|x_1)\prod_{t=2}^{T}p_\theta(x_{t-1}|x_t)}{q(x_1|x_{0}) \prod_{t=2}^{T} q(x_t|x_{t-1})}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)p_\theta(x_0|x_1)\prod_{t=2}^{T}p_\theta(x_{t-1}|x_t)}{q(x_1|x_{0}) \prod_{t=2}^{T} q(x_t|x_{t-1}, x_0)}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)p_\theta(x_0|x_1)}{q(x_1|x_{0})}+ \log \prod_{t=2}^{T}\frac{p_\theta(x_{t-1}|x_t)}{q(x_t|x_{t-1}, x_0)}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)p_\theta(x_0|x_1)}{q(x_1|x_{0})}+ \log \prod_{t=2}^{T}\frac{p_\theta(x_{t-1}|x_t)}{ \frac{q(x_{t-1}|x_{t}, x_0)q(x_t|x_0)}{q(x_{t-1}|x_0)}}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)p_\theta(x_0|x_1)}{q(x_1|x_{0})}+ \log \frac{q(x_1|x_0)}{q(x_T|x_0)}+ \log \prod_{t=2}^{T}\frac{p_\theta(x_{t-1}|x_t)}{q(x_{t-1}|x_{t}, x_0)}\right]\\ &=\mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)p_\theta(x_0|x_1)}{q(x_T|x_0)}+ \log \prod_{t=2}^{T}\frac{p_\theta(x_{t-1}|x_t)}{q(x_{t-1}|x_{t}, x_0)}\right] \\ &=\mathbb{E}_{q(x_{1:T}|x_0)}\left[\log p_\theta(x_0|x_1)\right] + \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log \frac{p(x_T)}{q(x_T|x_0)}\right] + \sum_{t=2}^T\mathbb{E}_{q(x_{1:T}|x_0)}\left[\log \frac{p_\theta(x_{t-1}|x_t)}{q(x_{t-1}|x_{t}, x_0)}\right] \\ &=\mathbb{E}_{q(x_{1}|x_0)}\left[\log p_\theta(x_0|x_1)\right] + \mathbb{E}_{q(x_{T}|x_0)}\left[\log \frac{p(x_T)}{q(x_T|x_0)}\right] + \sum_{t=2}^T\mathbb{E}_{q(x_{t-1}, x_t|x_0)}\left[\log \frac{p_\theta(x_{t-1}|x_t)}{q(x_{t-1}|x_{t}, x_0)}\right] \\ &=\mathbb{E}_{q(x_{1}|x_0)}\left[\log p_\theta(x_0|x_1)\right] -D_{KL}(q(x_{T}|x_0)\ \|\ p(x_T))- \sum_{t=2}^T\mathbb{E}_{q(x_t|x_0)}\left[D_{KL}(q(x_{t-1}|x_{t}, x_0)\ \|\ p_\theta(x_{t-1}|x_t))\right] \end{align*}
logp(x)≥Eq(x1:T∣x0)[logq(x1:T∣x0)p(x0:T)]=Eq(x1:T∣x0)[log∏t=1Tq(xt∣xt−1)p(xT)∏t=1Tpθ(xt−1∣xt)]=Eq(x1:T∣x0)[logq(x1∣x0)∏t=2Tq(xt∣xt−1)p(xT)pθ(x0∣x1)∏t=2Tpθ(xt−1∣xt)]=Eq(x1:T∣x0)[logq(x1∣x0)∏t=2Tq(xt∣xt−1,x0)p(xT)pθ(x0∣x1)∏t=2Tpθ(xt−1∣xt)]=Eq(x1:T∣x0)[logq(x1∣x0)p(xT)pθ(x0∣x1)+logt=2∏Tq(xt∣xt−1,x0)pθ(xt−1∣xt)]=Eq(x1:T∣x0)
logq(x1∣x0)p(xT)pθ(x0∣x1)+logt=2∏Tq(xt−1∣x0)q(xt−1∣xt,x0)q(xt∣x0)pθ(xt−1∣xt)
=Eq(x1:T∣x0)[logq(x1∣x0)p(xT)pθ(x0∣x1)+logq(xT∣x0)q(x1∣x0)+logt=2∏Tq(xt−1∣xt,x0)pθ(xt−1∣xt)]=Eq(x1:T∣x0)[logq(xT∣x0)p(xT)pθ(x0∣x1)+logt=2∏Tq(xt−1∣xt,x0)pθ(xt−1∣xt)]=Eq(x1:T∣x0)[logpθ(x0∣x1)]+Eq(x1:T∣x0)[logq(xT∣x0)p(xT)]+t=2∑TEq(x1:T∣x0)[logq(xt−1∣xt,x0)pθ(xt−1∣xt)]=Eq(x1∣x0)[logpθ(x0∣x1)]+Eq(xT∣x0)[logq(xT∣x0)p(xT)]+t=2∑TEq(xt−1,xt∣x0)[logq(xt−1∣xt,x0)pθ(xt−1∣xt)]=Eq(x1∣x0)[logpθ(x0∣x1)]−DKL(q(xT∣x0) ∥ p(xT))−t=2∑TEq(xt∣x0)[DKL(q(xt−1∣xt,x0) ∥ pθ(xt−1∣xt))]
- E q ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] \mathbb{E}_{q(x_{1}|x_0)}\left[\log p_\theta(x_0|x_1)\right] Eq(x1∣x0)[logpθ(x0∣x1)]还是恢复项;
- D K L ( q ( x T ∣ x 0 ) ∥ p ( x T ) ) D_{KL}(q(x_{T}|x_0)\ \|\ p(x_T)) DKL(q(xT∣x0) ∥ p(xT))依然是衡量了最终的分布和标准正态分布间的差距,根据假设,该项应该为0;
- E q ( x t ∣ x 0 ) [ D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) ] \mathbb{E}_{q(x_t|x_0)}\left[D_{KL}(q(x_{t-1}|x_{t}, x_0)\ \|\ p_\theta(x_{t-1}|x_t))\right] Eq(xt∣x0)[DKL(q(xt−1∣xt,x0) ∥ pθ(xt−1∣xt))]是去噪一致项,我们希望通过训练去噪过程 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) pθ(xt−1∣xt),使其符合真实的去噪过程 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_{t}, x_0) q(xt−1∣xt,x0)。
优化VDM的重担依然落在最小化第三项上。
接下来,我们就讨论如何优化 D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) D_{KL}(q(x_{t-1}|x_{t}, x_0)\ \|\ p_\theta(x_{t-1}|x_t)) DKL(q(xt−1∣xt,x0) ∥ pθ(xt−1∣xt))。
回忆一下VDM模型的第二个假设,我们假设编码器是线性高斯分布,也就是说
q
(
x
t
∣
x
t
−
1
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
)
=
N
(
x
t
;
α
t
x
t
−
1
,
(
1
−
α
t
)
I
)
\begin{equation} q(x_t|x_{t-1}, x_0)=q(x_t|x_{t-1}) = \mathcal{N}(x_t; \sqrt \alpha_t x_{t-1}, (1-\alpha_t) I) \end{equation}
q(xt∣xt−1,x0)=q(xt∣xt−1)=N(xt;αtxt−1,(1−αt)I)
再利用之前提到的重新参数化技巧,我们可以得出:
x
t
=
α
t
x
t
−
1
+
1
−
α
t
ϵ
,
with
ϵ
∼
N
(
ϵ
;
0
,
I
)
\begin{equation} x_t = \sqrt \alpha_t x_{t-1} + \sqrt{1-\alpha_t} \epsilon, \ \text{with} \ \epsilon \sim \mathcal{N}(\epsilon; 0, I) \end{equation}
xt=αtxt−1+1−αtϵ, with ϵ∼N(ϵ;0,I)
那么,假设我们有
2
T
2T
2T个随机噪声变量
{
ϵ
t
∗
,
ϵ
t
}
t
=
0
T
−
1
∼
N
(
ϵ
;
0
,
I
)
\{\epsilon^*_t, \epsilon_t\}^{T-1}_{t=0}\sim \mathcal{N}(\epsilon; 0, I)
{ϵt∗,ϵt}t=0T−1∼N(ϵ;0,I),我们可以得出以下递归等式:
x
t
=
α
t
x
t
−
1
+
1
−
α
t
ϵ
t
−
1
∗
=
α
t
(
α
t
−
1
x
t
−
2
+
1
−
α
t
−
1
ϵ
t
−
2
∗
)
+
1
−
α
t
ϵ
t
−
1
∗
=
α
t
α
t
−
1
x
t
−
2
+
α
t
−
α
t
α
t
−
1
ϵ
t
−
2
∗
+
1
−
α
t
ϵ
t
−
1
∗
=
α
t
α
t
−
1
x
t
−
2
+
α
t
−
α
t
α
t
−
1
+
1
−
α
t
ϵ
t
−
2
=
α
t
α
t
−
1
x
t
−
2
+
1
−
α
t
α
t
−
1
ϵ
t
−
2
=
…
=
∏
t
=
1
T
α
t
x
0
+
1
−
∏
t
=
1
T
α
t
ϵ
0
=
α
ˉ
t
x
0
+
1
−
α
ˉ
t
ϵ
0
∼
N
(
x
t
;
α
ˉ
t
x
0
,
(
1
−
α
ˉ
t
)
I
)
\begin{align} x_t &= \sqrt \alpha_t x_{t-1} + \sqrt{1-\alpha_t} \epsilon^*_{t-1} \\ &= \sqrt \alpha_t (\sqrt \alpha_{t-1} x_{t-2} + \sqrt{1-\alpha_{t-1}} \epsilon^*_{t-2}) + \sqrt{1-\alpha_t} \epsilon^*_{t-1} \\ &= \sqrt{\alpha_t\alpha_{t-1}}x_{t-2} + \sqrt{ \alpha_t- \alpha_t\alpha_{t-1}}\epsilon^*_{t-2} + \sqrt{1-\alpha_t} \epsilon^*_{t-1} \\ &= \sqrt{\alpha_t\alpha_{t-1}}x_{t-2} + \sqrt{\alpha_t- \alpha_t\alpha_{t-1} + 1-\alpha_t}\epsilon_{t-2} \\ &= \sqrt{\alpha_t\alpha_{t-1}}x_{t-2} + \sqrt{1- \alpha_t\alpha_{t-1}}\epsilon_{t-2}\\ &= \dots \\ &= \sqrt{\prod_{t=1}^T\alpha_t}x_{0} + \sqrt{1- \prod_{t=1}^T\alpha_t}\epsilon_{0} \\ &= \sqrt{\bar{\alpha}_t}x_{0} + \sqrt{1- \bar{\alpha}_t}\epsilon_{0} \\ &\sim \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t}x_{0}, (1- \bar{\alpha}_t)I) \end{align}
xt=αtxt−1+1−αtϵt−1∗=αt(αt−1xt−2+1−αt−1ϵt−2∗)+1−αtϵt−1∗=αtαt−1xt−2+αt−αtαt−1ϵt−2∗+1−αtϵt−1∗=αtαt−1xt−2+αt−αtαt−1+1−αtϵt−2=αtαt−1xt−2+1−αtαt−1ϵt−2=…=t=1∏Tαtx0+1−t=1∏Tαtϵ0=αˉtx0+1−αˉtϵ0∼N(xt;αˉtx0,(1−αˉt)I)
获得
q
(
x
t
∣
x
0
)
q(x_t|x_0)
q(xt∣x0)的显示形式后,我们可以再次利用贝叶斯公式,写出
q
(
x
t
∣
x
t
−
1
,
x
0
)
q(x_t|x_{t-1}, x_0)
q(xt∣xt−1,x0)的显示形式:
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
=
N
(
x
t
;
α
t
x
t
−
1
,
(
1
−
α
t
)
I
)
N
(
x
t
−
1
;
α
ˉ
t
−
1
x
0
,
(
1
−
α
ˉ
t
−
1
)
I
)
N
(
x
t
;
α
ˉ
t
x
0
,
(
1
−
α
ˉ
t
)
I
)
∝
exp
{
−
[
(
x
t
−
α
t
x
t
−
1
)
2
2
(
1
−
α
t
)
+
(
x
t
−
1
−
α
ˉ
t
−
1
x
0
)
2
2
(
1
−
α
ˉ
t
−
1
)
−
(
x
t
−
α
ˉ
t
x
0
)
2
2
(
1
−
α
ˉ
t
)
]
}
∝
exp
{
−
1
2
(
1
(
1
−
α
t
)
(
1
−
α
ˉ
t
−
1
)
(
1
−
α
ˉ
t
)
)
[
x
t
−
1
2
−
2
α
t
(
1
−
α
ˉ
t
−
1
)
x
t
+
α
ˉ
t
−
1
(
1
−
α
t
)
x
0
1
−
α
ˉ
t
x
t
−
1
]
}
∝
N
(
x
t
−
1
;
α
t
(
1
−
α
ˉ
t
−
1
)
x
t
+
α
ˉ
t
−
1
(
1
−
α
t
)
x
0
1
−
α
ˉ
t
,
(
1
−
α
t
)
(
1
−
α
ˉ
t
−
1
)
(
1
−
α
ˉ
t
)
I
)
\begin{align*} q(x_{t-1}|x_{t}, x_0) &= \frac{q(x_{t}|x_{t-1}, x_0)q(x_{t-1}|x_0)}{q(x_{t}|x_0)} \\ &= \frac{\mathcal{N}(x_t; \sqrt{{\alpha}_{t}}x_{t-1}, (1- {\alpha}_{t})I)\mathcal{N}(x_{t-1}; \sqrt{\bar{\alpha}_{t-1}}x_{0}, (1- \bar{\alpha}_{t-1})I)}{\mathcal{N}(x_t; \sqrt{\bar{\alpha}_t}x_{0}, (1- \bar{\alpha}_t)I)} \\ &\propto \exp \left\{ - \left[ \frac{(x_{t}-\sqrt{{\alpha}_{t}}x_{t-1})^2}{2(1-\alpha_t)} + \frac{(x_{t-1}-\sqrt{\bar{\alpha}_{t-1}}x_{0})^2}{2(1- \bar{\alpha}_{t-1})} - \frac{(x_t - \sqrt{\bar{\alpha}_t}x_{0})^2}{2(1- \bar{\alpha}_t)} \right] \right\} \\ &\propto \exp \left\{-\frac{1}{2}\left( \frac{1}{\frac{(1-\alpha_t)(1- \bar{\alpha}_{t-1})}{(1- \bar{\alpha}_{t})}}\right)\left[x_{t-1}^2-2\frac{\sqrt{{\alpha}_{t}}(1- \bar{\alpha}_{t-1})x_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)x_0}{1- \bar{\alpha}_{t}}x_{t-1} \right] \right\} \\ &\propto \mathcal{N}(x_{t-1}; \frac{\sqrt{{\alpha}_{t}}(1- \bar{\alpha}_{t-1})x_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)x_0}{1- \bar{\alpha}_{t}}, \frac{(1-\alpha_t)(1- \bar{\alpha}_{t-1})}{(1- \bar{\alpha}_{t})}I) \end{align*}
q(xt−1∣xt,x0)=q(xt∣x0)q(xt∣xt−1,x0)q(xt−1∣x0)=N(xt;αˉtx0,(1−αˉt)I)N(xt;αtxt−1,(1−αt)I)N(xt−1;αˉt−1x0,(1−αˉt−1)I)∝exp{−[2(1−αt)(xt−αtxt−1)2+2(1−αˉt−1)(xt−1−αˉt−1x0)2−2(1−αˉt)(xt−αˉtx0)2]}∝exp⎩
⎨
⎧−21
(1−αˉt)(1−αt)(1−αˉt−1)1
[xt−12−21−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x0xt−1]⎭
⎬
⎫∝N(xt−1;1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x0,(1−αˉt)(1−αt)(1−αˉt−1)I)
根据以上推导,我们得出
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_{t}, x_0)
q(xt−1∣xt,x0)是一个正态分布,其中均值记做:
μ
q
(
x
t
,
x
0
)
=
α
t
(
1
−
α
ˉ
t
−
1
)
x
t
+
α
ˉ
t
−
1
(
1
−
α
t
)
x
0
1
−
α
ˉ
t
\begin{equation} \mu_q(x_t, x_0) = \frac{\sqrt{{\alpha}_{t}}(1- \bar{\alpha}_{t-1})x_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)x_0}{1- \bar{\alpha}_{t}} \end{equation}
μq(xt,x0)=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x0
方差记做:
Σ
q
(
t
)
=
σ
q
2
(
t
)
I
=
(
1
−
α
t
)
(
1
−
α
ˉ
t
−
1
)
(
1
−
α
ˉ
t
)
I
\begin{equation} \Sigma_q(t) = \sigma_q^2(t)I = \frac{(1-\alpha_t)(1- \bar{\alpha}_{t-1})}{(1- \bar{\alpha}_{t})}I \end{equation}
Σq(t)=σq2(t)I=(1−αˉt)(1−αt)(1−αˉt−1)I
为了使去噪过程
p
θ
(
x
t
−
1
∣
x
t
)
p_\theta(x_{t-1}|x_t)
pθ(xt−1∣xt)和真实的去噪过程
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_{t}, x_0)
q(xt−1∣xt,x0)一致,我们可以将它建模成正态分布。具体而言,由于
α
\alpha
α是事先定好的,我们可以直接设置
p
θ
(
x
t
−
1
∣
x
t
)
p_\theta(x_{t-1}|x_t)
pθ(xt−1∣xt)方差为
Σ
q
(
t
)
=
σ
q
2
(
t
)
I
\Sigma_q(t) = \sigma_q^2(t)I
Σq(t)=σq2(t)I;另外,我们设置其均值为
μ
θ
(
x
t
,
t
)
\mu_\theta(x_t,t)
μθ(xt,t)。进而我们可以得到:
arg min
θ
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∥
p
θ
(
x
t
−
1
∣
x
t
)
)
=
arg min
θ
D
K
L
(
N
(
x
t
−
1
;
μ
q
,
Σ
q
(
t
)
)
∥
N
(
x
t
−
1
;
μ
θ
,
Σ
q
(
t
)
)
)
=
arg min
θ
1
2
[
log
∣
Σ
q
(
t
)
∣
∣
Σ
q
(
t
)
∣
−
d
+
tr
(
Σ
q
(
t
)
−
1
Σ
q
(
t
)
)
+
(
μ
θ
−
μ
q
)
⊤
Σ
q
(
t
)
−
1
(
μ
θ
−
μ
q
)
]
=
arg min
θ
1
2
σ
q
2
(
t
)
(
∥
μ
θ
−
μ
q
)
∥
2
2
)
\begin{align*} &\argmin_\theta D_{KL}(q(x_{t-1}|x_{t}, x_0)\ \|\ p_\theta(x_{t-1}|x_t)) \\ =&\argmin_\theta D_{KL}(\mathcal{N}(x_{t-1};\mu_q, \Sigma_q(t))\ \|\ \mathcal{N}(x_{t-1};\mu_\theta, \Sigma_q(t))) \\ =&\argmin_\theta\frac{1}{2}\left[\log \frac{ |\Sigma_q(t)|}{ |\Sigma_q(t)|} -d + \text{tr}( \Sigma_q(t)^{-1} \Sigma_q(t))+(\mu_{\theta}-\mu_{q})^\top \Sigma_q(t)^{-1} (\mu_{\theta}-\mu_{q})\right] \\ =&\argmin_\theta\frac{1}{2\sigma_q^2(t)}\left(\|\mu_{\theta}-\mu_{q})\|^2_2 \right) \end{align*}
===θargminDKL(q(xt−1∣xt,x0) ∥ pθ(xt−1∣xt))θargminDKL(N(xt−1;μq,Σq(t)) ∥ N(xt−1;μθ,Σq(t)))θargmin21[log∣Σq(t)∣∣Σq(t)∣−d+tr(Σq(t)−1Σq(t))+(μθ−μq)⊤Σq(t)−1(μθ−μq)]θargmin2σq2(t)1(∥μθ−μq)∥22)
D K L ( N ( x ; μ x , Σ x ) ∥ N ( y ; μ y , Σ y ) ) = 1 2 [ log ∣ Σ y ∣ ∣ Σ x ∣ − d + tr ( Σ y − 1 Σ x ) + ( μ y − μ x ) ⊤ Σ y − 1 ( μ y − μ x ) ] D_{KL}(\mathcal{N}(x;\mu_x, \Sigma_x)\ \|\ \mathcal{N}(y;\mu_y, \Sigma_y)) = \frac{1}{2}[\log \frac{|\Sigma_y|}{|\Sigma_x|}-d+\text{tr}(\Sigma_y^{-1}\Sigma_x)+(\mu_y-\mu_x)^\top\Sigma_y^{-1}(\mu_y-\mu_x)] DKL(N(x;μx,Σx) ∥ N(y;μy,Σy))=21[log∣Σx∣∣Σy∣−d+tr(Σy−1Σx)+(μy−μx)⊤Σy−1(μy−μx)]
也就是说,我们的目标等价于优化
μ
θ
(
x
t
,
t
)
\mu_\theta(x_t, t)
μθ(xt,t),使其和
μ
q
(
x
t
,
x
0
)
\mu_q(x_t, x_0)
μq(xt,x0)匹配。根据(36),我们可以将
μ
θ
(
x
t
,
t
)
\mu_\theta(x_t, t)
μθ(xt,t)设为
μ
θ
(
x
t
,
t
)
=
α
t
(
1
−
α
ˉ
t
−
1
)
x
t
+
α
ˉ
t
−
1
(
1
−
α
t
)
x
^
θ
(
x
t
,
t
)
1
−
α
ˉ
t
\begin{equation} \mu_\theta(x_t, t) = \frac{\sqrt{{\alpha}_{t}}(1- \bar{\alpha}_{t-1})x_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)\hat{x}_\theta(x_t, t)}{1- \bar{\alpha}_{t}} \end{equation}
μθ(xt,t)=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x^θ(xt,t)
有了这个显示形式,我们可以进一步改写我们的目标函数:
arg min
θ
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∥
p
θ
(
x
t
−
1
∣
x
t
)
)
=
arg min
θ
D
K
L
(
N
(
x
t
−
1
;
μ
q
,
Σ
q
(
t
)
)
∥
N
(
x
t
−
1
;
μ
θ
,
Σ
q
(
t
)
)
)
=
arg min
θ
1
2
σ
q
2
(
t
)
(
∥
μ
θ
−
μ
q
)
∥
2
2
)
=
arg min
θ
1
2
σ
q
2
(
t
)
(
∥
α
t
(
1
−
α
ˉ
t
−
1
)
x
t
+
α
ˉ
t
−
1
(
1
−
α
t
)
x
^
θ
(
x
t
,
t
)
1
−
α
ˉ
t
−
α
t
(
1
−
α
ˉ
t
−
1
)
x
t
+
α
ˉ
t
−
1
(
1
−
α
t
)
x
0
1
−
α
ˉ
t
∥
2
2
)
=
arg min
θ
1
2
σ
q
2
(
t
)
α
ˉ
t
−
1
(
1
−
α
t
)
2
(
1
−
α
ˉ
t
)
2
(
∥
x
^
θ
(
x
t
,
t
)
−
x
0
∥
2
2
)
\begin{align*} &\argmin_\theta D_{KL}(q(x_{t-1}|x_{t}, x_0)\ \|\ p_\theta(x_{t-1}|x_t)) \\ =&\argmin_\theta D_{KL}(\mathcal{N}(x_{t-1};\mu_q, \Sigma_q(t))\ \|\ \mathcal{N}(x_{t-1};\mu_\theta, \Sigma_q(t))) \\ =&\argmin_\theta\frac{1}{2\sigma_q^2(t)}\left(\|\mu_{\theta}-\mu_{q})\|^2_2 \right) \\ =&\argmin_\theta\frac{1}{2\sigma_q^2(t)}\left(\|\frac{\sqrt{{\alpha}_{t}}(1- \bar{\alpha}_{t-1})x_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)\hat{x}_\theta(x_t, t)}{1- \bar{\alpha}_{t}}-\frac{\sqrt{{\alpha}_{t}}(1- \bar{\alpha}_{t-1})x_t+\sqrt{\bar{\alpha}_{t-1}}(1-\alpha_t)x_0}{1- \bar{\alpha}_{t}}\|^2_2\right) \\ =&\argmin_\theta\frac{1}{2\sigma_q^2(t)}\frac{\bar{\alpha}_{t-1}(1-\alpha_t)^2}{(1- \bar{\alpha}_{t})^2}(\|\hat{x}_\theta(x_t, t)-x_0\|_2^2) \end{align*}
====θargminDKL(q(xt−1∣xt,x0) ∥ pθ(xt−1∣xt))θargminDKL(N(xt−1;μq,Σq(t)) ∥ N(xt−1;μθ,Σq(t)))θargmin2σq2(t)1(∥μθ−μq)∥22)θargmin2σq2(t)1(∥1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x^θ(xt,t)−1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x0∥22)θargmin2σq2(t)1(1−αˉt)2αˉt−1(1−αt)2(∥x^θ(xt,t)−x0∥22)
因此,优化一个VDM本质上就是训练一个去噪神经网络,能够从加了任意噪声的图片还原出相应的原始版本。并且,对于优化ELBO目标中第三项,我们可以利用迭代步上的期望来近似该求和:
arg min
θ
E
t
∼
U
{
2
,
T
}
[
E
q
(
x
t
∣
x
0
)
[
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∥
p
θ
(
x
t
−
1
∣
x
t
)
)
]
]
\begin{equation} \argmin_\theta \mathbb{E}_{t\sim U\{2, T\}}\left[\mathbb{E}_{q(x_t|x_0)}\left[D_{KL}(q(x_{t-1}|x_{t}, x_0)\ \|\ p_\theta(x_{t-1}|x_t))\right]\right] \end{equation}
θargminEt∼U{2,T}[Eq(xt∣x0)[DKL(q(xt−1∣xt,x0) ∥ pθ(xt−1∣xt))]]
上述优化过程可以通过对
t
t
t进行随机采样进行。
学习扩散噪声参数
上文中提到,每一步中的噪声参数
α
t
\alpha_t
αt是事先预定好的,那么能否将其设置成可学习的呢?一个直接的办法就是使用一个神经网络
α
^
η
(
t
)
\hat{\alpha}_\eta(t)
α^η(t)来对其进行建模。但是,这种方式很低效,因为在每一步中我们要多次调用这个网络来计算
α
ˉ
t
\bar{\alpha}_t
αˉt。尽管我们也可以提前计算完后将所需的值缓存下来,但是我们这里介绍一种基于对目标函数的进一步推导变形使得问题简化的方法。
1
2
σ
q
2
(
t
)
α
ˉ
t
−
1
(
1
−
α
t
)
2
(
1
−
α
ˉ
t
)
2
(
∥
x
^
θ
(
x
t
,
t
)
−
x
0
∥
2
2
)
=
1
2
(
1
−
α
t
)
(
1
−
α
ˉ
t
−
1
)
(
1
−
α
ˉ
t
)
α
ˉ
t
−
1
(
1
−
α
t
)
2
(
1
−
α
ˉ
t
)
2
(
∥
x
^
θ
(
x
t
,
t
)
−
x
0
∥
2
2
)
=
1
2
α
ˉ
t
−
1
(
1
−
α
t
)
(
1
−
α
ˉ
t
−
1
)
(
1
−
α
ˉ
t
)
(
∥
x
^
θ
(
x
t
,
t
)
−
x
0
∥
2
2
)
=
1
2
α
ˉ
t
−
1
(
1
−
α
ˉ
t
)
−
α
ˉ
t
(
1
−
α
ˉ
t
−
1
)
(
1
−
α
ˉ
t
−
1
)
(
1
−
α
ˉ
t
)
(
∥
x
^
θ
(
x
t
,
t
)
−
x
0
∥
2
2
)
=
1
2
(
α
ˉ
t
−
1
1
−
α
ˉ
t
−
1
−
α
ˉ
t
1
−
α
ˉ
t
)
(
∥
x
^
θ
(
x
t
,
t
)
−
x
0
∥
2
2
)
\begin{align} &\frac{1}{2\sigma_q^2(t)}\frac{\bar{\alpha}_{t-1}(1-\alpha_t)^2}{(1- \bar{\alpha}_{t})^2}(\|\hat{x}_\theta(x_t, t)-x_0\|_2^2) \\ =&\frac{1}{2\frac{(1-\alpha_t)(1- \bar{\alpha}_{t-1})}{(1- \bar{\alpha}_{t})}}\frac{\bar{\alpha}_{t-1}(1-\alpha_t)^2}{(1- \bar{\alpha}_{t})^2}(\|\hat{x}_\theta(x_t, t)-x_0\|_2^2) \\ =&\frac{1}{2}\frac{\bar{\alpha}_{t-1}(1-\alpha_t)}{(1- \bar{\alpha}_{t-1})(1- \bar{\alpha}_{t})}(\|\hat{x}_\theta(x_t, t)-x_0\|_2^2) \\ =&\frac{1}{2}\frac{\bar{\alpha}_{t-1}(1-\bar\alpha_t)-\bar{\alpha}_{t}(1-\bar\alpha_{t-1})}{(1- \bar{\alpha}_{t-1})(1- \bar{\alpha}_{t})}(\|\hat{x}_\theta(x_t, t)-x_0\|_2^2)\\ =&\frac{1}{2}(\frac{\bar{\alpha}_{t-1}}{1- \bar{\alpha}_{t-1}}-\frac{\bar{\alpha}_{t}}{1-\bar{\alpha}_{t}})(\|\hat{x}_\theta(x_t, t)-x_0\|_2^2) \end{align}
====2σq2(t)1(1−αˉt)2αˉt−1(1−αt)2(∥x^θ(xt,t)−x0∥22)2(1−αˉt)(1−αt)(1−αˉt−1)1(1−αˉt)2αˉt−1(1−αt)2(∥x^θ(xt,t)−x0∥22)21(1−αˉt−1)(1−αˉt)αˉt−1(1−αt)(∥x^θ(xt,t)−x0∥22)21(1−αˉt−1)(1−αˉt)αˉt−1(1−αˉt)−αˉt(1−αˉt−1)(∥x^θ(xt,t)−x0∥22)21(1−αˉt−1αˉt−1−1−αˉtαˉt)(∥x^θ(xt,t)−x0∥22)
由于,
q
(
x
t
∣
x
0
)
q(x_t|x_0)
q(xt∣x0)是高斯分布
N
(
x
t
;
α
ˉ
t
x
0
,
(
1
−
α
ˉ
t
)
I
)
\mathcal{N}(x_t; \sqrt{\bar{\alpha}_t}x_{0}, (1- \bar{\alpha}_t)I)
N(xt;αˉtx0,(1−αˉt)I),根据信噪比(signal-to-noise ratio,SNR)的定义
SNR
=
μ
2
σ
2
\text{SNR}=\frac{\mu^2}{\sigma^2}
SNR=σ2μ2,每一步的信噪比可以写作
SNR
(
t
)
=
α
ˉ
t
1
−
α
ˉ
t
\begin{equation} \text{SNR}(t) = \frac{\bar{\alpha}_t}{1- \bar{\alpha}_t} \end{equation}
SNR(t)=1−αˉtαˉt
进而,优化目标可以简化为:
1
2
(
SNR
(
t
−
1
)
−
SNR
(
t
)
)
(
∥
x
^
θ
(
x
t
,
t
)
−
x
0
∥
2
2
)
\begin{equation} \frac{1}{2}(\text{SNR}(t-1)-\text{SNR}(t))(\|\hat{x}_\theta(x_t, t)-x_0\|_2^2) \end{equation}
21(SNR(t−1)−SNR(t))(∥x^θ(xt,t)−x0∥22)
在VDM中,我们要求SNR是随着步数
t
t
t单调递减的(逐渐加噪声使得SNR单调递减),直至第
T
T
T步时成为标准高斯噪声。因此,我们将SNR建模为:
SNR
(
t
)
=
exp
(
−
ω
η
(
t
)
)
\begin{equation} \text{SNR}(t) = \exp(-\omega_\eta(t)) \end{equation}
SNR(t)=exp(−ωη(t))
其中,网络的参数是
η
\eta
η。那么,在优化(39)时,我们也要同时优化参数
η
\eta
η。注意到,(47)的形式使得我们可以得出
α
ˉ
t
=
sigmoid
(
−
ω
η
(
t
)
)
1
−
α
ˉ
t
=
sigmoid
(
ω
η
(
t
)
)
\begin{align} \bar{\alpha}_t &= \text{sigmoid}(-\omega_\eta(t)) \\ 1-\bar{\alpha}_t &= \text{sigmoid}(\omega_\eta(t)) \end{align}
αˉt1−αˉt=sigmoid(−ωη(t))=sigmoid(ωη(t))
上述两个式子使得我们可以很容易地依照(35)对在
x
0
x_0
x0上加上任意所需噪声得出
x
t
x_t
xt。