DiffusionModel-DDPM推导+代码理解

参考了这位up的视频所记录的笔记!推荐大家都去看原视频!
大白话AI
https://github.com/wangjia184/diffusion_model
论文:https://arxiv.org/abs/2006.11239

为了加深自己对扩散模型的理解,写这篇博客

散模型通diffusion model常指的是一种特别的生成模型,这种模型模拟物理中的扩散过程来生成数据。它们通过模拟数据由有序状态向无序状态的转变(正向过程),然后再逆向模拟这一过程(反向过程),来学习如何生成类似于真实数据的新样本。在这一过程中,模型会逐渐学会如何构造和重建数据中的复杂结构,这允许模型捕获和表达高维数据分布。这类模型在图像生成、自然语言处理等领域具有广泛的应用。

1.DDPM概述

在这里插入图片描述

  • q q q : 一个固定的(或预定义的)正向扩散过程,将高斯噪声逐渐添加到图像中,直到最终变成纯噪声
  • p θ p_θ pθ :一个学习的反向去噪扩散过程,其中神经网络被训练从纯噪声开始逐渐去噪图像,直到最终得到实际图像。

β × ϵ + 1 − β × x \sqrt{\beta}\times\epsilon+\sqrt{1-\beta}\times x β ×ϵ+1β ×x
x t + 1 = β × ϵ + 1 − β × x t x_{t+1}=\sqrt{\beta}\times\epsilon+\sqrt{1-\beta}\times x_t xt+1=β ×ϵ+1β ×xt

模型的正向和反向过程都由t索引,并持续有限的时间步长 T (DDPM 作者使用 T=1000)。在 t=0 时,你从你的数据分布中采样一个真实图像 x 0 x_0 x0 ,然后在每个时间步长 t,正向过程都会从高斯分布中采样一些噪声,并将其添加到前一时间步长的图像中。给定足够大的 T 和一个在每个时间步长添加噪声的合理计划,你最终将在 t=T 时通过一个渐进的过程得到一个各向同性高斯分布

2. 前向过程Forward Process q q q

扩散模型的加噪过程本质是概率密度的扩散过程,类似于一滴墨水滴到杯子里的水

x 0 → q ( x 1 ∣ x 0 ) x 1 → q ( x 2 ∣ x 1 ) x 2 → ⋯ → x T − 1 → q ( x t ∣ x t − 1 ) x T x_0 \overset{q(x_1 | x_0)}{\rightarrow} x_1 \overset{q(x_2 | x_1)}{\rightarrow} x_2 \rightarrow \dots \rightarrow x_{T-1} \overset{q(x_{t} | x_{t-1})}{\rightarrow} x_T x0q(x1x0)x1q(x2x1)x2xT1q(xtxt1)xT

这一过程是一个马尔可夫链, x t x_t xt 只依赖于 x t − 1 x_{t-1} xt1 q ( x t ∣ x t − 1 ) q(x_{t} | x_{t-1}) q(xtxt1) 根据已知的方差计划 β t β_{t} βt 在每个时间步长 t t t 添加高斯噪声。

x t = 1 − β t × x t − 1 + β t × ϵ t \boxed{x_t = \sqrt{1-β_t}\times x_{t-1} + \sqrt{β_t}\times ϵ_{t}} xt=1βt ×xt1+βt ×ϵt

  • ϵ 的系数 β \epsilon的系数\sqrt{\beta} ϵ的系数β 和x的系数 1 − β \sqrt{1-\beta} 1β 平方和始终为1,可以看成是直角三角形的两条边,可以直观感受到噪声与原图在混合中所占比例的此消彼长。随着β增加x越来越小

  • ϵ \epsilon ϵ:高斯噪声,从高斯分布中取样,高斯分布的概率密度函数是: f ( x ) = 1 σ 2 π e − ( x − μ ) 2 2 σ 2 f(x) = \frac{1}{\sigma\sqrt{2\pi}} e^{-\frac{(x-\mu)^2}{2\sigma^2}} f(x)=σ2π 1e2σ2(xμ)2

  • β t β_t βt 在每个时间步长 t t t 并不是常数。事实上,它定义了一个被称为“variance schedule”的东西,它可以是线性的、二次的、余弦的等等。

0 < β 1 < β 2 < β 3 < ⋯ < β T < 1 0 < β_1 < β_2 < β_3 < \dots < β_T < 1 0<β1<β2<β3<<βT<1

  • a t = 1 − β t a_t = 1 - β_t at=1βt则用=有:

x t = a t × x t − 1 + 1 − a t × ϵ t x_t = \sqrt{a_{t}}\times x_{t-1} + \sqrt{1-a_t} \times ϵ_{t} xt=at ×xt1+1at ×ϵt

2.1 Relationship between x t x_t xt and x t − 2 x_{t-2} xt2

x t − 1 = a t − 1 × x t − 2 + 1 − a t − 1 × ϵ t − 1 x_{t-1} = \sqrt{a_{t-1}}\times x_{t-2} + \sqrt{1-a_{t-1}} \times ϵ_{t-1} xt1=at1 ×xt2+1at1 ×ϵt1

⇓ 将 x t − 1 带入 x t \Downarrow\\将x_{t-1}带入x_t xt1带入xt

x t = a t ( a t − 1 × x t − 2 + 1 − a t − 1 ϵ t − 1 ) + 1 − a t × ϵ t x_t = \sqrt{a_{t}} (\sqrt{a_{t-1}}\times x_{t-2} + \sqrt{1-a_{t-1}} ϵ_{t-1}) + \sqrt{1-a_t} \times ϵ_t xt=at (at1 ×xt2+1at1 ϵt1)+1at ×ϵt

⇓ \Downarrow

x t = a t a t − 1 × x t − 2 + a t ( 1 − a t − 1 ) ϵ t − 1 + 1 − a t × ϵ t x_t = \sqrt{a_{t}a_{t-1}}\times x_{t-2} + \sqrt{a_{t}(1-a_{t-1})} ϵ_{t-1} + \sqrt{1-a_t} \times ϵ_t xt=atat1 ×xt2+at(1at1) ϵt1+1at ×ϵt

两个独立的高斯分布的和仍然是一个高斯分布。具体来说,如果 X ∼ N ( μ 1 , σ 1 2 ) X \sim N(\mu_{1},\sigma_{1}^{2}) XN(μ1,σ12) Y ∼ N ( μ 2 , σ 2 2 ) Y \sim N(\mu_{2},\sigma_{2}^{2}) YN(μ2,σ22) 是两个独立的高斯分布,那么它们的和 Z = X + Y Z = X + Y Z=X+Y 也服从高斯分布,其均值和方差分别为 μ 1 + μ 2 \mu_{1}+\mu_{2} μ1+μ2 σ 1 2 + σ 2 2 \sigma_{1}^{2} + \sigma_{2}^{2} σ12+σ22 ϵ t , ϵ t − 1 ϵ_t,ϵ_{t-1} ϵt,ϵt1都是独立的从高斯分布中采集的样本。

x t = a t a t − 1 × x t − 2 + a t ( 1 − a t − 1 ) + 1 − a t × ϵ x_t = \sqrt{a_{t}a_{t-1}}\times x_{t-2} + \sqrt{a_{t}(1-a_{t-1}) + 1-a_t} \times ϵ xt=atat1 ×xt2+at(1at1)+1at ×ϵ

⇓ \Downarrow

x t = a t a t − 1 × x t − 2 + 1 − a t a t − 1 × ϵ x_t = \sqrt{a_{t}a_{t-1}}\times x_{t-2} + \sqrt{1-a_{t}a_{t-1}} \times ϵ xt=atat1 ×xt2+1atat1 ×ϵ

2.3 Relationship between x t x_t xt and x 0 x_0 x0

同理,如此循环往复…

  • x t = a t a t − 1 × x t − 2 + 1 − a t a t − 1 × ϵ x_t = \sqrt{a_{t}a_{t-1}}\times x_{t-2} + \sqrt{1-a_{t}a_{t-1}}\times ϵ xt=atat1 ×xt2+1atat1 ×ϵ
  • x t = a t a t − 1 a t − 2 × x t − 3 + 1 − a t a t − 1 a t − 2 × ϵ x_t = \sqrt{a_{t}a_{t-1}a_{t-2}}\times x_{t-3} + \sqrt{1-a_{t}a_{t-1}a_{t-2}}\times ϵ xt=atat1at2 ×xt3+1atat1at2 ×ϵ
  • x t = a t a t − 1 a t − 2 a t − 3 . . . a t − ( k − 2 ) a t − ( k − 1 ) × x t − k + 1 − a t a t − 1 a t − 2 a t − 3 . . . a t − ( k − 2 ) a t − ( k − 1 ) × ϵ x_t = \sqrt{a_{t}a_{t-1}a_{t-2}a_{t-3}...a_{t-(k-2)}a_{t-(k-1)}}\times x_{t-k} + \sqrt{1-a_{t}a_{t-1}a_{t-2}a_{t-3}...a_{t-(k-2)}a_{t-(k-1)}}\times ϵ xt=atat1at2at3...at(k2)at(k1) ×xtk+1atat1at2at3...at(k2)at(k1) ×ϵ
  • x t = a t a t − 1 a t − 2 a t − 3 . . . a 2 a 1 × x 0 + 1 − a t a t − 1 a t − 2 a t − 3 . . . a 2 a 1 × ϵ x_t = \sqrt{a_{t}a_{t-1}a_{t-2}a_{t-3}...a_{2}a_{1}}\times x_{0} + \sqrt{1-a_{t}a_{t-1}a_{t-2}a_{t-3}...a_{2}a_{1}}\times ϵ xt=atat1at2at3...a2a1 ×x0+1atat1at2at3...a2a1 ×ϵ

a ˉ t : = a t a t − 1 a t − 2 a t − 3 . . . a 2 a 1 \bar{a}_{t} := a_{t}a_{t-1}a_{t-2}a_{t-3}...a_{2}a_{1} aˉt:=atat1at2at3...a2a1,可以得到最终公式:
x t = a ˉ t × x 0 + 1 − a ˉ t × ϵ , ϵ ∼ N ( 0 , I ) \boxed{x_{t} = \sqrt{\bar{a}_t}\times x_0+ \sqrt{1-\bar{a}_t}\times ϵ , ϵ \sim N(0,I)} xt=aˉt ×x0+1aˉt ×ϵ,ϵN(0,I)

⇓ \Downarrow

q ( x t ∣ x 0 ) = 1 2 π 1 − a ˉ t e ( − 1 2 ( x t − a ˉ t x 0 ) 2 1 − a ˉ t ) q(x_{t}|x_{0}) = \frac{1}{\sqrt{2\pi } \sqrt{1-\bar{a}_{t}}} e^{\left ( -\frac{1}{2}\frac{(x_{t}-\sqrt{\bar{a}_{t}}x_0)^2}{1-\bar{a}_{t}} \right ) } q(xtx0)=2π 1aˉt 1e(211aˉt(xtaˉt x0)2)

3.反向传播Reverse Process p p p

问题就是如何从t时刻的图像变成零时刻的图像,因为DDPM中反向传播依赖马尔科夫链,需要已知 x t x_t xt,才可以得到 x t − 1 x_{t-1} xt1

在这里插入图片描述

x t − 1 x_{t-1} xt1 x t x_t xt的推导

由贝叶斯公式 P ( A ∣ B ) = P ( B ∣ A ) P ( A ) P ( B ) P(A|B) = \frac{ P(B|A)P(A) }{ P(B) } P(AB)=P(B)P(BA)P(A),P(B|A)为先验概率,P(A|B)为后验概率
可以知道由后面的 x t x_t xt反向推出前一个时刻的 x t − 1 x_{t-1} xt1的公式( x 0 x_0 x0为条件概率):

p ( x t − 1 ∣ x t , x 0 ) = p ( x t ∣ x t − 1 , x 0 ) × p ( x t − 1 ∣ x 0 ) p ( x t ∣ x 0 ) \boxed{p(x_{t-1}|x_{t},x_{0}) = \frac{ p(x_{t}|x_{t-1},x_{0})\times p(x_{t-1}|x_0)}{p(x_{t}|x_0)} } p(xt1xt,x0)=p(xtx0)p(xtxt1,x0)×p(xt1x0)

  • p ( x t ∣ x t − 1 , x 0 ) p(x_t | x_{t-1}, x_0) p(xtxt1,x0) 是在给定 x t − 1 x_{t-1} xt1 x 0 x_0 x0 的情况下, x t x_t xt 的条件概率密度函数。
  • p ( x t − 1 ∣ x 0 ) p(x_{t-1} | x_0) p(xt1x0) 是在给定 x 0 x_0 x0 的情况下, x t − 1 x_{t-1} xt1 的条件概率密度函数。
  • p ( x t ∣ x 0 ) p(x_t | x_0) p(xtx0) 是在给定 x 0 x_0 x0 的情况下, x t x_t xt 的边缘概率密度函数。
变量公式对应分布
x t x_t xt a t x t − 1 + 1 − a t × ϵ \sqrt{a_t}x_{t-1}+\sqrt{1-a_t}\times\epsilon at xt1+1at ×ϵ N ( a t x t − 1 , 1 − a t ) N(\sqrt{a_t}x_{t-1},1-a_t) N(at xt1,1at)
x t − 1 x_{t-1} xt1 a ˉ t − 1 x 0 + 1 − a ˉ t − 1 × ϵ \sqrt{\bar{a}_{t-1}}x_0+\sqrt{1-\bar{a}_{t-1}}\times\epsilon aˉt1 x0+1aˉt1 ×ϵ N ( a ˉ t − 1 x 0 , 1 − a ˉ t − 1 ) N(\sqrt{\bar{a}_{t-1}}x_0,1-\bar{a}_{t-1}) N(aˉt1 x0,1aˉt1)
x t x_t xt a ˉ t x 0 + 1 − a ˉ t × ϵ \sqrt{\bar{a}_t}x_0+\sqrt{1-\bar{a}_t}\times\epsilon aˉt x0+1aˉt ×ϵ N ( a ˉ t x 0 , 1 − a ˉ t ) N(\sqrt{\bar{a}_t}x_0,1-\bar{a}_t) N(aˉt x0,1aˉt)

带入可得:

p ( x t ∣ x t − 1 , x 0 ) × p ( x t − 1 ∣ x 0 ) p ( x t ∣ x 0 ) = [ 1 2 π 1 − a t e ( − 1 2 ( x t − a t x t − 1 ) 2 1 − a t ) ] ∗ [ 1 2 π 1 − a ˉ t − 1 e ( − 1 2 ( x t − 1 − a ˉ t − 1 x 0 ) 2 1 − a ˉ t − 1 ) ] ÷ [ 1 2 π 1 − a ˉ t e ( − 1 2 ( x t − a ˉ t x 0 ) 2 1 − a ˉ t ) ] \frac{ p(x_{t}|x_{t-1},x_{0})\times p(x_{t-1}|x_0)}{p(x_{t}|x_0)} = \left [ \frac{1}{\sqrt{2\pi} \sqrt{1-a_{t}}} e^{\left ( -\frac{1}{2}\frac{(x_{t}-\sqrt{a_t}x_{t-1})^2}{1-a_{t}} \right ) } \right ] * \left [ \frac{1}{\sqrt{2\pi} \sqrt{1-\bar{a}_{t-1}}} e^{\left ( -\frac{1}{2}\frac{(x_{t-1}-\sqrt{\bar{a}_{t-1}}x_0)^2}{1-\bar{a}_{t-1}} \right ) } \right ] \div \left [ \frac{1}{\sqrt{2\pi} \sqrt{1-\bar{a}_{t}}} e^{\left ( -\frac{1}{2}\frac{(x_{t}-\sqrt{\bar{a}_{t}}x_0)^2}{1-\bar{a}_{t}} \right ) } \right ] p(xtx0)p(xtxt1,x0)×p(xt1x0)=[2π 1at 1e(211at(xtat xt1)2)][2π 1aˉt1 1e(211aˉt1(xt1aˉt1 x0)2)]÷[2π 1aˉt 1e(211aˉt(xtaˉt x0)2)]

⇓ 带入高斯分布的概率密度函数 \Downarrow \\带入高斯分布的概率密度函数 带入高斯分布的概率密度函数

1 2 π ( 1 − a t 1 − a ˉ t − 1 1 − a ˉ t ) e x p [ − 1 2 ( x t − 1 − ( a t ( 1 − a ˉ t − 1 ) 1 − a ˉ t x t + a ˉ t − 1 ( 1 − a t ) 1 − a ˉ t x 0 ) ) 2 ( 1 − a t 1 − a ˉ t − 1 1 − a ˉ t ) 2 ] \frac{1}{\sqrt{2\pi} \left ( {\color{Red} \frac{ \sqrt{1-a_t} \sqrt{1-\bar{a}_{t-1}} } {\sqrt{1-\bar{a}_{t}}}} \right ) } exp \left[ -\frac{1}{2} \frac{ \left( x_{t-1} - \left( {\color{Purple} \frac{\sqrt{a_t}(1-\bar{a}_{t-1})}{1-\bar{a}_t}x_t + \frac{\sqrt{\bar{a}_{t-1}}(1-a_t)}{1-\bar{a}_t}x_0} \right) \right) ^2 } { \left( {\color{Red} \frac{ \sqrt{1-a_t} \sqrt{1-\bar{a}_{t-1}} } {\sqrt{1-\bar{a}_{t}}}} \right)^2 } \right] 2π (1aˉt 1at 1aˉt1 )1exp 21(1aˉt 1at 1aˉt1 )2(xt1(1aˉtat (1aˉt1)xt+1aˉtaˉt1 (1at)x0))2

⇓ 可以推导出 p ( x t − 1 ∣ x t ) 满足的分布 \Downarrow \\可以推导出p(x_{t-1}|x_{t}) 满足的分布 可以推导出p(xt1xt)满足的分布

p ( x t − 1 ∣ x t , x 0 ) ∼ N ( a t ( 1 − a ˉ t − 1 ) 1 − a ˉ t x t + a ˉ t − 1 ( 1 − a t ) 1 − a ˉ t x 0 , ( 1 − a t 1 − a ˉ t − 1 1 − a ˉ t ) 2 ) p(x_{t-1}|x_{t},x_0) \sim N\left( {\color{Purple} \frac{\sqrt{a_t}(1-\bar{a}_{t-1})}{1-\bar{a}_t}x_t + \frac{\sqrt{\bar{a}_{t-1}}(1-a_t)}{1-\bar{a}_t}x_0} , \left( {\color{Red} \frac{ \sqrt{1-a_t} \sqrt{1-\bar{a}_{t-1}} } {\sqrt{1-\bar{a}_{t}}}} \right)^2 \right) p(xt1xt,x0)N(1aˉtat (1aˉt1)xt+1aˉtaˉt1 (1at)x0,(1aˉt 1at 1aˉt1 )2)

x t = a ˉ t × x 0 + 1 − a ˉ t × ϵ x_{t} = \sqrt{\bar{a}_t}\times x_0+ \sqrt{1-\bar{a}_t}\times ϵ xt=aˉt ×x0+1aˉt ×ϵ可以反解出 x 0 = x t − 1 − a ˉ t × ϵ a ˉ t x_0 = \frac{x_t - \sqrt{1-\bar{a}_t}\times ϵ}{\sqrt{\bar{a}_t}} x0=aˉt xt1aˉt ×ϵ.
⇓ 将 x 0 带入上述公式可以得到 \Downarrow \\将x_0带入上述公式可以得到 x0带入上述公式可以得到

p ( x t − 1 ∣ x t , x 0 ) ∼ N ( a t ( 1 − a ˉ t − 1 ) 1 − a ˉ t x t + a ˉ t − 1 ( 1 − a t ) 1 − a ˉ t × x t − 1 − a ˉ t × ϵ a ˉ t , β t ( 1 − a ˉ t − 1 ) 1 − a ˉ t ) p(x_{t-1}|x_{t},x_0) \sim N\left( {\color{Purple} \frac{\sqrt{a_t}(1-\bar{a}_{t-1})}{1-\bar{a}_t}x_t + \frac{\sqrt{\bar{a}_{t-1}}(1-a_t)}{1-\bar{a}_t}\times \frac{x_t - \sqrt{1-\bar{a}_t}\times ϵ}{\sqrt{\bar{a}_t}} } , {\color{Red} \frac{ \beta_{t} (1-\bar{a}_{t-1}) } { 1-\bar{a}_{t}}} \right) p(xt1xt,x0)N(1aˉtat (1aˉt1)xt+1aˉtaˉt1 (1at)×aˉt xt1aˉt ×ϵ,1aˉtβt(1aˉt1))

其中,可以通过数学变换将上述 p ( x t − 1 ∣ x t , x 0 ) p(x_{t-1}|x_{t},x_0) p(xt1xt,x0)c变成到更简单的形式:
μ = a t ( 1 − a ˉ t − 1 ) 1 − a ˉ t x t + a ˉ t − 1 ( 1 − a t ) 1 − a ˉ t × x t − 1 − a ˉ t × ϵ a ˉ t = 1 α t ( α t − α t ˉ 1 − α t ˉ x t + α t ˉ ( 1 − α t ) 1 − α t ˉ ∗ x t − 1 − α t ˉ ϵ α t ˉ ) = 1 α t [ α t − α t ˉ 1 − α t ˉ x t + 1 − α t 1 − α t ˉ ∗ ( x t − 1 − α t ˉ ϵ ) ] = 1 α t ( x t − 1 − α t 1 − α t ˉ 1 − α t ˉ ϵ ) = 1 α t ( x t − 1 − α t 1 − α t ˉ ϵ ) \begin{aligned} &\mu=\frac{\sqrt{a_t}\left(1-\bar{a}_{t-1}\right)}{1-\bar{a}_t}x_t+\frac{\sqrt{\bar{a}_{t-1}}\left(1-a_t\right.)}{1-\bar{a}_t}\times\frac{x_t-\sqrt{1-\bar{a}_t}\times\epsilon}{\sqrt{\bar{a}_t}} \\ &=\frac{1}{\sqrt{\alpha_{t}}}\big(\frac{\alpha_{t}-\bar{\alpha_{t}}}{1-\bar{\alpha_{t}}}x_{t}+\frac{\sqrt{\bar{\alpha_{t}}}(1-\alpha_{t})}{1-\bar{\alpha_{t}}}*\frac{x_{t}-\sqrt{1-\bar{\alpha_{t}}}\epsilon}{\sqrt{\bar{\alpha_{t}}}}\big) \\ &=\frac{1}{\sqrt{\alpha_{t}}}[\frac{\alpha_{t}-\bar{\alpha_{t}}}{1-\bar{\alpha_{t}}}x_{t}+\frac{1-\alpha_{t}}{1-\bar{\alpha_{t}}}*(x_{t}-\sqrt{1-\bar{\alpha_{t}}}\epsilon)]=\frac{1}{\sqrt{\alpha_{t}}}(x_{t}-\frac{1-\alpha_{t}}{1-\bar{\alpha_{t}}}\sqrt{1-\bar{\alpha_{t}}}\epsilon) \\ &=\frac{1}{\sqrt{\alpha_{t}}}(x_{t}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha_{t}}}}\epsilon) \end{aligned} μ=1aˉtat (1aˉt1)xt+1aˉtaˉt1 (1at)×aˉt xt1aˉt ×ϵ=αt 1(1αtˉαtαtˉxt+1αtˉαtˉ (1αt)αtˉ xt1αtˉ ϵ)=αt 1[1αtˉαtαtˉxt+1αtˉ1αt(xt1αtˉ ϵ)]=αt 1(xt1αtˉ1αt1αtˉ ϵ)=αt 1(xt1αtˉ 1αtϵ)
故最终:

μ t = 1 α t ( x t − 1 − α t 1 − α t ˉ ϵ ) \boxed{\mu_t=\frac{1}{\sqrt{\alpha_{t}}}(x_{t}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha_{t}}}}\epsilon)} μt=αt 1(xt1αtˉ 1αtϵ)
⇓ 可以推导出 x ^ t − 1 \Downarrow \\可以推导出\hat x_{t-1} 可以推导出x^t1

x ^ t − 1 = μ t − 1 − a t − 1 ( 1 − a t ) 1 − a t × ϵ ′ a t \hat{x}_{t-1}=\mu_{t-1}-\frac{\sqrt{a_{t-1}}(1-a_t)}{1-a_t}\times\frac{\epsilon^{\prime}}{\sqrt{a_t}} x^t1=μt11atat1 (1at)×at ϵ
P ( x t − 1 ∣ x t , x 0 ) ∼ N ( 1 a t ( x t − 1 − a t 1 − a ˉ t ϵ ) , ( β t ( 1 − a ˉ t − 1 ) 1 − a ˉ t ) 2 ) \boxed{P(x_{t-1}\left|x_t,x_0\right.)\sim N(\frac{1}{\sqrt{a_t}}\left(x_t-\frac{1-a_t}{\sqrt{1-\bar{a}_t}}\textcolor{red}{\epsilon}\right),(\sqrt{\frac{\beta_t\left(1-\bar{a}_{t-1}\right)}{1-\bar{a}_t}})^2)} P(xt1xt,x0)N(at 1(xt1aˉt 1atϵ),(1aˉtβt(1aˉt1) )2)

  • 由之前前向过程的推导可以知道任意 x t x_t xt的图像都可以由 x 0 x_0 x0加载而来;
  • 只要知道了从 x 0 x_0 x0 x t x_t xt加入的噪声 ϵ \epsilon ϵ(从高斯分布中采样后加入),就能得到它前一时刻 x t − 1 x_{t-1} xt1的概率分布,即: P ( x t − 1 ∣ x t , x 0 ) P(x_t-1|x_t,x_0) P(xt1∣xt,x0)。​
  • 这里只有噪声ϵ是未知的了,就需要使用神经网络,输入 x t x_t xt 时刻的图像和对应时间步骤t,预测此图像相对于某个 x 0 x_0 x0 原图加入的噪声ϵ。

神经网络预测噪声ϵ

目前就只有噪声ϵ是未知的了,再来回顾一下关键变量:

Q ( x t ∣ x t − 1 ) Q(x_t|x_{t-1}) Q(xtxt1):扩散过程。表示在给定前一时刻的样本 x t − 1 x_{t-1} xt1 的条件下,生成当前时刻样本 x t x_t xt 的概率分布。

  • 在扩散过程中,每一步都涉及将一些噪声添加到前一时刻的样本 x t − 1 x_{t-1} xt1 上,以生成当前时刻的样本 x t x_t xt
  • 这个过程可以数学上表示为: x t = a ˉ t × x 0 + 1 − a ˉ t × ϵ x_{t} = \sqrt{\bar{a}_t}\times x_0+ \sqrt{1-\bar{a}_t}\times ϵ xt=aˉt ×x0+1aˉt ×ϵ
  • 前向过程不需要一步步推进,已知 x 0 {x_0} x0和时间步t步骤对应的噪声,可以得到 x t {x_t} xt

Q ( x t − 1 ∣ x t ) Q(x_{t-1}|x_{t}) Q(xt1xt:在逆向过程中,模型的目标是从当前时刻的样本 x t x_t xt恢复出前一时刻的样本 x t − 1 x_{t-1} xt1

  • 这个过程可以看作是在给定当前时刻的样本 x t x_t xt的条件下,寻找前一时刻可能的样本 x t − 1 x_{t-1} xt1 的概率分布。

  • t越大, x t x_{t} xt x t − 1 x_{t-1} xt1越接近,条件概率 Q ( x t − 1 ∣ x t ) Q(x_{t-1}|x_{t}) Q(xt1xt)的标准差越小

Q ( x 0 ) Q(x_0) Q(x0): Q ( x 0 ) Q(x_0) Q(x0)是扩散过程的起始点,即原始数据的分布。

  • 注意:或许原始数据 x 0 x_0 x0是确定的,但是其分布 Q ( x 0 ) Q(x_0) Q(x0)不能确定 , Q ( x 0 ) Q(x_0) Q(x0)指整个数据集的分布,它包含了数据集中所有可能的 x 0 x_0 x0 的概率。而单个数据点 x 0 x_0 x0是模型试图生成或复制的具体实例。在生成任务中,我们从一个真实的数据点 x 0 x_0 x0 开始,并观察它如何通过扩散过程逐渐加入噪声变成 x t x_t xt
  • 在实际应用中,我们通常不知道 Q ( x 0 ) Q(x_0) Q(x0)的确切形式,但我们假设在扩散过程的最后,即 t趋向于无穷大时,数据的分布将趋向于某个已知的噪声分布,通常是高斯分布。
  • 可以通过训练代表条件概率 P ( x t − 1 ∣ x t ) P(x_{t-1}|x_t) P(xt1xt)的神经网络去拟合条件概率 Q ( x t − 1 ∣ x t ) Q(x_{t-1}|x_t) Q(xt1xt),最终使边缘概率 P ( x 0 ) P(x_0) P(x0)拟合数据分布 Q ( x 0 ) Q(x_0) Q(x0)(原始数据的分布),每一次的分布都根据上一次的来进行调整,目的是使生成的分布尽可能接近真实的分布
    P ( x t − 1 ∣ x t ) ⟹ Q ( x t − 1 ∣ x t ) ∼ N ( a t ( 1 − a t − 1 ) 1 − a t x t + a t − 1 ( 1 − a t ) 1 − a t × x t − 1 − a t × ϵ a t , ( β t ( 1 − a t − 1 ) 1 − a t ) 2 ) P(x_{t-1}|x_t) \quad \Longrightarrow \quad Q(x_{t-1}|x_t) \sim N\left(\frac{\sqrt{a_t}(1-a_{t-1})}{1-a_t}x_t+\frac{\sqrt{a_{t-1}}(1-a_t)}{1-a_t}\times\frac{x_t-\sqrt{1-a_t}\times \color{red}{\epsilon}}{\sqrt{a_t}},\left(\sqrt{\frac{\beta_t(1-a_{t-1})}{1-a_t}}\right)^2\right) P(xt1xt)Q(xt1xt)N 1atat (1at1)xt+1atat1 (1at)×at xt1at ×ϵ, 1atβt(1at1) 2

神经网络训练:

在这里插入图片描述

训练数据准备

  • 原始图像 x 0 x_0 x0:初始化分布 Q ( x 0 ) Q(x_0) Q(x0)采样得到 x 0 x_0 x0
  • 选择时间步 t t t,从集合从集合1,…,T 中均匀采样
  • 噪声采样:从标准正态分布 N ( 0 , 1 ) N(0, 1) N(0,1)中采样噪声 ϵ \epsilon ϵ
  • 加噪图像 x t x_t xt:使用 x 0 x_0 x0 ϵ \epsilon ϵ通过公式 x t = a ˉ t × x 0 + 1 − a ˉ t × ϵ x_t = \sqrt{\bar{a}_t} \times x_0 + \sqrt{1-\bar{a}_t} \times \epsilon xt=aˉt ×x0+1aˉt ×ϵ 生成。

训练过程

  • 输入:神经网络接收加噪图像 x t x_t xt作为输入。

  • 输出:网络输出对应于 x t x_t xt的预测噪声 ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) \boldsymbol{\epsilon}_{\theta}(\sqrt{\bar{\alpha}_{t}}\mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}}\boldsymbol{\epsilon},t) ϵθ(αˉt x0+1αˉt ϵ,t)

  • 目标:使预测的噪声尽可能接近扩散过程中实际加入的噪声。

  • 损失函数:
    ∇ θ ∥ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∥ 2 \nabla_{\theta}\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_{\theta}(\sqrt{\bar{\alpha}_{t}}\mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}}\boldsymbol{\epsilon},t)\right\|^{2} θ ϵϵθ(αˉt x0+1αˉt ϵ,t) 2

    • ϵ \boldsymbol{\epsilon} ϵ:这是扩散过程中添加到原始数据 x 0 \mathbf{x}_{0} x0 上的噪声。

    • 噪声预测函数 ϵ = f ( x t ) \epsilon=f(x_t) ϵ=f(xt),训练神经网络f学习 x 0 x_0 x0添加什么样子的噪声 ϵ ′ \epsilon' ϵ可以得到(无限接近) x t x_t xt

    • ϵ θ \boldsymbol{\epsilon}_{\theta} ϵθ:这是神经网络参数 θ \theta θ 的函数,它尝试预测给定噪声图像 x t x_t xt 的噪声。

    • α ˉ t x 0 + 1 − α ˉ t ϵ \sqrt{\bar{\alpha}_{t}}\mathbf{x}_{0}+\sqrt{1-\bar{\alpha}_{t}}\boldsymbol{\epsilon} αˉt x0+1αˉt ϵ:这是生成当前噪声图像 x t x_t xt的扩散过程公式,其中 * α ˉ t \bar{\alpha}_{t} αˉt 是与时间步 t 相关的预定义的方差参数。

    • t t t :表示扩散步骤的索引,它控制了扩散过程中噪声的量。

    • ∇ θ \nabla_{\theta} θ:表示对神经网络参数 θ \theta θ的梯度运算。

  • 每次迭代时,使用预测的噪声 ϵ ′ \epsilon' ϵ μ t − 1 \mu_{t-1} μt1更新函数然后, x ^ t − 1 \hat{x}_{t-1} x^t1作为新的输入 x t x_t xt 进入下一轮迭代,神经网络继续预测新的噪声 ϵ ′ \epsilon' ϵ,直到生成原始图像 x 0 x_0 x0
    x ^ t − 1 = μ t − 1 − a t − 1 ( 1 − a t ) 1 − a t × ϵ ′ a t \hat{x}_{t-1}=\mu_{t-1}-\frac{\sqrt{a_{t-1}}(1-a_t)}{1-a_t}\times\frac{\epsilon^{\prime}}{\sqrt{a_t}} x^t1=μt11atat1 (1at)×at ϵ

逆向重构

在这里插入图片描述

  1. 初始化

    • 初始化 x T x_T xT 为一个服从正态分布 N ( 0 , I ) N(0, I) N(0,I) 的随机向量,其中 I I I 是单位矩阵。
  2. 逆向重构

    • t从时间步 T T T 反向遍历到时间步 1 1 1
    • 对于每个时间步 t t t
      • 如果 t > 1 t > 1 t>1,则从标准正态分布 N ( 0 , I ) N(0, I) N(0,I) 中采样一个随机噪声向量 z z z
      • 如果 t = 1 t = 1 t=1,则 z z z 被设为零向量。第一步不需要额外的噪声
    • 计算 x t − 1 x_{t-1} xt1 x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z \mathbf{x}_{t-1}=\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}}\boldsymbol{\epsilon}_{\theta}(\mathbf{x}_{t},t)\right)+ \sigma_t z xt1=αt 1(xt1αˉt 1αtϵθ(xt,t))+σtz
      • 预测噪声的逆转:第一部分 1 α t ( x t − 1 − α t 1 − α t ϵ θ ( x t , t ) ) \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \alpha_t}} \epsilon_{\theta}(x_t, t) \right) αt 1(xt1αt 1αtϵθ(xt,t))是根据预测的噪声 ϵ θ ( x t , t ) \epsilon_{\theta}(x_t, t) ϵθ(xt,t)逆转扩散过程的影响,尝试恢复出原始图像的一个估计。
      • 随机噪声的引入:第二部分 σ t z \sigma_t z σtz 是引入额外的随机噪声项,其中 z z z是从标准正态分布 N ( 0 , I ) N(0, I) N(0,I)中采样的随机向量, σ t \sigma_t σt是与时间步 t t t 相关的噪声标准差。这有助于模型探索数据空间,增加生成样本的多样性,并模拟逆向过程中的不确定性。
  3. 输出

    • 最终生成的数据样本为 x 0 x_0 x0

在这里插入图片描述

  • 可以通过多个简单低维离散型概率组合成复杂高维离散型概率,对于高维度空间
    • 例如一张图像(RGB三通道:64*64的彩色图像有12288维)
    • 概率空间涵盖所有可能像素颜色组合
    • 任意一点对应某张图片
    • 条件概率是其(超)截面
    • 某函数描述任意一点的概率密度

代码分析


# Import of libraries
import random
import imageio
import numpy as np
from argparse import ArgumentParser

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import einops
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

from torchvision.transforms import Compose, ToTensor, Lambda
from torchvision.datasets.mnist import MNIST, FashionMNIST

# Setting reproducibility
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Definitions
STORE_PATH_MNIST = f"ddpm_model_mnist.pt"
STORE_PATH_FASHION = f"ddpm_model_fashion.pt"
no_train = False
fashion = True
batch_size = 128
n_epochs = 20
lr = 0.001
store_path = "ddpm_fashion.pt" if fashion else "ddpm_mnist.pt"

# Loading the data (converting each image into a tensor and normalizing between [-1, 1])
transform = Compose([ToTensor(), Lambda(lambda x: (x - 0.5) * 2)])
ds_fn = FashionMNIST if fashion else MNIST
dataset = ds_fn("./datasets", download=True, train=True, transform=transform)
loader = DataLoader(dataset, batch_size, shuffle=True)

DDPM 类

# DDPM class
class MyDDPM(nn.Module):
    def __init__(
        self,
        network,
        n_steps=200,
        min_beta=10**-4,
        max_beta=0.02,
        device=None,
        image_chw=(1, 28, 28),#图像的通道数,高度和宽度
    ):
        super(MyDDPM, self).__init__()
        self.n_steps = n_steps
        self.device = device
        self.image_chw = image_chw
        self.network = network.to(device)
        self.betas = torch.linspace(min_beta, max_beta, n_steps).to(device)  

        self.alphas = 1 - self.betas
        self.alpha_bars = torch.tensor(
            [torch.prod(self.alphas[: i + 1]) for i in range(len(self.alphas))]
        ).to(device)#计算累乘

    def forward(self, x0, t, eta=None):
        #t是(n,1)的张量,n是张量x的图像数量,对于x的每个图像我们可以指定不同的时间步(这样更加随机)
        n, c, h, w = x0.shape
        #batch size,通道数,高度,宽度
        a_bar = self.alpha_bars[t]

        if eta is None:
            eta = torch.randn(n, c, h, w).to(self.device)
            #如果没有输入的噪声张量,使用torch.randn生成与输入图像相同大小的标准正态分布的噪声

        noisy = (
            a_bar.sqrt().reshape(n, 1, 1, 1) * x0
            + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta
        )
        return noisy

    def backward(self, x, t):
        # Run each image through the network for each timestep t in the vector t.
        # The network returns its estimation of the noise that was added.
        return self.network(x, t)#噪声估计

def show_forward(ddpm, loader, device):
    # Showing the forward process
    for batch in loader:
        imgs = batch[0]

        show_images(imgs, "Original images")

        for percent in [0.25, 0.5, 0.75, 1]:
            show_images(
                ddpm(
                    imgs.to(device),
                    [int(percent * ddpm.n_steps) - 1 for _ in range(len(imgs))],
                ),
                f"DDPM Noisy images {int(percent * 100)}%",
            )
        break

generate_new_images使用一个给定的去噪扩散概率模型(DDPM)来生成新的图像样本,并可选地将生成过程保存为一个GIF动画

#从随机噪声开始,让时间戳由T到0
#run the backward pass and generate new images
#每一步的噪声估计为eta_theta
# 定义生成新图像的函数,从一个DDPM模型生成样本
def generate_new_images(
    ddpm,
    n_samples=16,  # 要生成的样本数量
    device=None,  # 用于计算的设备
    frames_per_gif=100,  # 生成的GIF动画中的帧数
    gif_name="sampling.gif",  # 生成的GIF动画的文件名
    c=1,  # 图像通道数
    h=28,  # 图像高度
    w=28  # 图像宽度
):
    # 计算GIF中每一帧的时间索引
    frame_idxs = np.linspace(0, ddpm.n_steps, frames_per_gif).astype(np.uint)
    frames = []  # 初始化存储GIF帧的列表

    # 使用torch.no_grad()确保所有张量不跟踪梯度,以加快计算速度
    with torch.no_grad():
        if device is None:  # 如果没有指定设备,使用DDPM模型的设备
            device = ddpm.device

        # 从随机噪声开始生成过程
        x = torch.randn(n_samples, c, h, w).to(device)  # 随机噪声图像

        # 逆向遍历时间步
        for idx, t in enumerate(list(range(ddpm.n_steps))[::-1]):
            # 估计要移除的噪声
            time_tensor = (torch.ones(n_samples, 1) * t).to(device).long()
            eta_theta = ddpm.backward(x, time_tensor)  # 噪声估计

            # 获取当前时间步的alpha值和累乘alpha值
            alpha_t = ddpm.alphas[t]
            alpha_t_bar = ddpm.alpha_bars[t]

            # 部分去噪图像
            x = (1 / alpha_t.sqrt()) * (
                x - (1 - alpha_t) / (1 - alpha_t_bar).sqrt() * eta_theta
            )

            # 如果当前时间步大于0,添加一些噪声
            if t > 0:
                z = torch.randn(n_samples, c, h, w).to(device)

                # 计算当前时间步的标准差
                beta_t = ddpm.betas[t]
                sigma_t = beta_t.sqrt()

                # 添加更多的噪声,类似于Langevin Dynamics的方式
                x = x + sigma_t * z

            # 将生成的图像帧添加到GIF中
            if idx in frame_idxs or t == 0:
                # 将图像值标准化到[0, 255]范围内
                normalized = x.clone()
                for i in range(len(normalized)):
                    normalized[i] -= torch.min(normalized[i])
                    normalized[i] *= 255 / torch.max(normalized[i])

                # 将批处理图像重塑为尽可能接近正方形的帧
                frame = einops.rearrange(
                    normalized,
                    "(b1 b2) c h w -> (b1 h) (b2 w) c",
                    b1=int(n_samples**0.5),
                )
                frame = frame.cpu().numpy().astype(np.uint8)

                # 将帧添加到GIF中
                frames.append(frame)

    # 将GIF帧写入文件
    with imageio.get_writer(gif_name, mode="I") as writer:
        for idx, frame in enumerate(frames):
            # 将灰度帧转换为RGB帧
            rgb_frame = np.repeat(frame, 3, axis=2)
            writer.append_data(rgb_frame)

            # 最后一帧显示更长时间
            if idx == len(frames) - 1:
                last_rgb_frame = np.repeat(frames[-1], 3, axis=2)
                for _ in range(frames_per_gif // 3):
                    writer.append_data(last_rgb_frame)

    # 返回生成的样本张量
    return x

sinusoidal_embedding 的函数,它生成一个正弦位置嵌入矩阵,通常用于神经网络中对序列或时间步的位置信息进行编码

def sinusoidal_embedding(n, d):
    # n 是位置的数量,d 是嵌入的维度
    # 返回标准的正弦位置嵌入
    embedding = torch.zeros(n, d)  # 初始化一个形状为 (n, d) 的零矩阵
    wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])  # 计算每个维度的缩放因子
    wk = wk.reshape((1, d))  # 将缩放因子重塑为 (1, d) 形状的张量
    t = torch.arange(n).reshape((n, 1))  # 生成一个从 0 到 n-1 的序列,并重塑为 (n, 1) 形状的张量

    # 使用正弦和余弦函数填充嵌入矩阵的奇数和偶数位置
    # 这种方法称为正弦位置编码,常用于Transformer模型
    embedding[:, ::2] = torch.sin(t * wk[:, ::2])  # 填充所有偶数索引的列(以2为步长)
    embedding[:, 1::2] = torch.cos(t * wk[:, 1::2])  # 填充所有奇数索引的列(以2为步长)

    return embedding  # 返回计算出的正弦位置嵌入矩阵

为构建Unet做准备:构建一个包含两个卷积层的神经网络块,它首先对输入数据进行可选的LayerNorm标准化处理,然后通过两个卷积层,每层后面跟着一个激活函数。MyUNet 类构建了一个典型的U-Net结构,它包括一个编码路径(由多个 MyBlock 组成),一个bottleneck layer,以及一个解码路径。每个 MyBlock 都是一个包含两个卷积层和激活函数的序列。在编码路径中,使用了降采样卷积来减少特征图的尺寸。在解码路径中,使用了上采样卷积(ConvTranspose2d)来恢复特征图的尺寸。

class MyBlock(nn.Module):
    def __init__(  # 类的构造函数
        self,
        shape,  # 传递给LayerNorm的形状参数
        in_c,  # 输入通道数
        out_c,  # 输出通道数
        kernel_size=3,  # 卷积核大小,默认为3
        stride=1,  # 步长,默认为1
        padding=1,  # 填充,默认为1
        activation=None,  # 激活函数,默认为None,即不指定
        normalize=True,  # 是否进行标准化处理,默认为True
    ):
        super(MyBlock, self).__init__()  # 调用基类的构造函数
        self.ln = nn.LayerNorm(shape)  # 初始化LayerNorm层
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding)  # 初始化第一个卷积层
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size, stride, padding)  # 初始化第二个卷积层
        self.activation = nn.SiLU() if activation is None else activation  # 设置激活函数,如果没有指定则使用SiLU
        self.normalize = normalize  # 存储是否进行标准化的标志

    def forward(self, x):
        # 前向传播函数
        out = self.ln(x) if self.normalize else x  # 如果需要标准化,则对输入x进行LayerNorm处理
        out = self.conv1(out)  # 第一个卷积操作
        out = self.activation(out)  # 激活函数操作
        out = self.conv2(out)  # 第二个卷积操作
        out = self.activation(out)  # 再次应用激活函数操作
        return out  # 返回网络块的输出

用于预测噪声的Unet

#用于预测噪声的Unet

class MyUNet(nn.Module):
    def __init__(self, n_steps=1000, time_emb_dim=100):
        super(MyUNet, self).__init__()

        # Sinusoidal embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)

        # First half
        self.te1 = self._make_te(time_emb_dim, 1)
        self.b1 = nn.Sequential(
            MyBlock((1, 28, 28), 1, 10),
            MyBlock((10, 28, 28), 10, 10),
            MyBlock((10, 28, 28), 10, 10),
        )
        self.down1 = nn.Conv2d(10, 10, 4, 2, 1)#降采样

        self.te2 = self._make_te(time_emb_dim, 10)
        self.b2 = nn.Sequential(
            MyBlock((10, 14, 14), 10, 20),
            MyBlock((20, 14, 14), 20, 20),
            MyBlock((20, 14, 14), 20, 20),
        )
        self.down2 = nn.Conv2d(20, 20, 4, 2, 1)

        self.te3 = self._make_te(time_emb_dim, 20)
        self.b3 = nn.Sequential(
            MyBlock((20, 7, 7), 20, 40),
            MyBlock((40, 7, 7), 40, 40),
            MyBlock((40, 7, 7), 40, 40),
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(40, 40, 2, 1), nn.SiLU(), nn.Conv2d(40, 40, 4, 2, 1)
        )

        # Bottleneck
        self.te_mid = self._make_te(time_emb_dim, 40)
        self.b_mid = nn.Sequential(
            MyBlock((40, 3, 3), 40, 20),
            MyBlock((20, 3, 3), 20, 20),
            MyBlock((20, 3, 3), 20, 40),
        )

        # Second half
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(40, 40, 4, 2, 1),
            nn.SiLU(),
            nn.ConvTranspose2d(40, 40, 2, 1),
        )

        self.te4 = self._make_te(time_emb_dim, 80)
        self.b4 = nn.Sequential(
            MyBlock((80, 7, 7), 80, 40),
            MyBlock((40, 7, 7), 40, 20),
            MyBlock((20, 7, 7), 20, 20),
        )

        self.up2 = nn.ConvTranspose2d(20, 20, 4, 2, 1)
        self.te5 = self._make_te(time_emb_dim, 40)
        self.b5 = nn.Sequential(
            MyBlock((40, 14, 14), 40, 20),
            MyBlock((20, 14, 14), 20, 10),
            MyBlock((10, 14, 14), 10, 10),
        )

        self.up3 = nn.ConvTranspose2d(10, 10, 4, 2, 1)
        self.te_out = self._make_te(time_emb_dim, 20)
        self.b_out = nn.Sequential(
            MyBlock((20, 28, 28), 20, 10),
            MyBlock((10, 28, 28), 10, 10),
            MyBlock((10, 28, 28), 10, 10, normalize=False),
        )

        self.conv_out = nn.Conv2d(10, 1, 3, 1, 1)

    def forward(self, x, t):
    #网络的前向传播函数 forward 定义了如何通过这个网络结构传递输入数据 x 和时间步嵌入 t。在前向传播过程中,输入数据会通过编码路径的多个块,然后通过瓶颈层,最后通过解码路径的多个块来生成输出。
        # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
        t = self.time_embed(t)
        n = len(x)
        out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1))  # (N, 10, 28, 28)
        out2 = self.b2(
            self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1)
        )  # (N, 20, 14, 14)
        out3 = self.b3(
            self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1)
        )  # (N, 40, 7, 7)

        out_mid = self.b_mid(
            self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1)
        )  # (N, 40, 3, 3)

        out4 = torch.cat((out3, self.up1(out_mid)), dim=1)  # (N, 80, 7, 7)
        out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1, 1))  # (N, 20, 7, 7)

        out5 = torch.cat((out2, self.up2(out4)), dim=1)  # (N, 40, 14, 14)
        out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1, 1))  # (N, 10, 14, 14)

        out = torch.cat((out1, self.up3(out5)), dim=1)  # (N, 20, 28, 28)
        out = self.b_out(out + self.te_out(t).reshape(n, -1, 1, 1))  # (N, 1, 28, 28)

        out = self.conv_out(out)

        return out
#用于创建单层MLP,用于映射位置嵌入
#_make_te 函数是一个辅助函数,用于创建一个单层的多层感知机(MLP),它用于将时间嵌入映射到不同的维度空间,这在U-Net的各层之间传递时间信息时非常有用。
    def _make_te(self, dim_in, dim_out):
        return nn.Sequential(
            nn.Linear(dim_in, dim_out), nn.SiLU(), nn.Linear(dim_out, dim_out)
        )

定义了一个去噪扩散概率模型(DDPM)实例,用于生成图像。


# Defining model
n_steps, min_beta, max_beta = 1000, 10**-4, 0.02  # Originally used by the authors
# 定义模型参数
# n_steps 是扩散过程中的时间步数
# min_beta 和 max_beta 是用于计算扩散过程中噪声的最小和最大 beta 值

# 创建 MyDDPM 实例
ddpm = MyDDPM(
    MyUNet(n_steps),  # 使用 MyUNet 作为去噪模型,n_steps 作为构建 U-Net 的时间步数参数
    n_steps=n_steps,  # DDPM 模型中使用的时间步数
    min_beta=min_beta,  # 最小 beta 值
    max_beta=max_beta,  # 最大 beta 值
    device=device,  # 指定计算设备,CPU 或 GPU
)
# ddpm 是一个 MyDDPM 对象,它是去噪扩散概率模型的实例
  • 29
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值