深度学习(生成式模型)——Classifier Guidance Diffusion

前言

几乎所有的生成式模型,发展到后期都需要引入"控制"的概念,可控制的生成式模型才能更好应用于实际场景。本文将总结《Diffusion Models Beat GANs on Image Synthesis》中提出的Classifier Guidance Diffusion(即条件扩散模型),其往Diffusion Model中引入了控制的概念,可以控制DDPM、DDIM生成指定类别(条件)的图片。

问题建模

本章节所有符号定义与DDPM一致,在条件 y y y下的Diffusion Model的前向与反向过程可以定义为
q ^ ( x t + 1 ∣ x t , y ) q ^ ( x t ∣ x t + 1 , y ) \begin{aligned} \hat q(x_{t+1}|x_{t},y)\\ \hat q(x_t|x_{t+1},y) \end{aligned} q^(xt+1xt,y)q^(xtxt+1,y)
只要求出上述两个概率密度函数,我们即可按条件生成图像。

我们利用 q ^ \hat q q^表示条件扩散模型的概率密度函数, q q q表示扩散模型的概率密度函数。

条件扩散模型的前向过程

对于前向过程,作者定义了以下等式
q ^ ( x 0 ) = q ( x 0 ) q ^ ( x t + 1 ∣ x t , y ) = q ( x t + 1 ∣ x t ) q ^ ( x 1 : T ∣ x 0 , y ) = ∏ t = 1 T q ^ ( x t ∣ x t − 1 , y ) \begin{aligned} \hat q(x_0)&=q(x_0)\\ \hat q(x_{t+1}|x_t,y)&=q(x_{t+1}|x_t)\\ \hat q(x_{1:T}|x_0,y)&=\prod_{t=1}^T\hat q(x_t|x_{t-1},y) \end{aligned} q^(x0)q^(xt+1xt,y)q^(x1:Tx0,y)=q(x0)=q(xt+1xt)=t=1Tq^(xtxt1,y)

基于上述第二行定义,可知基于条件 y y y的diffusion model的前向过程与普通的diffusion model一致,即 q ^ ( x t + 1 ∣ x t ) = q ( x t + 1 ∣ x t ) \hat q(x_{t+1}|x_t)=q(x_{t+1}|x_t) q^(xt+1xt)=q(xt+1xt)。即加噪过程与条件 y y y无关,这种定义也是合理的。

条件扩散模型的反向过程

对于反向过程,我们有
q ^ ( x t ∣ x t + 1 , y ) = q ^ ( x t , x t + 1 , y ) q ^ ( x t + 1 , y ) = q ^ ( x t , x t + 1 , y ) q ^ ( y ∣ x t + 1 ) q ^ ( x t + 1 ) = q ^ ( x t , y ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) = q ^ ( y ∣ x t , x t + 1 ) q ^ ( x t ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) (1.0) \begin{aligned} \hat q(x_t|x_{t+1},y)&=\frac{\hat q(x_t,x_{t+1},y)}{\hat q(x_{t+1},y)}\\ &=\frac{\hat q(x_t,x_{t+1},y)}{\hat q(y|x_{t+1})\hat q(x_{t+1})}\\ &=\frac{\hat q(x_t,y|x_{t+1})}{\hat q(y|x_{t+1})}\\ &=\frac{\hat q(y|x_t,x_{t+1})\hat q(x_{t}|x_{t+1})}{\hat q(y|x_{t+1})} \end{aligned}\tag{1.0} q^(xtxt+1,y)=q^(xt+1,y)q^(xt,xt+1,y)=q^(yxt+1)q^(xt+1)q^(xt,xt+1,y)=q^(yxt+1)q^(xt,yxt+1)=q^(yxt+1)q^(yxt,xt+1)q^(xtxt+1)(1.0)

已知条件扩散模型的前向过程与扩散模型一致,则有

q ^ ( x 1 : T ∣ x 0 ) = q ( x 1 : T ∣ x 0 ) \hat q(x_{1:T}|x_0)=q(x_{1:T}|x_0) q^(x1:Tx0)=q(x1:Tx0)

进而有
q ^ ( x t ) = ∫ q ^ ( x 0 , . . . , x t ) d x 0 : t − 1 = ∫ q ^ ( x 0 ) q ^ ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = ∫ q ( x 0 ) q ( x 1 : t ∣ x 0 ) d x 0 : t − 1 = q ( x t ) \begin{aligned} \hat q(x_{t})&=\int \hat q(x_0,...,x_t) dx_{0:t-1}\\ &=\int \hat q(x_0)\hat q(x_{1:t}|x_0)dx_{0:t-1}\\ &=\int q(x_0)q(x_{1:t}|x_0)dx_{0:t-1}\\ &=q(x_t) \end{aligned} q^(xt)=q^(x0,...,xt)dx0:t1=q^(x0)q^(x1:tx0)dx0:t1=q(x0)q(x1:tx0)dx0:t1=q(xt)

对于 q ^ ( x t ∣ x t + 1 ) \hat q(x_t|x_{t+1}) q^(xtxt+1),则有
q ^ ( x t ∣ x t + 1 ) = q ^ ( x t , x t + 1 ) q ^ ( x t + 1 ) = q ^ ( x t + 1 ∣ x t ) q ^ ( x t ) q ^ ( x t + 1 ) = q ( x t + 1 ∣ x t ) q ( x t ) q ( x t + 1 ) = q ( x t ∣ x t + 1 ) \begin{aligned} \hat q(x_t|x_{t+1})&=\frac{\hat q(x_t,x_{t+1})}{\hat q(x_{t+1})}\\ &=\frac{\hat q(x_{t+1}|x_t)\hat q(x_{t})}{\hat q(x_{t+1})}\\ &=\frac{q(x_{t+1}|x_t)q(x_{t})}{q(x_{t+1})}\\ &=q(x_{t}|x_{t+1}) \end{aligned} q^(xtxt+1)=q^(xt+1)q^(xt,xt+1)=q^(xt+1)q^(xt+1xt)q^(xt)=q(xt+1)q(xt+1xt)q(xt)=q(xtxt+1)

对于 q ^ ( y ∣ x t , x x t + 1 ) \hat q(y|x_t,x_{x_{t+1}}) q^(yxt,xxt+1),我们有
q ^ ( y ∣ x t , x x t + 1 ) = q ^ ( x t + 1 ∣ x t , y ) q ^ ( y ∣ x t ) q ^ ( x t + 1 ∣ x t ) = q ^ ( x t + 1 ∣ x t ) q ^ ( y ∣ x t ) q ^ ( x t + 1 ∣ x t ) = q ^ ( y ∣ x t ) \begin{aligned} \hat q(y|x_t,x_{x_{t+1}})&=\frac{\hat q(x_{t+1}|x_t,y)\hat q(y|x_t)}{\hat q(x_{t+1}|x_t)}\\ &=\frac{\hat q(x_{t+1}|x_t)\hat q(y|x_t)}{\hat q(x_{t+1}|x_t)}\\ &=\hat q(y|x_t) \end{aligned} q^(yxt,xxt+1)=q^(xt+1xt)q^(xt+1xt,y)q^(yxt)=q^(xt+1xt)q^(xt+1xt)q^(yxt)=q^(yxt)

因此式1.0为

q ^ ( x t ∣ x t + 1 , y ) = q ^ ( y ∣ x t , x t + 1 ) q ^ ( x t ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) = q ^ ( y ∣ x t ) q ( x t ∣ x t + 1 ) q ^ ( y ∣ x t + 1 ) \begin{aligned} \hat q(x_t|x_{t+1},y)&=\frac{\hat q(y|x_t,x_{t+1})\hat q(x_{t}|x_{t+1})}{\hat q(y|x_{t+1})}\\ &=\frac{\hat q(y|x_t)q(x_{t}|x_{t+1})}{\hat q(y|x_{t+1})} \end{aligned} q^(xtxt+1,y)=q^(yxt+1)q^(yxt,xt+1)q^(xtxt+1)=q^(yxt+1)q^(yxt)q(xtxt+1)

由于在反向过程中, x t + 1 x_{t+1} xt+1是已知的,因此 q ^ ( y ∣ x t + 1 ) \hat q(y|x_{t+1}) q^(yxt+1)也可看成已知值,设其倒数为 Z Z Z,则有

q ^ ( x t ∣ x t + 1 , y ) = Z q ^ ( y ∣ x t ) q ( x t ∣ x t + 1 ) \begin{aligned} \hat q(x_t|x_{t+1},y) = Z\hat q(y|x_t)q(x_{t}|x_{t+1}) \end{aligned} q^(xtxt+1,y)=Zq^(yxt)q(xtxt+1)

取log可得
log ⁡ q ^ ( x t ∣ x t + 1 , y ) = log ⁡ Z + log ⁡ q ^ ( y ∣ x t ) + log ⁡ q ^ ( x t ∣ x t + 1 ) (1.1) \begin{aligned} \log \hat q(x_{t}|x_{t+1},y)=\log Z+\log \hat q(y|x_t)+\log \hat q(x_t|x_{t+1})\tag{1.1} \end{aligned} logq^(xtxt+1,y)=logZ+logq^(yxt)+logq^(xtxt+1)(1.1)

q ^ ( x t ∣ x t + 1 ) = N ( μ t , ∑ t 2 ) \hat q(x_t|x_{t+1})=\mathcal N(\mu_t,\sum_t^2) q^(xtxt+1)=N(μt,t2),则有
log ⁡ q ^ ( x t ∣ x t + 1 ) = − 1 2 ( x t − μ t ) T ( ∑ t ) − 1 ( x t − μ t ) + C (1.2) \log \hat q(x_{t}|x_{t+1})=-\frac{1}{2}(x_t-\mu_t)^T({\sum}_t)^{-1}(x_t-\mu_t)+C\tag{1.2} logq^(xtxt+1)=21(xtμt)T(t)1(xtμt)+C(1.2)

对于 log ⁡ q ^ ( y ∣ x t ) \log \hat q(y|x_t) logq^(yxt),在 x t = μ t x_t=\mu_t xt=μt处做泰勒展开,则有

log ⁡ q ^ ( y ∣ x t ) ≈ log ⁡ q ^ ( y ∣ x t ) ∣ x t = μ t + ( x t − μ t ) ∇ x t log ⁡ q ^ ( y ∣ x t ) ∣ x t = μ t = C 1 + ( x t − μ t ) g (1.3) \begin{aligned} \log \hat q(y|x_t) &\approx \log \hat q(y|x_t)|_{x_t=\mu_t}+(x_t-\mu_t)\nabla_{x_t}\log\hat q(y|x_t)|_{x_t=\mu_t}\\ &=C_1+(x_t-\mu_t)g \end{aligned}\tag{1.3} logq^(yxt)logq^(yxt)xt=μt+(xtμt)xtlogq^(yxt)xt=μt=C1+(xtμt)g(1.3)
其中 g = ∇ x t log ⁡ q ^ ( y ∣ x t ) ∣ x t = μ t g=\nabla_{x_t}\log\hat q(y|x_t)|_{x_t=\mu_t} g=xtlogq^(yxt)xt=μt,结合式1.1、1.2、1.3,有

log ⁡ q ^ ( x t ∣ x t + 1 , y ) ≈ C 1 + ( x t − μ t ) g + log ⁡ Z − 1 2 ( x t − μ t ) T ( ∑ t ) − 1 ( x t − μ t ) + C = ( x t − μ t ) g − 1 2 ( x t − μ t ) T ( ∑ t ) − 1 ( x t − μ t ) + C 2 = − 1 2 ( x t − μ t − ∑ t g ) T ( ∑ t ) − 1 ( x t − μ t − ∑ t g ) + C 3 \begin{aligned} \log \hat q(x_{t}|x_{t+1},y)&\approx C_1+(x_t-\mu_t)g+\log Z-\frac{1}{2}(x_t-\mu_t)^T(\sum{_t})^{-1}(x_t-\mu_t)+C\\ &=(x_t-\mu_t)g-\frac{1}{2}(x_t-\mu_t)^T(\sum{_t})^{-1}(x_t-\mu_t)+C_2\\ &=-\frac{1}{2}(x_t-\mu_t-\sum{_t} g)^T(\sum{_t})^{-1}(x_t-\mu_t-\sum{_t}g)+C_3 \end{aligned} logq^(xtxt+1,y)C1+(xtμt)g+logZ21(xtμt)T(t)1(xtμt)+C=(xtμt)g21(xtμt)T(t)1(xtμt)+C2=21(xtμttg)T(t)1(xtμttg)+C3

最终有

q ^ ( x t ∣ x t + 1 , y ) ≈ N ( μ t + ∑ t g , ( ∑ t ) 2 ) g = ∇ x t log ⁡ q ^ ( y ∣ x t ) ∣ x t = μ t (1.4) \begin{aligned} \hat q(x_t|x_{t+1},y)\approx \mathcal N(\mu_t+{\sum}_{t}g,({\sum}_t)^2)\\ g=\nabla_{x_t}\log\hat q(y|x_t)|_{x_t=\mu_t} \end{aligned}\tag{1.4} q^(xtxt+1,y)N(μt+tg,(t)2)g=xtlogq^(yxt)xt=μt(1.4)

为了获得 ∇ x t log ⁡ q ^ ( y ∣ x t ) \nabla_{x_t}\log\hat q(y|x_t) xtlogq^(yxt),Classifier Guidance Diffusion在训练好的Diffusion model的基础上额外训练了一个分类头。

假设 x t ≈ μ t x_t \approx\mu_t xtμt,则Classifier Guidance Diffusion的反向过程为:
在这里插入图片描述

其中 p ϕ ( y ∣ x t ) = q ^ ( y ∣ x t ) p_ \phi(y|x_t)=\hat q(y|x_t) pϕ(yxt)=q^(yxt) s s s为一个超参数。

式1.4有个问题,当方差 ∑ \sum 取值为0时, ∑ ∇ x t log ⁡ q ^ ( y ∣ x t ) {\sum}\nabla_{x_t}\log\hat q(y|x_t) xtlogq^(yxt)取值将为0,无法控制生成指定条件的图像。因此式1.4不适用于DDIM等确定性采样的扩散模型

在推导DDIM的采样公式前,我们先了解一下用Tweedie方法做参数估计的流程。

Tweedie方法主要用于指数族概率分布的参数估计,而高斯分布属于指数族概率分布,自然也适用。假设有一批样本 z z z,则利用样本 z z z估计高斯分布 N ( Z ; μ , ∑ 2 ) \mathcal N(Z;\mu,{\sum}^2) N(Z;μ,2)的均值 μ \mu μ的公式为

E [ μ ∣ z ] = z + ∑ 2 ∇ z log ⁡ p ( z ) (1.5) E[\mu|z]=z+{\sum}^2\nabla_z\log p(z)\tag{1.5} E[μz]=z+2zlogp(z)(1.5)

已知DDPM、DDIM的前向过程有

q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) (1.6) q(x_t|x_0)=\mathcal N(x_t;\sqrt{\bar \alpha_t}x_0,(1-\bar\alpha_t)\mathcal I)\tag{1.6} q(xtx0)=N(xt;αˉt x0,(1αˉt)I)(1.6)

结合式1.5、1.6可得

α ˉ t x 0 = x t + ( 1 − α ˉ t ) ∇ x t log ⁡ p ( x t ) \begin{aligned} \sqrt{\bar \alpha_t}x_0=x_t+(1-\bar\alpha_t)\nabla_{x_t}\log p(x_t) \end{aligned} αˉt x0=xt+(1αˉt)xtlogp(xt)
进而有
x t = α ˉ t x 0 − ( 1 − α ˉ t ) ∇ x t log ⁡ p ( x t ) (1.7) x_t=\sqrt{\bar \alpha_t}x_0-(1-\bar\alpha_t)\nabla_{x_t}\log p(x_t)\tag{1.7} xt=αˉt x0(1αˉt)xtlogp(xt)(1.7)
ϵ t \epsilon_t ϵt服从标准正态分布,则从式1.6可知

x t = α ˉ t x 0 + 1 − α ˉ t ϵ t (1.8) x_t=\sqrt{\bar \alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon_t\tag{1.8} xt=αˉt x0+1αˉt ϵt(1.8)

结合式1.7、1.8,则有

∇ x t log ⁡ p ( x t ) = − 1 1 − α ˉ t ϵ t (1.9) \nabla_{x_t}\log p(x_t)=-\frac{1}{\sqrt{1-\bar\alpha_t}}\epsilon_t\tag{1.9} xtlogp(xt)=1αˉt 1ϵt(1.9)

已知DDIM的采样公式为

x t − 1 = α ˉ t − 1 x t − 1 − α ˉ t ϵ θ ( x t ) α ˉ t + 1 − α ˉ t − δ t 2 ϵ θ ( x t ) (2.0) x_{t-1}=\sqrt{\bar \alpha_{t-1}}\frac{x_t-\sqrt{1-\bar \alpha_t}\epsilon_\theta(x_t)}{\sqrt{\bar\alpha_t}}+\sqrt{1-\bar\alpha_{t}-\delta_t^2}\epsilon_\theta(x_t)\tag{2.0} xt1=αˉt1 αˉt xt1αˉt ϵθ(xt)+1αˉtδt2 ϵθ(xt)(2.0)

结合式1.9、2.0可将DDIM的采样公式转变为

x t − 1 = α ˉ t − 1 x t − 1 − α ˉ t ( − 1 − α ˉ t ∇ x t log ⁡ p ( x t ) ) α ˉ t + 1 − α ˉ t − δ t 2 ( − 1 − α ˉ t ∇ x t log ⁡ p ( x t ) ) (2.1) x_{t-1}=\sqrt{\bar \alpha_{t-1}}\frac{x_t-\sqrt{1-\bar \alpha_t}(-\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(x_t))}{\sqrt{\bar\alpha_t}}+\sqrt{1-\bar\alpha_{t}-\delta_t^2}(-\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(x_t))\tag{2.1} xt1=αˉt1 αˉt xt1αˉt (1αˉt xtlogp(xt))+1αˉtδt2 (1αˉt xtlogp(xt))(2.1)

我们只需要将其中的 ∇ x t log ⁡ p ( x t ) \nabla_{x_t}\log p(x_t) xtlogp(xt)替换为 ∇ x t log ⁡ p ( x t ∣ y ) \nabla_{x_t}\log p(x_t|y) xtlogp(xty),即可引入条件 y y y来控制DDIM的生成过程,利用贝叶斯定理,我们有

log ⁡ p ( x t ∣ y ) = log ⁡ p ( y ∣ x t ) + log ⁡ p ( x t ) − log ⁡ p ( y ) ∇ x t log ⁡ p ( x t ∣ y ) = ∇ x t log ⁡ p ( y ∣ x t ) + ∇ x t log ⁡ p ( x t ) − ∇ x t log ⁡ p ( y ) = ∇ x t log ⁡ p ( y ∣ x t ) + ∇ x t log ⁡ p ( x t ) = ∇ x t log ⁡ p ( y ∣ x t ) − 1 1 − α ˉ t ϵ t (2.2) \begin{aligned} \log p(x_t|y)&=\log p(y|x_t)+\log p(x_t)-\log p(y)\\ \nabla_{x_t}\log p(x_t|y)&=\nabla_{x_t}\log p(y|x_t)+\nabla_{x_t}\log p(x_t)-\nabla_{x_t}\log p(y)\\ &=\nabla_{x_t}\log p(y|x_t)+\nabla_{x_t}\log p(x_t)\\ &=\nabla_{x_t}\log p(y|x_t)-\frac{1}{\sqrt{1-\bar\alpha_t}}\epsilon_t \end{aligned}\tag{2.2} logp(xty)xtlogp(xty)=logp(yxt)+logp(xt)logp(y)=xtlogp(yxt)+xtlogp(xt)xtlogp(y)=xtlogp(yxt)+xtlogp(xt)=xtlogp(yxt)1αˉt 1ϵt(2.2)
则有

− 1 − α ˉ t ∇ x t log ⁡ p ( x t ∣ y ) = ϵ t − 1 − α ˉ t ∇ x t log ⁡ p ( y ∣ x t ) (2.3) -\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(x_t|y)=\epsilon_t-\sqrt{1-\bar\alpha_t}\nabla_{x_t}\log p(y|x_t)\tag{2.3} 1αˉt xtlogp(xty)=ϵt1αˉt xtlogp(yxt)(2.3)

至此,我们可以得到DDIM的采样流程为
在这里插入图片描述
对于DDIM等确定性采样的扩散模型,其应在训练好的Diffusion model的基础上额外训练了一个分类头,从而转变为Classifier Guidance Diffusion。

条件扩散模型的训练目标

注意到 q ^ ( x t ∣ x t + 1 ) = q ( x t ∣ x t + 1 ) \hat q(x_t|x_{t+1})=q(x_t|x_{t+1}) q^(xtxt+1)=q(xtxt+1),并且上述的推导过程并没有改变 q ( x t ∣ x t + 1 ) 、 q ( x t + 1 ∣ x t ) q(x_t|x_{t+1})、q(x_{t+1}|x_t) q(xtxt+1)q(xt+1xt)的形式,因此Classifier Guidance Diffusion的训练目标与DDPM、DDIM是一致的,都可以拟合训练数据。

  • 6
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
classifier-free diffusion guidance(无分类器扩散引导)是一种新兴的技术,用于在无需提前训练分类器的情况下进行目标导航。 传统的目标导航技术通常需要使用先验知识和已经训练好的分类器来辨别和识别目标。然而,这些方法存在许多限制和缺点,如对精确的先验知识的需求以及对大量标记数据的依赖。 相比之下,classifier-free diffusion guidance 可以在目标未知的情况下进行导航,避免了先验知识和训练好的分类器的依赖。它的主要思想是利用传感器和环境反馈信息,通过推测和逐步调整来实现导航。 在这种方法中,机器人通过感知环境中的信息,例如物体的形状、颜色、纹理等特征,获取关于目标位置的信息。然后,它将这些信息与先验的环境模型进行比较,并尝试找到与目标最相似的区域。 为了进一步提高导航的准确性,机器人还可以利用扩散算法来调整自己的位置和方向。通过比较当前位置的特征与目标位置的特征,机器人可以根据这些差异进行调整,逐渐接近目标。 需要注意的是,classifier-free diffusion guidance还处于研究阶段,目前还存在许多挑战和问题。例如,对于复杂的环境和多个目标,算法的性能可能会下降。然而,随着技术的发展,我们可以预见classifier-free diffusion guidance将会在未来的目标导航中发挥重要的作用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值