【学习笔记】扩散模型的基本原理与训练方法

1. Diffusion前向过程

给定采样自数据集的真实图片 x 0 ∼ q ( x ) x_0 \sim q(x) x0q(x),并对其添加高斯噪声,共进行 T T T步,称该过程为 q q q过程,能够得到添加噪声后的图片分布 x 1 , x 2 , . . . , x T x_1,x_2,...,x_T x1,x2,...,xT

将加噪的过程看作一个马尔可夫过程,即 t t t时刻的状态只与 t − 1 t-1 t1时刻有关,设置超参数 β t ∈ ( 0 , 1 ) , t ∈ ( 1 , T ) \beta_t \in (0,1), t\in(1,T) βt(0,1),t(1,T) ,本质上该超参数即为每一时刻下添加的高斯分布的方差。至此可以将该前向 q q q过程写成以下形式,表示在 x t − 1 x_{t-1} xt1满足的分布前提下, x t x_{t} xt的方差为 β t \beta_t βt,均值则受到了前一个状态的影响。
q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t|x_{t-1})=\mathcal N(x_t;\sqrt{1-\beta_t}x_{t-1},\beta_t \mathbf{I}) q(xtxt1)=N(xt;1βt xt1,βtI)
实际场景中, β \beta β是随着 t t t的增大而递增的,那么当 t t t很大时, β t \beta_t βt 趋近于1,则 x t x_t xt满足的分布接近于标准正态分布。【 β t \beta_t βt这一超参数的存在本质是描述方差,但却在均值上乘上了 1 − β t \sqrt{1-\beta_t} 1βt ,这能够使得均值在 t t t很大的时候趋向于0,整个分布则为标准正态分布。】

由于以上过程看作一个马尔可夫过程,因此根据乘法公式可以写出以下表达式:
q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q(x_{1:T}|x_0)=\prod_{t=1}^{T}q(x_t|x_{t-1}) q(x1:Tx0)=t=1Tq(xtxt1)

重参数技巧(reparameterization trick)

以上前向的 q q q过程的表达形式是采样形式的,表示在某个分布中进行随机采样,这回导致过程有随机性,无法反向传播梯度,为了使该采样过程变得可导,使用冲参数技巧,引入一个固定的随机变量 ϵ \epsilon ϵ实现。

例如需要进行以下采样: z ∼ N ( z ; μ θ , σ θ 2 I ) z \sim \mathcal N (z;\mu_\theta,\sigma_{\theta}^{2}\mathbf{I}) zN(z;μθ,σθ2I),可以写成: z = μ θ + σ θ ⊙ ϵ z=\mu_{\theta}+\sigma_{\theta}\odot\epsilon z=μθ+σθϵ,这样的话 z z z依旧是一个随机变量,但对于 μ θ \mu_{\theta} μθ σ θ \sigma_{\theta} σθ等含有网络参数的参数,能够通过 z z z进行梯度的求取,随机性完全来自于固定不变服从标注正态分布的$\epsilon $中了。

那么对于原先的 q q q过程,可以使用重参数技巧重写 x t x_t xt,即:
x t = 1 − β t x t − 1 + β t z , z ∼ N ( 0 , I ) x_t=\sqrt{1-\beta_t}x_{t-1}+\sqrt{\beta_t}z,z \sim \mathcal N (0,\mathbf{I}) xt=1βt xt1+βt zzN(0,I)

使用统一的方法表示 x t x_t xt

由于上述扩散过程是一步步进行的,为了能够快速得到 x t x_t xt,考虑使用 x 0 x_0 x0 β \beta β进行统一表示。

首先假设 α t = 1 − β t \alpha_t=1-\beta_t αt=1βt,并记 α t ‾ = ∏ i = 1 T α i \overline{\alpha_t}=\prod_{i=1}^{T}\alpha_i αt=i=1Tαi

则对 x t x_t xt有以下推导过程:

x t = α t x t − 1 + 1 − α t z 1   = α t ( α t − 1 x t − 2 + 1 − α t − 1 z 2 ) + 1 − α t z 1   = α t α t − 1 x t − 2 + α t ( 1 − α t − 1 ) z 2 + 1 − α t z 1   = α t α t − 1 x t − 2 + 1 − α t α t − 1 z 2 ‾   . . .   = α t ‾ x 0 + 1 − α t ‾   z t ‾ \begin{aligned} %aligned命令对齐,在对齐的地方用"&" x_t &=\sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}z_1 \\\ &=\sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t-1}}z_2)+\sqrt{1-\alpha_t}z_1 \\\ &=\sqrt{\alpha_t \alpha_{t-1}}x_{t-2}+\sqrt{\alpha_t(1-\alpha_{t-1})}z_2+\sqrt{1-\alpha_t}z_1 \\\ &=\sqrt{\alpha_t \alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_t \alpha_{t-1}}\overline{z_2} \\\ & ... \\\ &=\sqrt{\overline{\alpha_t}}x_0+\sqrt{1-\overline{\alpha_t}}\ \overline{z_t} \end{aligned} xt     =αt xt1+1αt z1=αt (αt1 xt2+1αt1 z2)+1αt z1=αtαt1 xt2+αt(1αt1) z2+1αt z1=αtαt1 xt2+1αtαt1 z2...=αt x0+1αt  zt

其中对于 α t ( 1 − α t − 1 ) z 2 + 1 − α t z 1 = 1 − α t α t − 1 z 2 ‾ \sqrt{\alpha_t(1-\alpha_{t-1})}z_2+\sqrt{1-\alpha_t}z_1=\sqrt{1-\alpha_t \alpha_{t-1}}\overline{z_2} αt(1αt1) z2+1αt z1=1αtαt1 z2的推导过程,有以下分析:

α t ( 1 − α t − 1 ) z 2 ∼ N ( 0 , α t ( 1 − α t − 1 ) I )   1 − α t z 1 ∼ N ( 0 , ( 1 − α t ) I )   α t ( 1 − α t − 1 ) z 2 + 1 − α t z 1 ∼ N ( 0 , ( 1 − α t α t − 1 ) I )   \sqrt{\alpha_t(1-\alpha_{t-1})}z_2 \sim \mathcal N(0,\alpha_t(1-\alpha_{t-1})\mathbf{I}) \\\ \sqrt{1-\alpha_t}z_1 \sim \mathcal N (0,(1-\alpha_t)\mathbf{I}) \\\ \sqrt{\alpha_t(1-\alpha_{t-1})}z_2+\sqrt{1-\alpha_t}z_1 \sim \mathcal N(0,(1-\alpha_t \alpha_{t-1})\mathbf{I}) \\\ αt(1αt1) z2N(0,αt(1αt1)I) 1αt z1N(0,(1αt)I) αt(1αt1) z2+1αt z1N(0,(1αtαt1)I) 

对于得到的混合高斯分布 N ( 0 , ( 1 − α t α t − 1 ) I ) \mathcal N(0,(1-\alpha_t \alpha_{t-1})\mathbf{I}) N(0,(1αtαt1)I),可以表达为 1 − α t α t − 1 z 2 ‾ \sqrt{1-\alpha_t \alpha_{t-1}}\overline{z_2} 1αtαt1 z2,对于 z 2 ‾ \overline{z_2} z2依然服从的是标准高斯分布。

综上,对于 x t x_t xt的表达式形式为 x t = α t ‾ x 0 + 1 − α t ‾ z x_t=\sqrt{\overline{\alpha_t}}x_0+\sqrt{1-\overline{\alpha_t}}z xt=αt x0+1αt z,采样形式为 q ( x t ∣ x 0 ) ∼ N ( x t ; α t ‾ x 0 , ( 1 − α t ‾ ) I ) q(x_t|x_0) \sim \mathcal N(x_t;\sqrt{\overline{\alpha_t}}x_0,(1-\overline{\alpha_t})\mathbf{I}) q(xtx0)N(xt;αt x0,(1αt)I)。可以看成原始图片与高斯噪声的加权求和结果。

2. Diffusion逆向推断

该过程可以看成前向 q q q过程的逆过程,即去噪过程。目前在这里我们已知所有的前向过程 q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q(xtxt1),想要从标准正态分布中逐步去噪得到最终的原图分布,即从 x T x_T xT得到 x 0 x_0 x0,那么就需要知道 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt)。但这是十分困难的,因此考虑使用深度神经网络来对该分布进行预测,即训练一个模型使其能够做到 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)。【即让神经网络去学习去噪的过程】

对于 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_{t}) pθ(xt1xt),写成正态分布的表达形式,即为下式,可以看到我们需要使网络根据 x t x_t xt t t t学习得到分布的均值与方差。

p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_{\theta}(x_{t-1}|x_t)=\mathcal N (x_{t-1};\mu_{\theta}(x_t,t),\Sigma_{\theta}(x_t,t)) pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))

虽然 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt)难以直接得到,但是可以引入已知的原图像分布 x 0 x_0 x0的先验知识,尝试得到 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0),可以进行以下过程推导:

q ( x t − 1 ∣ x t , x 0 ) = q ( x t − 1 , x t , x 0 ) q ( x t , x 0 )   = q ( x t − 1 , x t , x 0 ) q ( x t − 1 , x 0 ) q ( x t − 1 , x 0 ) q ( x t , x 0 )   = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 , x 0 ) q x 0 q ( x t , x 0 ) q x 0   = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) \begin{aligned} q(x_{t-1}|x_t,x_0)&=\frac{q(x_{t-1},x_{t},x_{0})}{q(x_t,x_0)} \\\ &=\frac{q(x_{t-1},x_t,x_0)}{q(x_{t-1},x_0)}\frac{q(x_{t-1},x_0)}{q(x_t,x_0)} \\\ &=q(x_t|x_{t-1},x_0)\frac{\frac{q(x_{t-1},x_0)}{q_{x_0}}}{\frac{q(x_t,x_0)}{q_{x_0}}}\\\ &=q(x_t|x_{t-1},x_0)\frac{q(x_{t-1}|x_0)}{q(x_t|x_0)} \end{aligned} q(xt1xt,x0)   =q(xt,x0)q(xt1,xt,x0)=q(xt1,x0)q(xt1,xt,x0)q(xt,x0)q(xt1,x0)=q(xtxt1,x0)qx0q(xt,x0)qx0q(xt1,x0)=q(xtxt1,x0)q(xtx0)q(xt1x0)

由此将加入先验知识的逆向分布转换为正向分布的表达式。

对于几个前向过程进行表达式说明:

q ( x t − 1 ∣ x 0 ) = α t − 1 ‾ x 0 + 1 − α t − 1 ‾ z ∼ N ( α t − 1 ‾ x 0 , 1 − α t − 1 ‾ )   q ( x t ∣ x 0 ) = α t ‾ x 0 + 1 − α t ‾ z ∼ N ( α t ‾ x 0 , 1 − α t ‾ ) q(x_{t-1}|x_0)=\sqrt{\overline{\alpha_{t-1}}}x_0+\sqrt{1-\overline{\alpha_{t-1}}}z \sim \mathcal N (\sqrt{\overline{\alpha_{t-1}}}x_0,1-\overline{\alpha_{t-1}}) \\\ q(x_{t}|x_0)=\sqrt{\overline{\alpha_{t}}}x_0+\sqrt{1-\overline{\alpha_{t}}}z \sim \mathcal N (\sqrt{\overline{\alpha_{t}}}x_0,1-\overline{\alpha_{t}}) q(xt1x0)=αt1 x0+1αt1 zN(αt1 x0,1αt1) q(xtx0)=αt x0+1αt zN(αt x0,1αt)

由于 q q q过程是马尔可夫过程,则有:

q ( x t ∣ x t − 1 , x 0 ) = q ( x t ∣ x t − 1 ) = α t x t − 1 + 1 − α t z ∼ N ( α t x t − 1 , 1 − α t ) q(x_t|x_{t-1},x_0)=q(x_t|x_{t-1})=\sqrt{\alpha_{t}}x_{t-1}+\sqrt{1-\alpha_{t}}z \sim \mathcal N (\sqrt{\alpha_{t}}x_{t-1},1-\alpha_{t}) q(xtxt1,x0)=q(xtxt1)=αt xt1+1αt zN(αt xt1,1αt)

而对于高斯分布,我们能够使用概率密度函数的形式进行表达,即: N ( μ , σ 2 ) ∝ e x p ( − ( x − μ ) 2 2 σ 2 ) \mathcal N(\mu,\sigma^2)\propto exp(-\frac{(x-\mu)^2}{2\sigma^2}) N(μ,σ2)exp(2σ2(xμ)2),则考虑使用概率密度函数描述 q ( x t − 1 ∣ x 0 ) 、 q ( x t ∣ x 0 ) 、 q ( x t ∣ x t − 1 , x 0 ) q(x_{t-1}|x_0)、q(x_t|x_0)、q(x_t|x_{t-1},x_0) q(xt1x0)q(xtx0)q(xtxt1,x0),如下表达式所示:

q ( x t − 1 ∣ x 0 ) ∝ e x p ( − ( x t − 1 − α t − 1 ‾ x 0 ) 2 2 ( 1 − α t − 1 ‾ ) )   q ( x t ∣ x 0 ) ∝ e x p ( − ( x t − α t ‾ x 0 ) 2 2 ( 1 − α t ‾ ) )   q ( x t ∣ x t − 1 , x 0 ) ∝ e x p ( − ( x t − α t x t − 1 ) 2 2 ( 1 − α t ) ) q(x_{t-1}|x_0)\propto exp(-\frac{(x_{t-1}-\sqrt{\overline{\alpha_{t-1}}}x_0)^2}{2(1-\overline{\alpha_{t-1}})}) \\\ q(x_{t}|x_0)\propto exp(-\frac{(x_{t}-\sqrt{\overline{\alpha_{t}}}x_0)^2}{2(1-\overline{\alpha_{t}})}) \\\ q(x_t|x_{t-1},x_0)\propto exp(-\frac{(x_{t}-\sqrt{\alpha_{t}}x_{t-1})^2}{2(1-\alpha_{t})}) q(xt1x0)exp(2(1αt1)(xt1αt1 x0)2) q(xtx0)exp(2(1αt)(xtαt x0)2) q(xtxt1,x0)exp(2(1αt)(xtαt xt1)2)

则最终可以表达 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0),如下推导所示。由于 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)是关于 x t − 1 x_{t-1} xt1的表达式,则需要进行同类相合并

q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 )   ∝ e x p ( − ( x t − 1 − α t − 1 ‾ x 0 ) 2 2 ( 1 − α t − 1 ‾ ) ) + e x p ( − ( x t − α t x t − 1 ) 2 2 ( 1 − α t ) ) − e x p ( − ( x t − α t ‾ x 0 ) 2 2 ( 1 − α t ‾ ) )   = e x p ( − 1 2 ( ( x t − 1 − α t − 1 ‾ x 0 ) 2 1 − α t − 1 ‾ + ( x t − α t x t − 1 ) 2 1 − α t − ( x t − α t ‾ x 0 ) 2 1 − α t ‾ ) )   = e x p ( − 1 2 ( ( α t 1 − α t + 1 1 − α t − 1 ‾ ) x t − 1 2 − ( 2 α t 1 − α t x t + 2 α t − 1 ‾ 1 − α t − 1 ‾ x 0 ) x t − 1 + C ( x t , x 0 ) ) ) \begin{aligned} q(x_{t-1}|x_t,x_0) &=q(x_t|x_{t-1},x_0)\frac{q(x_{t-1}|x_0)}{q(x_t|x_0)} \\\ &\propto exp(-\frac{(x_{t-1}-\sqrt{\overline{\alpha_{t-1}}}x_0)^2}{2(1-\overline{\alpha_{t-1}})}) + exp(-\frac{(x_{t}-\sqrt{\alpha_{t}}x_{t-1})^2}{2(1-\alpha_{t})}) - exp(-\frac{(x_{t}-\sqrt{\overline{\alpha_{t}}}x_0)^2}{2(1-\overline{\alpha_{t}})}) \\\ &=exp(-\frac{1}{2}(\frac{(x_{t-1}-\sqrt{\overline{\alpha_{t-1}}}x_0)^2}{1-\overline{\alpha_{t-1}}}+\frac{(x_{t}-\sqrt{\alpha_{t}}x_{t-1})^2}{1-\alpha_{t}}-\frac{(x_{t}-\sqrt{\overline{\alpha_{t}}}x_0)^2}{1-\overline{\alpha_{t}}})) \\\ &=exp(-\frac{1}{2}((\frac{\alpha_t}{1-\alpha_t}+\frac{1}{1-\overline{\alpha_{t-1}}})x_{t-1}^2-(\frac{2\sqrt{\alpha_t}}{1-\alpha_t}x_t+\frac{2\sqrt{\overline{\alpha_{t-1}}}}{1-\overline{\alpha_{t-1}}}x_0)x_{t-1}+C(x_t,x_0))) \end{aligned} q(xt1xt,x0)   =q(xtxt1,x0)q(xtx0)q(xt1x0)exp(2(1αt1)(xt1αt1 x0)2)+exp(2(1αt)(xtαt xt1)2)exp(2(1αt)(xtαt x0)2)=exp(21(1αt1(xt1αt1 x0)2+1αt(xtαt xt1)21αt(xtαt x0)2))=exp(21((1αtαt+1αt11)xt12(1αt2αt xt+1αt12αt1 x0)xt1+C(xt,x0)))

其中 C ( x t , x 0 ) C(x_t,x_0) C(xt,x0)不含 x t − 1 x_{t-1} xt1

考虑到对于概率密度的表达式,可以进行展开操作,即 N ( μ , σ 2 ) ∝ e x p ( − ( x − μ ) 2 2 σ 2 ) = e x p ( − 1 2 ( 1 σ 2 x 2 − 2 μ σ 2 x + μ 2 σ 2 ) ) \mathcal N(\mu,\sigma^2)\propto exp(-\frac{(x-\mu)^2}{2\sigma^2})=exp(-\frac{1}{2}(\frac{1}{\sigma^2}x^2-\frac{2\mu}{\sigma^2}x+\frac{\mu^2}{\sigma^2})) N(μ,σ2)exp(2σ2(xμ)2)=exp(21(σ21x2σ22μx+σ2μ2)),与 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)的表达式一一对应可知。

1 σ 2 = α t 1 − α t + 1 1 − α t − 1 ‾   σ 2 = ( 1 − α t ) ( 1 − α t − 1 ‾ ) α t ( 1 − α t − 1 ‾ ) + ( 1 − α t ) = β t ( 1 − α t − 1 ‾ ) α t ( 1 − α t − 1 ‾ ) + β t   = 1 − α t − 1 ‾ 1 − α t ‾ β t \frac{1}{\sigma^2}=\frac{\alpha_t}{1-\alpha_t}+\frac{1}{1-\overline{\alpha_{t-1}}} \\\ \sigma^2 =\frac{(1-\alpha_t)(1-\overline{\alpha_{t-1}})}{\alpha_t(1-\overline{\alpha_{t-1}})+(1-\alpha_t)}=\frac{\beta_t(1-\overline{\alpha_{t-1}})}{\alpha_t(1-\overline{\alpha_{t-1}})+\beta_t} \\\ =\frac{1-\overline{\alpha_{t-1}}}{1-\overline{\alpha_t}}\beta_t σ21=1αtαt+1αt11 σ2=αt(1αt1)+(1αt)(1αt)(1αt1)=αt(1αt1)+βtβt(1αt1) =1αt1αt1βt

对于均值,有求得表达式:

2 μ σ 2 = 2 α t 1 − α t x t + 2 α t − 1 ‾ − α t − 1 ‾ x 0   μ = α t ( 1 − α t − 1 ‾ ) 1 − α t ‾ x t + α t − 1 ‾ β t 1 − α t ‾ x 0 \frac{2\mu}{\sigma^2}=\frac{2\sqrt{\alpha_t}}{1-\alpha_t}x_t+\frac{2\sqrt{\overline{\alpha_{t-1}}}}{-\overline{\alpha_{t-1}}}x_0 \\\ \mu=\frac{\sqrt{\alpha_t}(1-\overline{\alpha_{t-1}})}{1-\overline{\alpha_{t}}}x_t+\frac{\sqrt{\overline{\alpha_{t-1}}}\beta_t}{1-\overline{\alpha_t}}x_0 σ22μ=1αt2αt xt+αt12αt1 x0 μ=1αtαt (1αt1)xt+1αtαt1 βtx0

至此,我们在加入 x 0 x_0 x0的先验知识后,能够描述出分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)的均值和方差的表达式,即

μ = α t ( 1 − α t − 1 ‾ ) 1 − α t ‾ x t + α t − 1 ‾ β t 1 − α t ‾ x 0   σ 2 = 1 − α t − 1 ‾ 1 − α t ‾ β t \mu=\frac{\sqrt{\alpha_t}(1-\overline{\alpha_{t-1}})}{1-\overline{\alpha_{t}}}x_t+\frac{\sqrt{\overline{\alpha_{t-1}}}\beta_t}{1-\overline{\alpha_t}}x_0 \\\ \sigma^2=\frac{1-\overline{\alpha_{t-1}}}{1-\overline{\alpha_t}}\beta_t μ=1αtαt (1αt1)xt+1αtαt1 βtx0 σ2=1αt1αt1βt

但实际上我们需要通过神经网络来训练出一个分布 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt),该分布的均值和方差均是含参数的,即上述的 μ θ ( x t , t ) \mu_{\theta}(x_t,t) μθ(xt,t) Σ θ ( x t , t ) \Sigma_{\theta}(x_t,t) Σθ(xt,t)。我们考虑使用 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)来近似估计 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt),这显然是可行的,因为 x t x_t xt可以由 x 0 x_0 x0来表示,同理也可以使用 x t x_t xt来表示 x 0 x_0 x0。综上可得以下关于均值和方差的表达式:

μ θ ( x t , t ) = α t ( 1 − α t − 1 ‾ ) 1 − α t ‾ x t + α t − 1 ‾ β t 1 − α t ‾ x 0   Σ θ ( x t , t ) = 1 − α t − 1 ‾ 1 − α t ‾ β t \mu_{\theta}(x_t,t)=\frac{\sqrt{\alpha_t}(1-\overline{\alpha_{t-1}})}{1-\overline{\alpha_{t}}}x_t+\frac{\sqrt{\overline{\alpha_{t-1}}}\beta_t}{1-\overline{\alpha_t}}x_0 \\\ \Sigma_{\theta}(x_t,t)=\frac{1-\overline{\alpha_{t-1}}}{1-\overline{\alpha_t}}\beta_t μθ(xt,t)=1αtαt (1αt1)xt+1αtαt1 βtx0 Σθ(xt,t)=1αt1αt1βt

对于其中出现的先验知识 x 0 x_0 x0,考虑使用 x t x_t xt来表达,即有:
x 0 = x t − 1 − α t ‾ z α t ‾ x_0=\frac{x_t-\sqrt{1-\overline{\alpha_t}}z}{\sqrt{\overline{\alpha_t}}} x0=αt xt1αt z
代入均值 μ θ ( x t , t ) \mu_{\theta}(x_t,t) μθ(xt,t)的表达式,有以下推导过程:

μ θ ( x t , t ) = α t ( 1 − α t − 1 ‾ ) 1 − α t ‾ x t + α t − 1 ‾ β t 1 − α t ‾ x 0   = α t ( 1 − α t − 1 ‾ ) 1 − α t ‾ x t + α t − 1 ‾ β t 1 − α t ‾ x t − 1 − α t ‾ z α t ‾   = 1 α t ( x t − 1 − α t 1 − α t ‾ z ) \begin{aligned} \mu_{\theta}(x_t,t)&=\frac{\sqrt{\alpha_t}(1-\overline{\alpha_{t-1}})}{1-\overline{\alpha_{t}}}x_t+\frac{\sqrt{\overline{\alpha_{t-1}}}\beta_t}{1-\overline{\alpha_t}}x_0 \\\ &=\frac{\sqrt{\alpha_t}(1-\overline{\alpha_{t-1}})}{1-\overline{\alpha_{t}}}x_t+\frac{\sqrt{\overline{\alpha_{t-1}}}\beta_t}{1-\overline{\alpha_t}}\frac{x_t-\sqrt{1-\overline{\alpha_t}}z}{\sqrt{\overline{\alpha_t}}} \\\ &=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{1-\alpha_t}{\sqrt{1-\overline{\alpha_t}}}z) \end{aligned} μθ(xt,t)  =1αtαt (1αt1)xt+1αtαt1 βtx0=1αtαt (1αt1)xt+1αtαt1 βtαt xt1αt z=αt 1(xt1αt 1αtz)

观察表达式可以发现, α t \alpha_t αt x t x_t xt α t ‾ \overline{\alpha_t} αt均为已知量,模型需要确定 μ θ ( x t , t ) \mu_{\theta}(x_t,t) μθ(xt,t),本质上是个噪声预测的过程,即对于上式中的 z z z,需要交给神经网络去预测,记为 z θ ( x t , t ) z_{\theta}(x_t,t) zθ(xt,t)

综上,反向去噪过程可以概括为:

  1. 根据 x t x_t xt t t t预测高斯噪声 z θ ( x t , t ) z_{\theta}(x_t,t) zθ(xt,t),则能够得到模型预测的去噪后的分布均值 μ θ ( x t , t ) \mu_{\theta}(x_t,t) μθ(xt,t),其实就是去噪后的图像。
  2. 得到方差 Σ θ ( x t , t ) \Sigma_{\theta}(x_t,t) Σθ(xt,t),在DDPM中方差是untrained的,即 Σ θ ( x t , t ) = 1 − α t − 1 ‾ 1 − α t ‾ β t \Sigma_{\theta}(x_t,t)=\frac{1-\overline{\alpha_{t-1}}}{1-\overline{\alpha_t}}\beta_t Σθ(xt,t)=1αt1αt1βt,但对于方差也可以进行训练估计。
  3. p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_{\theta}(x_{t-1}|x_t)=\mathcal N (x_{t-1};\mu_{\theta}(x_t,t),\Sigma_{\theta}(x_t,t)) pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))进行去噪,获得 x t − 1 x_{t-1} xt1,通过重参数技巧。

3. 训练方式与策略

常规的训练流程
  1. 从数据集中采样得到 x 0 x_0 x0 x 0 ∼ q ( x ) x_0 \sim q(x) x0q(x),并在 1... T 1...T 1...T中随机采样一个 t t t
  2. 从标准高斯分布中采样一个噪声 z ∼ N ( 0 , I ) z\sim \mathcal N (0,\mathbf{I}) zN(0,I)
  3. 根据重参数技巧得到 x t = α t ‾ x 0 + 1 − α t ‾ z x_t=\sqrt{\overline{\alpha_t}}x_0+\sqrt{1-\overline{\alpha_t}}z xt=αt x0+1αt z
  4. 训练UNet,输入 x t x_t xt t t t,模型输出 z θ ( x t , t ) z_{\theta}(x_t,t) zθ(xt,t),将 z θ ( x t , t ) z_{\theta}(x_t,t) zθ(xt,t) z z z做Loss。
推理流程
  1. 从标准高斯分布中采样得到一个噪声 x t ∼ N ( 0 , I ) x_t \sim \mathcal N(0,\mathbf{I}) xtN(0,I)
  2. T T T到1遍历变量 t t t,如果 t = = 1 t==1 t==1 z = 0 z=0 z=0,否则采样噪声 z ∼ N ( 0 , I ) z\sim\mathcal N(0,\mathbf{I}) zN(0,I)
  3. UNet推理得到 z θ ( x t , t ) z_\theta(x_t,t) zθ(xt,t),进行去噪 x t − 1 = 1 α t ‾ ( x t − 1 − α t 1 − α t ‾ z θ ( x t , t ) ) + Σ θ ( x t , t ) z x_{t-1}=\frac{1}{\sqrt{\overline{\alpha_t}}}(x_t-\frac{1-\alpha_t}{\sqrt{1-\overline{\alpha_t}}}z_{\theta}(x_t,t))+\Sigma_{\theta}(x_t,t)z xt1=αt 1(xt1αt 1αtzθ(xt,t))+Σθ(xt,t)z
  4. 得到 x 0 x_0 x0

4. Diffusion的应用拓展

Diffusion做分割(SegDiff)

在这里插入图片描述

要点:将待分割图像作为condition image进行特征的提取,在进入UNet之前进行特征的融合,扩散生成的为mask。

训练 μ θ ( x t , t ) \mu_{\theta}(x_t,t) μθ(xt,t)时,需要加入图像的特征,即 μ θ ( x t , t , I ) \mu_{\theta}(x_t,t,I) μθ(xt,t,I),推理时同理。

Diffusion做分割(SegDiff)

[外链图片转存中…(img-fszHBj2E-1715260213416)]

要点:将待分割图像作为condition image进行特征的提取,在进入UNet之前进行特征的融合,扩散生成的为mask。

训练 μ θ ( x t , t ) \mu_{\theta}(x_t,t) μθ(xt,t)时,需要加入图像的特征,即 μ θ ( x t , t , I ) \mu_{\theta}(x_t,t,I) μθ(xt,t,I),推理时同理。

  • 5
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值