系统理解扩散模型(Diffusion Models):从柏拉图洞穴之喻开始(中)
变分扩散模型(Variational Diffusion Models)
设想我们给HVAE模型增添三个限制条件:
- 潜在变量的维度和数据的维度相同;
- 每一层级的编码器不是通过学习得到的,而是事先预定好的线性高斯模型;
- 最后一层(第 T T T层)的潜在变量分布是一个标准高斯分布。
加上上述条件的HVAE模型就是变分扩散模型(Variational Diffusion Models,VDM)。
由第一个条件,我们可以先统一符号:用 x 0 x_0 x0表示真实的数据样本,而用 x t , t ∈ [ 1 , T ] x_t, t \in [1, T] xt,t∈[1,T]表示对应第 t t t层的潜在变量。此时,后验分布可以重写为:
q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) \begin{equation} q(x_{1:T}|x_0) = \prod_{t=1}^T q(x_t|x_{t-1}) \end{equation} q(x1:T∣x0)=t=1∏Tq(xt∣xt−1)
根据第二个条件,我们将高斯编码器的均值设置为 μ t ( x t ) = α t x t − 1 \mu_t(x_t)=\sqrt \alpha_t x_{t-1} μt(xt)=αtxt−1,并将其方差设置为 Σ t ( x t ) = ( 1 − α t ) I \Sigma_t(x_t) = (1-\alpha_t) I Σt(xt)=(1−αt)I。此时,编码器可以表示为:
q ( x t ∣ x t − 1 ) = N ( x t ; α t x t − 1 , ( 1 − α t ) I ) \begin{equation} q(x_t|x_{t-1}) = \mathcal{N}(x_t; \sqrt \alpha_t x_{t-1}, (1-\alpha_t) I) \end{equation} q(xt∣xt−1)=N(xt;αtxt−1,(1−αt)I)
根据第三个条件, α t \alpha_t αt的值需要遵循一定的规律,使得最后一层的潜在分布 p ( x T ) p(x_T) p(xT)是一个标准高斯分布。此时,VDM的联合分布可以重写为
p ( x 0 : T ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) where p ( x T ) = N ( x T ; 0 , I ) \begin{align} p(x_{0:T}) &= p(x_T)\prod_{t=1}^{T}p_\theta(x_{t-1}|x_t) \\ &\text{where} \\ &p(x_T) = \mathcal{N}(x_T; 0, I) \end{align} p(x0:T)=p(xT)t=1∏Tpθ(xt−1∣xt)wherep(xT)=N(xT;0,I)
如果以图片为输入,这就相当于不断给这张图片加上一系列的噪声,只至输出为纯高斯噪声。值得注意的是,由于编码过程就是按照既定过程加高斯噪声,编码器分布 q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q(xt∣xt−1)不再有参数 ϕ \phi ϕ。也就是说,对于VDM模型,我们关注 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) pθ(xt−1∣xt)即可,并以此生成新数据。具体而言,训练完成后,我们从 p ( x T ) p(x_T) p(xT)采样出高斯噪声,然后逐步执行去噪过程 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1}|x_t) pθ(xt−1∣xt)生成新的 x 0 x_0 x0。
相似地,我们可以通过最大化ELBO来优化VDM:
log p ( x ) = log ∫ p ( x 0 : T ) d x 1 : T = log ∫ p ( x 0 : T ) q ( x 1 : T ∣ x 0 ) q ( x 1 : T ∣ x 0 ) d x 1 : T = log E q ( x 1 : T ∣ x 0 ) [ p ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ] ≥ E q ( x 1 : T ∣ x 0 ) [ log p ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) ∏ t = 1 T q ( x t ∣ x t − 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) p θ ( x 0 ∣ x 1 ) ∏ t = 2 T p θ ( x t − 1 ∣ x t ) q ( x T ∣ x T − 1 ) ∏ t = 1 T − 1 q ( x t ∣ x t − 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) p θ ( x 0 ∣ x 1 ) ∏ t = 1 T − 1 p θ ( x t ∣ x t + 1 ) q ( x T ∣ x T − 1 ) ∏ t = 1 T − 1 q ( x t ∣ x t − 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) p θ ( x 0 ∣ x 1 ) q ( x T ∣ x T − 1 ) ] + E q ( x 1 : T ∣ x 0 ) [ log ∏ t = 1 T − 1 p θ ( x t ∣ x t + 1 ) q ( x t ∣ x t − 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] + E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) q ( x T ∣ x T − 1 ) ] + E q ( x 1 : T ∣ x 0 ) [ ∑ t = 1 T − 1 log p θ ( x t ∣ x t + 1 ) q ( x t ∣ x t − 1 ) ] = E q ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] + E q ( x 1 : T ∣ x 0 ) [ log p ( x T ) q ( x T ∣ x T − 1 ) ] + ∑ t = 1 T − 1 E q ( x 1 : T ∣ x 0 ) [ log p θ ( x t ∣ x t + 1 ) q ( x t ∣ x t − 1 ) ] = E q ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] + E q ( x T − 1 , x T ∣ x 0 ) [ log p ( x T ) q ( x T ∣ x T − 1 ) ] + ∑ t = 1 T − 1 E q ( x t − 1 , x t , x t + 1 ∣ x 0 ) [ log p θ ( x t ∣ x t + 1 ) q ( x t ∣ x t − 1 ) ] = E q ( x 1 ∣ x 0 ) [ log p θ ( x 0 ∣ x 1 ) ] − E q ( x T − 1 ∣ x 0 ) [ D K L ( q ( x T ∣ x T − 1 ) ∣ ∣ p ( x T ) ) ] − ∑ t = 1 T − 1 E q ( x t − 1 , x t + 1 ∣ x 0 ) [ D K L ( q ( x t ∣ x t − 1 ) ∣ ∣ p θ ( x t ∣ x t + 1 ) ) ] \begin{align} \log p(x) &= \log \int p(x_{0:T}) dx_{1:T} \\ &= \log \int \frac{p(x_{0:T})q(x_{1:T}|x_0)}{q(x_{1:T}|x_0)}dx_{1:T} \\ &= \log \mathbb{E}_{q(x_{1:T}|x_0)}\left[\frac{p(x_{0:T})}{q(x_{1:T}|x_0)}\right] \\ & \geq \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_{0:T})}{q(x_{1:T}|x_0)}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)\prod_{t=1}^{T}p_\theta(x_{t-1}|x_t)}{\prod_{t=1}^T q(x_t|x_{t-1})}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)p_\theta(x_0|x_1)\prod_{t=2}^{T}p_\theta(x_{t-1}|x_t)}{q(x_T|x_{T-1}) \prod_{t=1}^{T-1} q(x_t|x_{t-1})}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)p_\theta(x_0|x_1)\prod_{t=1}^{T-1}p_\theta(x_{t}|x_{t+1})}{q(x_T|x_{T-1}) \prod_{t=1}^{T-1} q(x_t|x_{t-1})}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)p_\theta(x_0|x_1)}{q(x_T|x_{T-1}) }\right] \\ &\quad+ \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\prod_{t=1}^{T-1}\frac{p_\theta(x_{t}|x_{t+1})}{q(x_t|x_{t-1})}\right] \\ &= \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log p_\theta(x_0|x_1)\right] + \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)}{q(x_T|x_{T-1}) }\right]\\ &\quad+\mathbb{E}_{q(x_{1:T}|x_0)}\left[\sum_{t=1}^{T-1}\log\frac{p_\theta(x_{t}|x_{t+1})}{q(x_t|x_{t-1})}\right] \\ &=\mathbb{E}_{q(x_{1:T}|x_0)}\left[\log p_\theta(x_0|x_1)\right] + \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p(x_T)}{q(x_T|x_{T-1}) }\right]\\ &\quad+\sum_{t=1}^{T-1}\mathbb{E}_{q(x_{1:T}|x_0)}\left[\log\frac{p_\theta(x_{t}|x_{t+1})}{q(x_t|x_{t-1})}\right] \\ &=\mathbb{E}_{q(x_{1}|x_0)}\left[\log p_\theta(x_0|x_1)\right] + \mathbb{E}_{q(x_{T-1}, x_T|x_0)}\left[\log\frac{p(x_T)}{q(x_T|x_{T-1}) }\right]\\ &\quad+\sum_{t=1}^{T-1}\mathbb{E}_{q(x_{t-1}, x_t, x_{t+1}|x_0)}\left[\log\frac{p_\theta(x_{t}|x_{t+1})}{q(x_t|x_{t-1})}\right] \\ &=\mathbb{E}_{q(x_{1}|x_0)}\left[\log p_\theta(x_0|x_1)\right] \\ &\quad- \mathbb{E}_{q(x_{T-1}|x_0)}\left[D_{KL}(q(x_T|x_{T-1}) \ ||\ p(x_T))\right] \\ &\quad-\sum_{t=1}^{T-1}\mathbb{E}_{q(x_{t-1}, x_{t+1}|x_0)}\left[D_{KL}(q(x_t|x_{t-1})\ ||\ p_\theta(x_{t}|x_{t+1}))\right] \end{align} logp(x)=log∫p(x0:T)dx1:T=log∫q(x1:T∣x0)p(x0:T)q(x1:T∣x0)dx1:T=logEq(x1:T∣x0)[q(x1:T∣x0)p(x0:T)]≥Eq(x1:T∣x0)[logq(x1:T∣x0)p(x0:T)]=Eq(x1:T∣x0)[log∏t=1Tq(xt∣xt−1)p(xT)∏t=1Tpθ(xt−1∣xt)]=Eq(x1:T∣x0)[logq(xT∣xT−1)∏t=1T−1q(xt∣xt−1)p(xT)pθ(x0∣x1)∏t=2Tpθ(xt−1∣xt)]=Eq(x1:T∣x0)[logq(xT∣xT−1)∏t=1T−1q(xt∣xt−1)p(xT)pθ(x0∣x1)∏t=1T−1pθ(xt∣xt+1)]=Eq(x1:T∣x0)[logq(xT∣xT−1)p(xT)pθ(x0∣x1)]+Eq(x1:T∣x0)[logt=1∏T−1q(xt∣xt−1)pθ(xt∣xt+1)]=Eq(x1:T∣x0)[logpθ(x0∣x1)]+Eq(x1:T∣x0)[logq(xT∣xT−1)p(xT)]+Eq(x1:T∣x0)[t=1∑T−1logq(xt∣xt−1)pθ(xt∣xt+1)]=Eq(x1:T∣x0)[logpθ(x0∣x1)]+Eq(x1:T∣x0)[logq(xT∣xT−1)p(xT)]+t=1∑T−1Eq(x1:T∣x0)[logq(xt∣xt−1)pθ(xt∣xt+1)]=Eq(x