Diffusion Model(3):训练目标以及训练过程


观看本文之前建议先观看以下两篇文章:

Training Loss
训练目标

​ 首先回顾一下我们的问题,我们在逆向降噪过程中由于没办法得到 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}) q(xt1xt),因此我们定义了一个 需要学习的模型模型 p θ ( x t − 1 ∣ x t ) p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t) pθ(xt1xt)来对其进行近似,并且在训练阶段我们可以利用后验 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}\vert \mathbf{x}_t,\mathbf{x}_0) q(xt1xt,x0)来对 p θ p_\theta pθ进行优化。

​ 那么现在的问题是我们如何 p θ p_\theta pθ优化得到理想的 μ θ \boldsymbol{\mu}_\theta μθ Σ θ \boldsymbol{\Sigma}_\theta Σθ?类似于 VAE ,我们可以最小化在真实数据期望下,模型预测分布的负对数似然,即最小化预测 p d a t a = q ( x 0 ) p_{\mathrm{data}}=q({\mathbf{x}_0}) pdata=q(x0) p θ ( x 0 ) p_{\theta}(\mathbf{x}_0) pθ(x0)的交叉熵:
L = E x 0 ∼ q ( x 0 ) [ − log ⁡ p θ ( x 0 ) ] \begin{equation} \mathcal{L}=\mathbb{E}_{\mathbf{x}_{0} \sim q\left(\mathbf{x}_{0}\right)}\left[-\log p_{\theta}\left(\mathbf{x}_{0}\right)\right] \end{equation} L=Ex0q(x0)[logpθ(x0)]
​ 但是,我们没法得到 p θ ( x 0 ) p_\theta(\mathbf{x}_0) pθ(x0)的表达式,因此公式1的交叉熵是没法计算的。那么可以借助公式Diffusion Model(2):前向扩散过程和逆向降噪过程
2-6
进行一些数学推导。将公式1中的 p θ ( x 0 ) p_\theta(\mathbf{x}_0) pθ(x0)转化为已知的项:
L = − E q ( x 0 ) log ⁡ p θ ( x 0 ) = − E q ( x 0 ) log ⁡ ( ∫ p θ ( x 0 : T ) d x 1 : T ) = − E q ( x 0 ) log ⁡ ( ∫ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) d x 1 : T ) = − E q ( x 0 ) log ⁡ ( E q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ) ≤ − E q ( x 0 : T ) log ⁡ p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) = E q ( x 0 : T ) [ log ⁡ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] = L V L B \begin{equation} \begin{aligned} \mathcal{L} &=-\mathbb{E}_{q\left(\mathbf{x}_{0}\right)} \log p_{\theta}\left(\mathbf{x}_{0}\right) \\ &=-\mathbb{E}_{q\left(\mathbf{x}_{0}\right)} \log \left(\int p_{\theta}\left(\mathbf{x}_{0: T}\right) d \mathbf{x}_{1: T}\right) \\ &=-\mathbb{E}_{q\left(\mathbf{x}_{0}\right)} \log \left(\int q\left(\mathbf{x}_{1: T} \vert \mathbf{x}_{0}\right) \frac{p_{\theta}\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \vert \mathbf{x}_{0}\right)} d \mathbf{x}_{1: T}\right) \\ &=-\mathbb{E}_{q\left(\mathbf{x}_{0}\right)} \log \left(\mathbb{E}_{q\left(\mathbf{x}_{1: T} \vert \mathbf{x}_{0}\right)} \frac{p_{\theta}\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \vert \mathbf{x}_{0}\right)}\right) \\ & \leq-\mathbb{E}_{q\left(\mathbf{x}_{0: T}\right)} \log \frac{p_{\theta}\left(\mathbf{x}_{0: T}\right)}{q\left(\mathbf{x}_{1: T} \vert \mathbf{x}_{0}\right)} \\ &=\mathbb{E}_{q\left(\mathbf{x}_{0: T}\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \vert \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0: T}\right)}\right]=\mathcal{L}_{\mathrm{VLB}} \end{aligned} \end{equation} L=Eq(x0)logpθ(x0)=Eq(x0)log(pθ(x0:T)dx1:T)=Eq(x0)log(q(x1:Tx0)q(x1:Tx0)pθ(x0:T)dx1:T)=Eq(x0)log(Eq(x1:Tx0)q(x1:Tx0)pθ(x0:T))Eq(x0:T)logq(x1:Tx0)pθ(x0:T)=Eq(x0:T)[logpθ(x0:T)q(x1:Tx0)]=LVLB
​ 上式中 q ( x 0 ) q(\mathbf{x}_0) q(x0)是真实的数据分布,而 p θ ( x 0 ) p_\theta(\mathbf{x}_0) pθ(x0)是模型,从第四行到第五行使用了Jensen不等式 log ⁡ E [ f ( x ) ] ≤ E [ log ⁡ f ( x ) ] \log \mathbb{E}[f(x)] \leq \mathbb{E}[\log f(x)] logE[f(x)]E[logf(x)]并结合了对 q ( x 0 ) q(\mathbf{x}_0) q(x0)的期望和对 q ( x 1 : T ∣ x 0 ) q(\mathbf{x}_{1:T} \vert \mathbf{x}_0) q(x1:Tx0)的期望。

​ 为了最小化这个损失,结合公式2可以将其转化为最小化其上界 L V L B \mathcal{L}_{\mathrm{VLB}} LVLB
L V L B = E q ( x 0 : T ) [ log ⁡ q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] = E q [ log ⁡ ∏ t = 1 T q ( x t ∣ x t − 1 ) p θ ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) ] = E q [ − log ⁡ p θ ( x T ) + ∑ t = 1 T log ⁡ q ( x t ∣ x t − 1 ) p θ ( x t − 1 ∣ x t ) ] = E q [ − log ⁡ p θ ( x T ) + ∑ t = 2 T log ⁡ q ( x t ∣ x t − 1 ) p θ ( x t − 1 ∣ x t ) + log ⁡ q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ − log ⁡ p θ ( x T ) + ∑ t = 2 T log ⁡ ( q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) ⋅ q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) ) + log ⁡ q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ − log ⁡ p θ ( x T ) + ∑ t = 2 T log ⁡ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) + ∑ t = 2 T log ⁡ q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) + log ⁡ q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ − log ⁡ p θ ( x T ) + ∑ t = 2 T log ⁡ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) + log ⁡ q ( x T ∣ x 0 ) q ( x 1 ∣ x 0 ) + log ⁡ q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ log ⁡ q ( x T ∣ x 0 ) p θ ( x T ) + ∑ t = 2 T log ⁡ q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) − log ⁡ p θ ( x 0 ∣ x 1 ) ] = E q [ − log ⁡ p θ ( x 0 ∣ x 1 ) ⏟ L 0 ] + ∑ t = 2 T D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) ⏟ L t − 1 + D K L ( q ( x T ∣ x 0 ) ∥ p θ ( x T ) ) ⏟ L T \begin{equation} \begin{array}{l} \mathcal{L}_{\mathrm{VLB}}=\mathbb{E}_{q\left(\mathbf{x}_{0: T}\right)}\left[\log \frac{q\left(\mathbf{x}_{1: T} \vert \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0: T}\right)}\right]\\ =\mathbb{E}_{q}\left[\log \frac{\prod_{t=1}^{T} q\left(\mathbf{x}_{t} \vert \mathbf{x}_{t-1}\right)}{p_{\theta}\left(\mathbf{x}_{T}\right) \prod_{t=1}^{T} p_{\theta}\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}\right)}\right]\\ =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=1}^{T} \log \frac{q\left(\mathbf{x}_{t} \vert \mathbf{x}_{t-1}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}\right)}\right]\\ =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \frac{\color{blue}q(\mathbf{x}_{t} \vert \mathbf{x}_{t-1})}{p_{\theta}\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}\right)}+\log \frac{q\left(\mathbf{x}_{1} \vert \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \vert \mathbf{x}_{1}\right)}\right]\\ =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \left(\frac{\color{blue}q(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}, \mathbf{x}_{0})}{p_{\theta}\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}\right)} \cdot \frac{\color{blue}q(\mathbf{x}_{t} \vert \mathbf{x}_{0})}{\color{blue}q(\mathbf{x}_{t-1} \vert \mathbf{x}_{0})}\right)+\log \frac{q\left(\mathbf{x}_{1} \vert \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \vert \mathbf{x}_{1}\right)}\right]\\ =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}, \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}\right)}+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t} \vert \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{0}\right)}+\log \frac{q\left(\mathbf{x}_{1} \vert \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \vert \mathbf{x}_{1}\right)}\right]\\ =\mathbb{E}_{q}\left[-\log p_{\theta}\left(\mathbf{x}_{T}\right)+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}, \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}\right)}+\log \frac{q\left(\mathbf{x}_{T} \vert \mathbf{x}_{0}\right)}{q\left(\mathbf{x}_{1} \vert \mathbf{x}_{0}\right)}+\log \frac{q\left(\mathbf{x}_{1} \vert \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{0} \vert \mathbf{x}_{1}\right)}\right]\\ =\mathbb{E}_{q}\left[\log \frac{q\left(\mathbf{x}_{T} \vert \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{T}\right)}+\sum_{t=2}^{T} \log \frac{q\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}, \mathbf{x}_{0}\right)}{p_{\theta}\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}\right)}-\log p_{\theta}\left(\mathbf{x}_{0} \vert \mathbf{x}_{1}\right)\right]\\ =\mathbb{E}_{q}[\underbrace{-\log p_{\theta}\left(\mathbf{x}_{0} \vert \mathbf{x}_{1}\right)}_{L_{0}}]+\sum_{t=2}^{T} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}, \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}\right)\right)}_{L_{t-1}}+\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{T} \vert \mathbf{x}_{0}\right) \| p_{\theta}\left(\mathbf{x}_{T}\right)\right)}_{L_{T}} \end{array} \end{equation} LVLB=Eq(x0:T)[logpθ(x0:T)q(x1:Tx0)]=Eq[logpθ(xT)t=1Tpθ(xt1xt)t=1Tq(xtxt1)]=Eq[logpθ(xT)+t=1Tlogpθ(xt1xt)q(xtxt1)]=Eq[logpθ(xT)+t=2Tlogpθ(xt1xt)q(xtxt1)+logpθ(x0x1)q(x1x0)]=Eq[logpθ(xT)+t=2Tlog(pθ(xt1xt)q(xt1xt,x0)q(xt1x0)q(xtx0))+logpθ(x0x1)q(x1x0)]=Eq[logpθ(xT)+t=2Tlogpθ(xt1xt)q(xt1xt,x0)+t=2Tlogq(xt1x0)q(xtx0)+logpθ(x0x1)q(x1x0)]=Eq[logpθ(xT)+t=2Tlogpθ(xt1xt)q(xt1xt,x0)+logq(x1x0)q(xTx0)+logpθ(x0x1)q(x1x0)]=Eq[logpθ(xT)q(xTx0)+t=2Tlogpθ(xt1xt)q(xt1xt,x0)logpθ(x0x1)]=Eq[L0 logpθ(x0x1)]+t=2TLt1 DKL(q(xt1xt,x0)pθ(xt1xt))+LT DKL(q(xTx0)pθ(xT))
​ 上述式子中:

  • 从第三行到第四行,是将t=1的情况与总的求和拆开
  • 从第四行到第五行,使用了前向过程的马尔科夫过程结合贝叶斯公式Diffusion Model(1):预备知识1-8

q ( x t ∣ x t − 1 ) = q ( x t ∣ x t − 1 , x 0 ) = q ( x t , x t − 1 ∣ x 0 ) q ( x t − 1 ∣ x 0 ) = q ( x t − 1 ∣ x t , x 0 ) ⋅ q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) \begin{equation} \begin{aligned} q(\mathbf{x}_{t} \vert \mathbf{x}_{t-1}) &= q(\mathbf{x}_{t} \vert \mathbf{x}_{t-1}, \mathbf{x}_0) = \frac{q(\mathbf{x}_t, \mathbf{x}_{t-1} \vert \mathbf{x}_0)}{q(\mathbf{x}_{t-1} \vert \mathbf{x}_0)} = {\color{red}q(\mathbf{x}_{t-1} \vert \mathbf{x}_{t}, \mathbf{x}_0)} \cdot \frac{q(\mathbf{x}_t \vert \mathbf{x}_0)}{q(\mathbf{x}_{t-1}|\mathbf{x}_0)} \end{aligned} \end{equation} q(xtxt1)=q(xtxt1,x0)=q(xt1x0)q(xt,xt1x0)=q(xt1xt,x0)q(xt1x0)q(xtx0)

  • 从第六行到第七行,将对数和转化为了乘积的形式,然后消去相同的分子和分母
  • 第七行到第八行,首先将最后两项的求和转化为乘积然后消除 q ( x 1 ∣ x 0 ) q(\mathbf{x}_1 \vert \mathbf{x}_0) q(x1x0),然后将第一项放在分母上,将原来的分母变为减法
  • 从第八行到第九行,使用了KL散度的公式 D K L ( q ( x ) ∣ ∣ p ( x ) ) = E q [ log ⁡ q ( x ) p ( x ) ] D_{KL}(q(\mathbf{x}) || p(\mathbf{x}))=\mathbb{E}_{q}[\log\frac{q(\mathbf{x})}{p(\mathbf{x})}] DKL(q(x)∣∣p(x))=Eq[logp(x)q(x)]

​ 我们可以得到 L V L B \mathcal{L}_{VLB} LVLB实际上是由一个熵( L 0 L_0 L0)以及多个KL散度( ( L t , t ∈ 1 , 2 , 3 , . . . , T ) (L_t,t\in{1,2,3,...,T}) (Lt,t1,2,3,...,T))构成。其中最后一项 L T L_T LT中的 x t \mathbf{x}_t xt x 0 \mathbf{x}_0 x0分别是数据分布和先验分布,都是固定的,因此它是一个常数,在最小化时可以忽略。转而去研究 L 0 和 L t , t ∈ 1 , 2 , 3 , . . . , T − 1 L_0和L_t, t\in{1,2,3,...,T-1} L0Lt,t1,2,3,...,T1

L t L_t Lt的计算

​ 首先来考虑公式2中比较复杂的 L t L_t Lt

​ 我们的模型 p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t) = \mathcal{N}\left(\mathbf{x}_{t-1}; \mu_{\theta}(\mathbf{x}_t, t), \Sigma_\theta(\mathbf{x}_t, t) \right) pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))

​ 根据公式Diffusion Model(2):前向扩散过程和逆向降噪过程
2-11
,我们知道,对于平均值 μ θ ( x t , t ) \mu_{\theta}(\mathbf{x}_t,t) μθ(xt,t)在给定 x t \mathbf{x}_t xt t t t的情况下,我们期望它接近于 μ ~ t ( x t , x 0 ) = 1 α t ( x t − β t ( 1 − α ˉ t ) z t ) \tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0)=\frac{1}{{\sqrt{\alpha_t}}}\big(\mathbf{x}_t - \frac{\beta_t}{\sqrt{(1-\bar{\alpha}_{t})}} z_t \big) μ~t(xt,x0)=αt 1(xt(1αˉt) βtzt)

​ 因此我们可以通过重参数化,通过学习高斯噪声 z θ ( x t , t ) z_\theta(\mathbf{x}_t,t) zθ(xt,t)来接近于 z t z_t zt
μ θ ( x t , t ) = 1 α t ( x t − β t 1 − α ˉ t z θ ( x t , t ) ) \begin{equation} \boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right)=\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \mathbf{z}_{\theta}\left(\mathbf{x}_{t}, t\right)\right) \end{equation} μθ(xt,t)=αt 1(xt1αˉt βtzθ(xt,t))
​ 这个过程可以归结为我们的模型实际上学会了估计真正的逆向过程中的噪声

​ 对于方差 Σ θ ( x t , t ) \Sigma_\theta(\mathbf{x}_t, t) Σθ(xt,t)的处理,DDPM (Ho et al 2020)将其设置成了 σ t 2 I \sigma_t^2\mathbf{I} σt2I,其中 σ t 2 = β t \sigma_t^2=\beta_t σt2=βt或者 σ t 2 = β ~ t = 1 − α ˉ t − 1 1 − α ˉ t β t \sigma_t^2=\tilde{\beta}_t= \frac{1-\bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \beta_t σt2=β~t=1αˉt1αˉt1βt。当然此项也可以通过模型学习,此处只是为了简化,并不是唯一的。

​ 现在,结合KL散度的公式Diffusion Model(1):预备知识1-10我们可以写出模型 p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) = 1 α t ( x t − β t 1 − α ˉ t z θ ( x t , t ) ) , Σ θ ( x t , t ) = σ t 2 I ) p_\theta(\mathbf{x}_{t-1} \vert \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; {\color{green}\mu_{\theta}(\mathbf{x}_t, t)=\frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}z_\theta(\mathbf{x}_t,t))}, \Sigma_\theta(\mathbf{x}_t, t) = \sigma^2_t \mathbf{I}) pθ(xt1xt)=N(xt1;μθ(xt,t)=αt 1(xt1αˉt βtzθ(xt,t)),Σθ(xt,t)=σt2I)关于 q ( x t − 1 ∣ x t ) = q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ t ( x t , x 0 ) = 1 α t ( x t − β t 1 − α ˉ t z t ) , Σ ~ t = β ˉ t I ) q(\mathbf{x}_{t-1} \vert \mathbf{x}_t) =q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; {\color{blue}\tilde{\mu}_t(\mathbf{x}_t, \mathbf{x}_0)=\frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}z_t)}, \tilde{\Sigma}_t =\bar{\beta}_t\mathbf{I}) q(xt1xt)=q(xt1xt,x0)=N(xt1;μ~t(xt,x0)=αt 1(xt1αˉt βtzt),Σ~t=βˉtI)的KL散度:
L t = E x 0 , z [ 1 2 ∥ Σ θ ( x t , t ) ∥ 2 2 ∥ μ ~ t ( x t , x 0 ) − μ θ ( x t , t ) ∥ 2 ] = E x 0 , z [ 1 2 ∥ Σ θ ∥ 2 2 ∥ 1 α t ( x t − β t 1 − α ˉ t z t ) − 1 α t ( x t − β t 1 − α ˉ t z θ ( x t , t ) ) ∥ 2 ] = E x 0 , z [ β t 2 2 α t ( 1 − α ˉ t ) ∥ Σ θ ∥ 2 2 ∥ z t − z θ ( x t , t ) ∥ 2 ] = E x 0 , z [ β t 2 2 α t ( 1 − α ˉ t ) ∥ Σ θ ∥ 2 2 ∥ z t − z θ ( α ˉ t x 0 + 1 − α ˉ t z t , t ) ∥ 2 ] \begin{equation} \begin{array}{l} L_{t}=\mathbb{E}_{\mathbf{x}_{0}, \mathbf{z}}\left[\frac{1}{2\left\|\boldsymbol{\Sigma}_{\theta}\left(\mathbf{x}_{t}, t\right)\right\|_{2}^{2}}\left\| \textcolor{blue}{ \tilde{\boldsymbol{\mu}}_{t}\left(\mathbf{x}_{t}, \mathbf{x}_{0}\right)}-\textcolor{green}{\boldsymbol{\mu}_{\theta}\left(\mathbf{x}_{t}, t\right)}\right\|^{2}\right]\\ =\mathbb{E}_{\mathbf{x}_{0}, \mathbf{z}}\left[\frac{1}{2\left\|\mathbf{\Sigma}_{\theta}\right\|_{2}^{2}}\left\|\textcolor{blue}{\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \mathbf{z}_{t}\right)}-\textcolor{green}{\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{\beta_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \mathbf{z}_{\theta}\left(\mathbf{x}_{t}, t\right)\right)}\right\|^{2}\right]\\ =\mathbb{E}_{\mathbf{x}_{0}, \mathbf{z}}\left[\frac{\beta_{t}^{2}}{2 \alpha_{t}\left(1-\bar{\alpha}_{t}\right)\left\|\mathbf{\Sigma}_{\theta}\right\|_{2}^{2}}\left\|\mathbf{z}_{t}-\mathbf{z}_{\theta}\left(\mathbf{x}_{t}, t\right)\right\|^{2}\right]\\ =\mathbb{E}_{\mathbf{x}_{0}, \mathbf{z}}\left[\frac{\beta_{t}^{2}}{2 \alpha_{t}\left(1-\bar{\alpha}_{t}\right)\left\|\mathbf{\Sigma}_{\theta}\right\|_{2}^{2}}\left\|\mathbf{z}_{t}-\mathbf{z}_{\theta}\left(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \mathbf{z}_{t}, t\right)\right\|^{2}\right] \end{array} \end{equation} Lt=Ex0,z[2Σθ(xt,t)221μ~t(xt,x0)μθ(xt,t)2]=Ex0,z[2Σθ221 αt 1(xt1αˉt βtzt)αt 1(xt1αˉt βtzθ(xt,t)) 2]=Ex0,z[2αt(1αˉt)Σθ22βt2ztzθ(xt,t)2]=Ex0,z[2αt(1αˉt)Σθ22βt2 ztzθ(αˉt x0+1αˉt zt,t) 2]
​ 我们可以发现, L t L_t Lt的训练目标实际上是在使用MSE最小化两个高斯噪声 z t z_t zt z θ ( x t , t ) z_\theta(\mathbf{x}_t,t) zθ(xt,t)

​ 在训练中发现上述带有加权( β t 2 2 α t ( 1 − α ˉ t ) ∥ Σ θ ∥ 2 2 \frac{\beta_{t}^{2}}{2 \alpha_{t}\left(1-\bar{\alpha}_{t}\right)\left\|\mathbf{\Sigma}_{\theta}\right\|_{2}^{2}} 2αt(1αˉt)Σθ22βt2)的MSE Loss不太稳定,因此DDPM (Ho et al 2020)使用了不带权重项的简化损失。
L t simple  = E x 0 , z t [ ∥ z t − z θ ( α ˉ t x 0 + 1 − α ˉ t z t , t ) ∥ 2 ] L simple  = L t simple  + C \begin{equation} L_{t}^{\text {simple }}=\mathbb{E}_{\mathbf{x}_{0}, \mathbf{z}_{t}}\left[\left\|\mathbf{z}_{t}-\mathbf{z}_{\theta}\left(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \mathbf{z}_{t}, t\right)\right\|^{2}\right]\\ L_{\text {simple }}=L_{t}^{\text {simple }} + C \end{equation} Ltsimple =Ex0,zt[ ztzθ(αˉt x0+1αˉt zt,t) 2]Lsimple =Ltsimple +C
​ 其中的 C C C是一项不依赖于 θ \theta θ的常数。

L 0 L_0 L0的计算

​ 然后来考虑 L 0 L_0 L0的计算。

​ 已知 L 0 = − E x 0 , x 1 log ⁡ ( p θ ( x 0 ∣ x 1 ) ) L_0=-\mathbb{E}_{\mathbf{x}_0,\mathbf{x}_1}\log(p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)) L0=Ex0,x1log(pθ(x0x1)),而 p θ ( x 0 ∣ x 1 ) = N ( μ θ ( x 1 ) , 1 , σ 1 2 I ) p_\theta(\mathbf{x}_0 \vert \mathbf{x}_1)=\mathcal{N}(\mu_\theta(\mathbf{x}_1),1,\sigma_1^2\mathbf{I}) pθ(x0x1)=N(μθ(x1),1,σ12I)。因此 L 0 L_0 L0实际上是一个多元高斯分布的负对数似然期望,即其熵。多元高斯分布的熵仅与其协方差有关,即 L 0 L_0 L0仅与 σ 1 2 I \sigma_1^2\mathbf{I} σ12I有关, L 0 L_0 L0是一个常数。

​ 然而,论文DDPM指出,一般而言, x 0 \mathbf{x}_0 x0的分布实际上是离散的,而不是连续的。比如图片数据,像素值取值必须是整数,归一化到 [ − 1 , 1 ] [-1,1] [1,1]后,依然是离散的点。Diffusion前向过程的第一步实际上是为离散数据添加噪声。那么,逆Diffusion的最后一步,即从 x 1 \mathbf{x}_1 x1 x 0 \mathbf{x}_0 x0,也不能被简单地看作从 N ( μ θ ( x 1 , 1 ) , σ 1 2 I ) \mathcal{N}(\boldsymbol{\mu}_\theta(\mathbf{x}_1, 1), \sigma_1^2\mathbf{I}) N(μθ(x1,1),σ12I)中采样,而是在从 N ( μ θ ( x 1 , 1 ) , σ 1 2 I ) \mathcal{N}(\boldsymbol{\mu}_\theta(\mathbf{x}_1, 1), \sigma_1^2\mathbf{I}) N(μθ(x1,1),σ12I)采样的基础上再加上离散化操作。 L 0 L_0 L0也不再是一个常数,而是一个与 μ θ ( x 1 , 1 ) \mu_\theta(\mathbf{x}_1, 1) μθ(x1,1)相关的积分,其具体表达式可以参考DDPM (Ho et al 2020)的Sec3.3。在忽略 σ 1 2 \sigma_1^2 σ12和边缘效应后, L 0 L_0 L0的取值可以被 N ( μ θ ( x 1 , 1 ) , σ 1 2 I ) \mathcal{N}(\boldsymbol{\mu}_\theta(\mathbf{x}_1, 1), \sigma_1^2\mathbf{I}) N(μθ(x1,1),σ12I)的密度函数与离散时的分块大小(bin width)相乘所拟合。

​ 另外值得一提的是,逆Diffusion的最后一步,DDPM直接取 μ θ ( x 1 , 1 ) \mu_\theta(\mathbf{x}_1, 1) μθ(x1,1)作为 x 0 \mathbf{x}_0 x0

在这里插入图片描述

The overall training and sampling algorithms

在这里插入图片描述

训练过程

​ 训练时,分别从 q ( x 0 ) q(\mathbf{x}_0) q(x0) U n i f o r m ( 1 , . . . , T ) Uniform({1,...,T}) Uniform(1,...,T) N ( 0 , I ) \mathcal{N}(\mathbf{0},\textbf{I}) N(0,I)中采样得到 x 0 \mathbf{x}_0 x0 t t t ϵ \epsilon ϵ(这里的 ϵ \epsilon ϵ就是前面说的 z z z),利用公式Diffusion Model(2):前向扩散过程和逆向降噪过程
2-4
计算得到 x t = α ˉ t x 0 + 1 − α ˉ t z \mathbf{x}_t=\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}} \mathbf{z} xt=αˉt x0+1αˉt z,将 x t \mathbf{x}_t xt t t t送入网络,预测得到一个噪声。结合公式7最小化预测噪声和真实采样的 ϵ \epsilon ϵ之间的距离。重复这一过程直到网络收敛。

采样过程

​ 采样时,需要从 x T \mathbf{x}_T xT一步一步的变回 x 0 \mathbf{x}_0 x0,其中的每一步都包含三个操作:

  • x t \mathbf{x}_t xt t t t送入网络,预测得到噪声 ϵ \epsilon ϵ
  • 利用估计的噪声 ϵ \epsilon ϵ x t \mathbf{x}_t xt,计算 μ θ = 1 α t ( x t − β t 1 − α ˉ t ϵ ) \mu_{\theta}=\frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon) μθ=αt 1(xt1αˉt βtϵ)
  • 如果 t > 1 t>1 t>1,需要从 N ( μ θ , σ t 2 I ) \mathcal{N}(\mu_\theta, \sigma_t^2\mathbf{I}) N(μθ,σt2I)中采样得到 x t − 1 \mathbf{x}_{t-1} xt1,利用重参数化技巧,可以将采样过程转换为首先采样 z ∈ N ( 0 , I ) z\in\mathcal{N}(\mathbf{0},\textbf{I}) zN(0,I),然后计算 x t − 1 = μ θ + σ t z x_{t-1}=\mu_\theta+\sigma_tz xt1=μθ+σtz。如果 t = 1 t=1 t=1,直接令 x 0 = μ θ \mathbf{x}_0=\mu_\theta x0=μθ
网络中的参数选择

​ 在前向扩散过程中,需要确定的超参数有 β t \beta_t βt以及总的步数 T T T。在前向扩散过程中,其值通常会增加。

  • DDPM (Ho et al 2020)中,前向扩散过程中的方差 β t \beta_t βt被设置为从 β 1 = 1 0 − 4 \beta_1=10^{-4} β1=104 β T = 0.02 \beta_T=0.02 βT=0.02线性增加。与 $[-1,1] $之间的归一化图像像素值相比,它们相对较小。
  • Nichol & Dhariwal (2021) 中,提出了可以使用基于余弦的方差表。调度函数的选择可以是任意的,只要它在训练过程的中间提供一个接近线性的下降,以及在 t = 0 t = 0 t=0 t = T t = T t=T附近的细微变化。

β t = clip ⁡ ( 1 − α ˉ t α ˉ t − 1 , 0.999 ) α ˉ t = f ( t ) f ( 0 )  where  f ( t ) = cos ⁡ ( t / T + s 1 + s ⋅ π 2 ) \beta_{t}=\operatorname{clip}\left(1-\frac{\bar{\alpha}_{t}}{\bar{\alpha}_{t-1}}, 0.999\right) \quad \bar{\alpha}_{t}=\frac{f(t)}{f(0)} \quad \text { where } f(t)=\cos \left(\frac{t / T+s}{1+s} \cdot \frac{\pi}{2}\right) βt=clip(1αˉt1αˉt,0.999)αˉt=f(0)f(t) where f(t)=cos(1+st/T+s2π)

在这里插入图片描述

​ 注意此图的纵坐标是 α ˉ t = 1 − β t \bar{\alpha}_t=1-\beta_t αˉt=1βt而我们讨论的是 β t \beta_t βt的一个取值。

​ 对于逆向降噪过程中,我们需要选择一个网络来预测一个噪声来从 x t \mathbf{x}_t xt得到 x t − 1 \mathbf{x}_{t-1} xt1(参照采样的过程)。在这个过程中唯一的要求是我们需要保证输入和输出的维度是一样的。比如输入是 1 × 256 × 256 1\times256\times256 1×256×256那么输出也需要是 1 × 256 × 256 1\times256\times256 1×256×256。因此此处可以选择U-Net。
在这里插入图片描述
References:

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值