扩散模型(Diffusion Models)详解

扩散模型(Diffusion Models)详解 - 增强版

扩散模型(Diffusion Models)在近几年成为了生成对抗网络(GAN)之外的另一重要生成模型方向。其核心理念可概括为:「先把一张清晰的图像逐渐加噪,最终变成纯随机噪声;然后训练神经网络逆向地一步步清除噪声,还原到清晰的原图」

一、扩散过程的物理与数学解释

1.1 物理直观

扩散过程在物理学中描述的是粒子从高浓度区域向低浓度区域移动的现象,遵循熵增原理。根据热力学第二定理,系统会自发地从有序状态向无序状态演化。在概率模型中,我们将扩散过程形式化为:从数据分布逐步转变为简单的先验分布(如标准高斯分布)

1.2 信息论基础

从信息论角度,扩散过程可以理解为信息的逐步丢失过程。设原始数据分布的熵为 H ( X 0 ) H(X_0) H(X0),噪声分布的熵为 H ( X T ) H(X_T) H(XT),则有:

H ( X 0 ) < H ( X 1 ) < ⋯ < H ( X T ) H(X_0) < H(X_1) < \cdots < H(X_T) H(X0)<H(X1)<<H(XT)

其中熵的增加量 Δ H t = H ( X t ) − H ( X t − 1 ) > 0 \Delta H_t = H(X_t) - H(X_{t-1}) > 0 ΔHt=H(Xt)H(Xt1)>0 反映了每一步信息的丢失。

1.3 概率论框架

在扩散模型中,我们定义两个过程:

  1. 正向扩散:将数据分布逐步转换为噪声分布,满足:
    q ( x 0 , x 1 , … , x T ) = q ( x 0 ) ∏ t = 1 T q ( x t ∣ x t − 1 ) q(x_0, x_1, \ldots, x_T) = q(x_0) \prod_{t=1}^T q(x_t|x_{t-1}) q(x0,x1,,xT)=q(x0)t=1Tq(xtxt1)

  2. 逆向扩散:从噪声分布逆向还原数据分布,目标学习:
    p θ ( x 0 , x 1 , … , x T ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) p_\theta(x_0, x_1, \ldots, x_T) = p(x_T) \prod_{t=1}^T p_\theta(x_{t-1}|x_t) pθ(x0,x1,,xT)=p(xT)t=1Tpθ(xt1xt)

二、数学原理的严格推导

2.1 正向扩散(Forward Diffusion)的详细分析

设数据点 x 0 ∼ q ( x 0 ) \mathbf{x}_0 \sim q(\mathbf{x}_0) x0q(x0) 为原始数据分布,我们构建一个马尔可夫加噪过程:

q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q(\mathbf{x}_{1:T}|\mathbf{x}_0) = \prod_{t=1}^T q(\mathbf{x}_t|\mathbf{x}_{t-1}) q(x1:Tx0)=t=1Tq(xtxt1)

其中转移核为高斯条件分布:

q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(\mathbf{x}_t|\mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1-\beta_t}\mathbf{x}_{t-1}, \beta_t\mathbf{I}) q(xtxt1)=N(xt;1βt xt1,βtI)

这里 { β t ∈ ( 0 , 1 ) } t = 1 T \{\beta_t \in (0,1)\}_{t=1}^T {βt(0,1)}t=1T 是预定义的方差调度系数。

2.1.1 马尔可夫性质的验证

马尔可夫性质要求: q ( x t ∣ x t − 1 , x t − 2 , … , x 0 ) = q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}, x_{t-2}, \ldots, x_0) = q(x_t|x_{t-1}) q(xtxt1,xt2,,x0)=q(xtxt1)

这可以通过条件独立性验证:给定 x t − 1 x_{t-1} xt1 x t x_t xt 与历史状态条件独立。

2.1.2 累积加噪的重参数化技巧

为了高效地采样中间状态,我们可以推导出从 x 0 \mathbf{x}_0 x0 直接得到任意时刻 x t \mathbf{x}_t xt 的闭式解。定义 α t = 1 − β t \alpha_t = 1 - \beta_t αt=1βt α ˉ t = ∏ s = 1 t α s \bar{\alpha}_t = \prod_{s=1}^t \alpha_s αˉt=s=1tαs,则:

x t = α t x t − 1 + 1 − α t ϵ t − 1 = α t ( α t − 1 x t − 2 + 1 − α t − 1 ϵ t − 2 ) + 1 − α t ϵ t − 1 \begin{align} \mathbf{x}_t &= \sqrt{\alpha_t}\mathbf{x}_{t-1} + \sqrt{1-\alpha_t}\boldsymbol{\epsilon}_{t-1} \\ &= \sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}\mathbf{x}_{t-2} + \sqrt{1-\alpha_{t-1}}\boldsymbol{\epsilon}_{t-2}) + \sqrt{1-\alpha_t}\boldsymbol{\epsilon}_{t-1} \end{align} xt=αt xt1+1αt ϵt1=αt (αt1 xt2+1αt1 ϵt2)+1αt ϵt1

由于独立高斯随机变量的线性组合仍为高斯,且有:

α t ( 1 − α t − 1 ) ϵ t − 2 + 1 − α t ϵ t − 1 ∼ N ( 0 , ( α t ( 1 − α t − 1 ) + ( 1 − α t ) ) I ) \sqrt{\alpha_t(1-\alpha_{t-1})}\boldsymbol{\epsilon}_{t-2} + \sqrt{1-\alpha_t}\boldsymbol{\epsilon}_{t-1} \sim \mathcal{N}(0, (\alpha_t(1-\alpha_{t-1}) + (1-\alpha_t))\mathbf{I}) αt(1αt1) ϵt2+1αt ϵt1N(0,(αt(1αt1)+(1αt))I)

其中方差为:
α t ( 1 − α t − 1 ) + ( 1 − α t ) = α t − α t α t − 1 + 1 − α t = 1 − α t α t − 1 = 1 − α ˉ t / α ˉ t − 2 \alpha_t(1-\alpha_{t-1}) + (1-\alpha_t) = \alpha_t - \alpha_t\alpha_{t-1} + 1 - \alpha_t = 1 - \alpha_t\alpha_{t-1} = 1 - \bar{\alpha}_t/\bar{\alpha}_{t-2} αt(1αt1)+(1αt)=αtαtαt1+1αt=1αtαt1=1αˉt/αˉt2

继续递归展开并整理,对独立同分布的噪声进行合并,可得:

q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(\mathbf{x}_t|\mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t}\mathbf{x}_0, (1-\bar{\alpha}_t)\mathbf{I}) q(xtx0)=N(xt;αˉt x0,(1αˉt)I)

引理 2.1:设 ϵ 1 , ϵ 2 , … , ϵ t \epsilon_1, \epsilon_2, \ldots, \epsilon_t ϵ1,ϵ2,,ϵt 为独立的标准高斯随机变量, a 1 , a 2 , … , a t a_1, a_2, \ldots, a_t a1,a2,,at 为常数,则:
∑ i = 1 t a i ϵ i ∼ N ( 0 , ∑ i = 1 t a i 2 ) \sum_{i=1}^t a_i \epsilon_i \sim \mathcal{N}(0, \sum_{i=1}^t a_i^2) i=1taiϵiN(0,i=1tai2)

这就允许我们通过如下方式直接采样 x t \mathbf{x}_t xt

x t = α ˉ t x 0 + 1 − α ˉ t ϵ , ϵ ∼ N ( 0 , I ) \mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}, \quad \boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) xt=αˉt x0+1αˉt ϵ,ϵN(0,I)

2.1.3 信噪比分析

定义信噪比(SNR)为:
SNR ( t ) = α ˉ t 1 − α ˉ t \text{SNR}(t) = \frac{\bar{\alpha}_t}{1-\bar{\alpha}_t} SNR(t)=1αˉtαˉt

随着时间步增加, α ˉ t \bar{\alpha}_t αˉt 单调递减,信噪比逐渐降低,最终趋近于0。

2.2 逆向扩散过程的精确推导

2.2.1 贝叶斯推导

在理想情况下,逆向过程也是一个马尔可夫链,从 p ( x T ) = N ( 0 , I ) p(\mathbf{x}_T) = \mathcal{N}(\mathbf{0}, \mathbf{I}) p(xT)=N(0,I) 开始,通过学习条件概率 p θ ( x t − 1 ∣ x t ) p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t) pθ(xt1xt) 逐步恢复 x 0 \mathbf{x}_0 x0

让我们推导 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) q(xt1xt,x0) 的精确形式。根据贝叶斯定理:

q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) = \frac{q(\mathbf{x}_t|\mathbf{x}_{t-1}, \mathbf{x}_0)q(\mathbf{x}_{t-1}|\mathbf{x}_0)}{q(\mathbf{x}_t|\mathbf{x}_0)} q(xt1xt,x0)=q(xtx0)q(xtxt1,x0)q(xt1x0)

由于是马尔可夫过程, q ( x t ∣ x t − 1 , x 0 ) = q ( x t ∣ x t − 1 ) q(\mathbf{x}_t|\mathbf{x}_{t-1}, \mathbf{x}_0) = q(\mathbf{x}_t|\mathbf{x}_{t-1}) q(xtxt1,x0)=q(xtxt1)

2.2.2 高斯分布的配方推导

代入已知的高斯分布并完成平方配方,我们有三个高斯分布:

  1. q ( x t ∣ x t − 1 ) = N ( α t x t − 1 , ( 1 − α t ) I ) q(\mathbf{x}_t|\mathbf{x}_{t-1}) = \mathcal{N}(\sqrt{\alpha_t}\mathbf{x}_{t-1}, (1-\alpha_t)\mathbf{I}) q(xtxt1)=N(αt xt1,(1αt)I)
  2. q ( x t − 1 ∣ x 0 ) = N ( α ˉ t − 1 x 0 , ( 1 − α ˉ t − 1 ) I ) q(\mathbf{x}_{t-1}|\mathbf{x}_0) = \mathcal{N}(\sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0, (1-\bar{\alpha}_{t-1})\mathbf{I}) q(xt1x0)=N(αˉt1 x0,(1αˉt1)I)
  3. q ( x t ∣ x 0 ) = N ( α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(\mathbf{x}_t|\mathbf{x}_0) = \mathcal{N}(\sqrt{\bar{\alpha}_t}\mathbf{x}_0, (1-\bar{\alpha}_t)\mathbf{I}) q(xtx0)=N(αˉt x0,(1αˉt)I)

定理 2.1(高斯分布的贝叶斯更新):设 X ∼ N ( μ 1 , σ 1 2 ) X \sim \mathcal{N}(\mu_1, \sigma_1^2) XN(μ1,σ12) Y ∣ X ∼ N ( a X + b , σ 2 2 ) Y|X \sim \mathcal{N}(aX + b, \sigma_2^2) YXN(aX+b,σ22),则:
X ∣ Y ∼ N ( σ 2 2 μ 1 + a σ 1 2 ( Y − b ) σ 2 2 + a 2 σ 1 2 , σ 1 2 σ 2 2 σ 2 2 + a 2 σ 1 2 ) X|Y \sim \mathcal{N}\left(\frac{\sigma_2^2\mu_1 + a\sigma_1^2(Y-b)}{\sigma_2^2 + a^2\sigma_1^2}, \frac{\sigma_1^2\sigma_2^2}{\sigma_2^2 + a^2\sigma_1^2}\right) XYN(σ22+a2σ12σ22μ1+aσ12(Yb),σ22+a2σ12σ12σ22)

应用此定理,可得:

q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ t ( x t , x 0 ) , β ~ t I ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0), \tilde{\beta}_t\mathbf{I}) q(xt1xt,x0)=N(xt1;μ~t(xt,x0),β~tI)

其中:

μ ~ t ( x t , x 0 ) = α ˉ t − 1 β t 1 − α ˉ t x 0 + α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t \tilde{\boldsymbol{\mu}}_t(\mathbf{x}_t, \mathbf{x}_0) = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}\mathbf{x}_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\mathbf{x}_t μ~t(xt,x0)=1αˉtαˉt1 βtx0+1αˉtαt (1αˉt1)xt

且:

β ~ t = 1 − α ˉ t − 1 1 − α ˉ t β t \tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t β~t=1αˉt1αˉt1βt

2.2.3 噪声预测的数学基础

这个条件分布是精确的,但在实际情况中我们无法直接获得 x 0 \mathbf{x}_0 x0。因此,我们训练神经网络 ϵ θ ( x t , t ) \epsilon_\theta(\mathbf{x}_t, t) ϵθ(xt,t) 来预测噪声 ϵ \boldsymbol{\epsilon} ϵ,从而间接估计 x 0 \mathbf{x}_0 x0

x ^ 0 = 1 α ˉ t ( x t − 1 − α ˉ t ϵ θ ( x t , t ) ) \hat{\mathbf{x}}_0 = \frac{1}{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t - \sqrt{1-\bar{\alpha}_t}\epsilon_\theta(\mathbf{x}_t, t)) x^0=αˉt 1(xt1αˉt ϵθ(xt,t))

这个估计基于重参数化公式的逆变换。将这个估计值代入 μ ~ t \tilde{\boldsymbol{\mu}}_t μ~t 的表达式,得到:

μ θ ( x t , t ) = 1 α t ( x t − β t 1 − α ˉ t ϵ θ ( x t , t ) ) \boldsymbol{\mu}_\theta(\mathbf{x}_t, t) = \frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(\mathbf{x}_t, t)\right) μθ(xt,t)=αt 1(xt1αˉt βtϵθ(xt,t))

2.3 变分下界与损失函数的详细推导

2.3.1 证据下界(ELBO)推导

扩散模型的训练目标是最大化观测数据的对数似然。受变分推断启发,我们可以通过最大化证据下界(ELBO)来实现:

log ⁡ p θ ( x 0 ) ≥ E q ( x 1 : T ∣ x 0 ) [ log ⁡ p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ] = L \log p_\theta(\mathbf{x}_0) \geq \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}\left[\log\frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}\right] = \mathcal{L} logpθ(x0)Eq(x1:Tx0)[logq(x1:Tx0)pθ(x0:T)]=L

展开并重新组织项:

L = E q [ log ⁡ p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ] = E q [ log ⁡ p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) ∏ t = 1 T q ( x t ∣ x t − 1 ) ] = E q [ log ⁡ p ( x T ) + ∑ t = 1 T log ⁡ p θ ( x t − 1 ∣ x t ) − ∑ t = 1 T log ⁡ q ( x t ∣ x t − 1 ) ] \begin{align} \mathcal{L} &= \mathbb{E}_q\left[\log\frac{p_\theta(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_0)}\right] \\ &= \mathbb{E}_q\left[\log\frac{p(\mathbf{x}_T)\prod_{t=1}^T p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)}{\prod_{t=1}^T q(\mathbf{x}_t|\mathbf{x}_{t-1})}\right] \\ &= \mathbb{E}_q\left[\log p(\mathbf{x}_T) + \sum_{t=1}^T \log p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t) - \sum_{t=1}^T \log q(\mathbf{x}_t|\mathbf{x}_{t-1})\right] \end{align} L=Eq[logq(x1:Tx0)pθ(x0:T)]=Eq[logt=1Tq(xtxt1)p(xT)t=1Tpθ(xt1xt)]=Eq[logp(xT)+t=1Tlogpθ(xt1xt)t=1Tlogq(xtxt1)]

2.3.2 KL散度分解

通过巧妙的项重新排列和贝叶斯规则,ELBO可以分解为:

L = E q [ log ⁡ p θ ( x 0 ∣ x 1 ) − D K L ( q ( x T ∣ x 0 ) ∣ ∣ p ( x T ) ) − ∑ t = 2 T D K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) ] \mathcal{L} = \mathbb{E}_q\left[\log p_\theta(\mathbf{x}_0|\mathbf{x}_1) - D_{KL}(q(\mathbf{x}_T|\mathbf{x}_0) || p(\mathbf{x}_T)) - \sum_{t=2}^T D_{KL}(q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0) || p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t))\right] L=Eq[logpθ(x0x1)DKL(q(xTx0)∣∣p(xT))t=2TDKL(q(xt1xt,x0)∣∣pθ(xt1xt))]

各项的解释:

  • 重构项 E q [ log ⁡ p θ ( x 0 ∣ x 1 ) ] \mathbb{E}_q[\log p_\theta(\mathbf{x}_0|\mathbf{x}_1)] Eq[logpθ(x0x1)] 确保从轻微噪声中正确重构
  • 先验匹配项 D K L ( q ( x T ∣ x 0 ) ∣ ∣ p ( x T ) ) D_{KL}(q(\mathbf{x}_T|\mathbf{x}_0) || p(\mathbf{x}_T)) DKL(q(xTx0)∣∣p(xT)) 确保最终状态接近先验
  • 去噪匹配项 ∑ t = 2 T D K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) \sum_{t=2}^T D_{KL}(q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0) || p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)) t=2TDKL(q(xt1xt,x0)∣∣pθ(xt1xt)) 确保逆向过程正确
2.3.3 简化损失函数的推导

定理 2.2:当 p θ ( x t − 1 ∣ x t ) = N ( μ θ ( x t , t ) , β ~ t I ) p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t) = \mathcal{N}(\boldsymbol{\mu}_\theta(\mathbf{x}_t, t), \tilde{\beta}_t\mathbf{I}) pθ(xt1xt)=N(μθ(xt,t),β~tI) 时,KL散度项可以写为:

D K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) = 1 2 β ~ t ∣ ∣ μ ~ t − μ θ ∣ ∣ 2 D_{KL}(q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0) || p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)) = \frac{1}{2\tilde{\beta}_t}||\tilde{\boldsymbol{\mu}}_t - \boldsymbol{\mu}_\theta||^2 DKL(q(xt1xt,x0)∣∣pθ(xt1xt))=2β~t1∣∣μ~tμθ2

由于我们使用噪声预测参数化,可以证明:

∣ ∣ μ ~ t − μ θ ∣ ∣ 2 = β t 2 2 β ~ t α t ( 1 − α ˉ t ) ∣ ∣ ϵ − ϵ θ ( x t , t ) ∣ ∣ 2 ||\tilde{\boldsymbol{\mu}}_t - \boldsymbol{\mu}_\theta||^2 = \frac{\beta_t^2}{2\tilde{\beta}_t\alpha_t(1-\bar{\alpha}_t)}||\boldsymbol{\epsilon} - \epsilon_\theta(\mathbf{x}_t, t)||^2 ∣∣μ~tμθ2=2β~tαt(1αˉt)βt2∣∣ϵϵθ(xt,t)2

因此,整个目标可以简化为预测噪声的均方误差:

L simple = E t , x 0 , ϵ [ ∣ ∣ ϵ − ϵ θ ( x t , t ) ∣ ∣ 2 ] \mathcal{L}_{\text{simple}} = \mathbb{E}_{t,\mathbf{x}_0,\boldsymbol{\epsilon}}\left[||\boldsymbol{\epsilon} - \epsilon_\theta(\mathbf{x}_t, t)||^2\right] Lsimple=Et,x0,ϵ[∣∣ϵϵθ(xt,t)2]

这里 t t t 是均匀采样的时间步, ϵ \boldsymbol{\epsilon} ϵ 是标准高斯噪声。

2.3.4 权重调度的理论分析

完整的损失函数实际上包含权重项:

L weighted = E t , x 0 , ϵ [ λ t ∣ ∣ ϵ − ϵ θ ( x t , t ) ∣ ∣ 2 ] \mathcal{L}_{\text{weighted}} = \mathbb{E}_{t,\mathbf{x}_0,\boldsymbol{\epsilon}}\left[\lambda_t||\boldsymbol{\epsilon} - \epsilon_\theta(\mathbf{x}_t, t)||^2\right] Lweighted=Et,x0,ϵ[λt∣∣ϵϵθ(xt,t)2]

其中 λ t = β t 2 2 β ~ t α t ( 1 − α ˉ t ) \lambda_t = \frac{\beta_t^2}{2\tilde{\beta}_t\alpha_t(1-\bar{\alpha}_t)} λt=2β~tαt(1αˉt)βt2

2.4 扩散模型的概率流ODE视角

2.4.1 随机微分方程(SDE)表述

扩散模型也可以从概率流常微分方程(ODE)的视角理解。当步数 T → ∞ T \rightarrow \infty T 时,离散马尔可夫链趋近于连续的扩散过程,由以下随机微分方程(SDE)描述:

d x = f ( x , t ) d t + g ( t ) d w d\mathbf{x} = \mathbf{f}(\mathbf{x}, t)dt + g(t)d\mathbf{w} dx=f(x,t)dt+g(t)dw

其中:

  • f ( x , t ) \mathbf{f}(\mathbf{x}, t) f(x,t) 是漂移函数
  • g ( t ) g(t) g(t) 是扩散系数
  • w \mathbf{w} w 是维纳过程(布朗运动)

对于扩散模型,具体形式为:
d x = − 1 2 β ( t ) x d t + β ( t ) d w d\mathbf{x} = -\frac{1}{2}\beta(t)\mathbf{x}dt + \sqrt{\beta(t)}d\mathbf{w} dx=21β(t)xdt+β(t) dw

2.4.2 逆向SDE的推导

根据Anderson定理,逆向过程满足:

d x = [ f ( x , t ) − g ( t ) 2 ∇ x log ⁡ p t ( x ) ] d t + g ( t ) d w ˉ d\mathbf{x} = [\mathbf{f}(\mathbf{x}, t) - g(t)^2\nabla_\mathbf{x}\log p_t(\mathbf{x})]dt + g(t)d\bar{\mathbf{w}} dx=[f(x,t)g(t)2xlogpt(x)]dt+g(t)dwˉ

其中 ∇ x log ⁡ p t ( x ) \nabla_\mathbf{x}\log p_t(\mathbf{x}) xlogpt(x) 是分数函数(score function), w ˉ \bar{\mathbf{w}} wˉ 是逆向布朗运动。

定理 2.3(Anderson逆向SDE):给定正向SDE d x = f ( x , t ) d t + g ( t ) d w d\mathbf{x} = \mathbf{f}(\mathbf{x}, t)dt + g(t)d\mathbf{w} dx=f(x,t)dt+g(t)dw,其逆向过程为:
d x = [ f ( x , t ) − g ( t ) 2 ∇ x log ⁡ p t ( x ) ] d t + g ( t ) d w ˉ d\mathbf{x} = [\mathbf{f}(\mathbf{x}, t) - g(t)^2\nabla_\mathbf{x}\log p_t(\mathbf{x})]dt + g(t)d\bar{\mathbf{w}} dx=[f(x,t)g(t)2xlogpt(x)]dt+g(t)dwˉ

2.4.3 概率流ODE

更重要的是,存在一个确定性的ODE,其边际概率密度与SDE相同:

d x = [ f ( x , t ) − 1 2 g ( t ) 2 ∇ x log ⁡ p t ( x ) ] d t d\mathbf{x} = \left[\mathbf{f}(\mathbf{x}, t) - \frac{1}{2}g(t)^2\nabla_\mathbf{x}\log p_t(\mathbf{x})\right]dt dx=[f(x,t)21g(t)2xlogpt(x)]dt

在扩散模型中,我们实际上是在学习 s θ ( x , t ) ≈ ∇ x log ⁡ p t ( x ) s_\theta(\mathbf{x}, t) \approx \nabla_\mathbf{x}\log p_t(\mathbf{x}) sθ(x,t)xlogpt(x),而:

s θ ( x , t ) = − ϵ θ ( x , t ) 1 − α ˉ t s_\theta(\mathbf{x}, t) = -\frac{\epsilon_\theta(\mathbf{x}, t)}{\sqrt{1-\bar{\alpha}_t}} sθ(x,t)=1αˉt ϵθ(x,t)

三、高效采样策略的数学基础

3.1 DDIM: 确定性非马尔可夫采样

3.1.1 非马尔可夫过程的构造

DDPM使用随机性采样,而DDIM提出了一种确定性变体。关键洞察是:给定边际分布 q ( x t ∣ x 0 ) q(\mathbf{x}_t|\mathbf{x}_0) q(xtx0),存在无穷多个联合分布 q ( x 1 : T ∣ x 0 ) q(\mathbf{x}_{1:T}|\mathbf{x}_0) q(x1:Tx0) 具有相同的边际。

DDIM通过重新定义 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0) q(xt1xt,x0) 为:

x t − 1 = α ˉ t − 1 x ^ 0 + 1 − α ˉ t − 1 − σ t 2 ⋅ x t − α ˉ t x ^ 0 1 − α ˉ t + σ t ϵ \mathbf{x}_{t-1} = \sqrt{\bar{\alpha}_{t-1}}\hat{\mathbf{x}}_0 + \sqrt{1-\bar{\alpha}_{t-1} - \sigma_t^2}\cdot\frac{\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\hat{\mathbf{x}}_0}{\sqrt{1-\bar{\alpha}_t}} + \sigma_t\boldsymbol{\epsilon} xt1=αˉt1 x^0+1αˉt1σt2 1αˉt xtαˉt x^0+σtϵ

其中 x ^ 0 = x t − 1 − α ˉ t ϵ θ ( x t , t ) α ˉ t \hat{\mathbf{x}}_0 = \frac{\mathbf{x}_t - \sqrt{1-\bar{\alpha}_t}\epsilon_\theta(\mathbf{x}_t, t)}{\sqrt{\bar{\alpha}_t}} x^0=αˉt xt1αˉt ϵθ(xt,t)

3.1.2 插值参数的选择

σ t = 1 − α ˉ t − 1 1 − α ˉ t β t \sigma_t = \sqrt{\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t} σt=1αˉt1αˉt1βt 时,恢复DDPM;当 σ t = 0 \sigma_t = 0 σt=0 时,得到确定性的DDIM。

引理 3.1:DDIM更新公式保持边际分布不变,即 q DDIM ( x t ∣ x 0 ) = q DDPM ( x t ∣ x 0 ) q_{\text{DDIM}}(\mathbf{x}_t|\mathbf{x}_0) = q_{\text{DDPM}}(\mathbf{x}_t|\mathbf{x}_0) qDDIM(xtx0)=qDDPM(xtx0)

3.1.3 加速采样的数学原理

DDIM允许跳跃式采样,选择子序列 { t 1 , t 2 , … , t S } \{t_1, t_2, \ldots, t_S\} {t1,t2,,tS} 其中 S ≪ T S \ll T ST

x t i − 1 = α ˉ t i − 1 x ^ 0 + 1 − α ˉ t i − 1 ⋅ x t i − α ˉ t i x ^ 0 1 − α ˉ t i \mathbf{x}_{t_{i-1}} = \sqrt{\bar{\alpha}_{t_{i-1}}}\hat{\mathbf{x}}_0 + \sqrt{1-\bar{\alpha}_{t_{i-1}}}\cdot\frac{\mathbf{x}_{t_i}-\sqrt{\bar{\alpha}_{t_i}}\hat{\mathbf{x}}_0}{\sqrt{1-\bar{\alpha}_{t_i}}} xti1=αˉti1 x^0+1αˉti1 1αˉti xtiαˉti x^0

这种跳跃采样的理论基础是ODE视角:确定性ODE可以用更大的步长求解。

3.2 分数匹配与Langevin动力学

3.2.1 分数匹配目标

扩散模型的训练可以视为分数匹配问题。给定数据分布 p data ( x ) p_{\text{data}}(\mathbf{x}) pdata(x),分数匹配的目标是学习:

s θ ( x ) ≈ ∇ x log ⁡ p data ( x ) s_\theta(\mathbf{x}) \approx \nabla_\mathbf{x} \log p_{\text{data}}(\mathbf{x}) sθ(x)xlogpdata(x)

通过以下目标:

L SM = E p data ( x ) [ 1 2 ∣ ∣ s θ ( x ) − ∇ x log ⁡ p data ( x ) ∣ ∣ 2 ] \mathcal{L}_{\text{SM}} = \mathbb{E}_{p_{\text{data}}(\mathbf{x})}\left[\frac{1}{2}||s_\theta(\mathbf{x}) - \nabla_\mathbf{x} \log p_{\text{data}}(\mathbf{x})||^2\right] LSM=Epdata(x)[21∣∣sθ(x)xlogpdata(x)2]

3.2.2 加噪分数匹配

由于直接计算 ∇ x log ⁡ p data ( x ) \nabla_\mathbf{x} \log p_{\text{data}}(\mathbf{x}) xlogpdata(x) 困难,我们使用加噪分数匹配:

L DSM = E t   E p data ( x 0 )   E q ( x t ∣ x 0 ) [ 1 2 ∥ s θ ( x t , t ) − ∇ x t log ⁡ q ( x t ∣ x 0 ) ∥ 2 ] \mathcal{L}_{\text{DSM}} = \mathbb{E}_{t}\, \mathbb{E}_{p_{\text{data}}(\mathbf{x}_0)}\, \mathbb{E}_{q(\mathbf{x}_t\mid\mathbf{x}_0)} \left[ \frac{1}{2} \left\| s_{\theta}(\mathbf{x}_t,t) -\nabla_{\mathbf{x}_t}\log q(\mathbf{x}_t\mid\mathbf{x}_0) \right\|^{2} \right] LDSM=EtEpdata(x0)Eq(xtx0)[21sθ(xt,t)xtlogq(xtx0)2]

定理 3.1:对于高斯转移核 q ( x t ∣ x 0 ) = N ( α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(\mathbf{x}_t|\mathbf{x}_0) = \mathcal{N}(\sqrt{\bar{\alpha}_t}\mathbf{x}_0, (1-\bar{\alpha}_t)\mathbf{I}) q(xtx0)=N(αˉt x0,(1αˉt)I),有:
∇ x t log ⁡ q ( x t ∣ x 0 ) = − x t − α ˉ t   x 0 1 − α ˉ t = − ϵ 1 − α ˉ t . \nabla_{\mathbf{x}_t}\log q(\mathbf{x}_t\mid\mathbf{x}_0) = -\frac{\mathbf{x}_t-\sqrt{\bar{\alpha}_t}\,\mathbf{x}_0}{1-\bar{\alpha}_t} = -\frac{\boldsymbol{\epsilon}}{\sqrt{1-\bar{\alpha}_t}}. xtlogq(xtx0)=1αˉtxtαˉt x0=1αˉt ϵ.

因此,分数匹配等价于噪声预测!

3.3 DPM-Solver与高阶数值方法

3.3.1 指数积分器方法

DPM-Solver将扩散ODE重写为线性部分和非线性部分:

d x d t = − 1 2 β ( t ) x + 1 2 β ( t ) x − ϵ θ ( x , t ) 1 − α ˉ t α ˉ t \frac{d\mathbf{x}}{dt} = -\frac{1}{2}\beta(t)\mathbf{x} + \frac{1}{2}\beta(t)\frac{\mathbf{x} - \epsilon_\theta(\mathbf{x}, t)}{\sqrt{1-\bar{\alpha}_t}}\sqrt{\bar{\alpha}_t} dtdx=21β(t)x+21β(t)1αˉt xϵθ(x,t)αˉt

通过变量替换 λ t = log ⁡ ( α ˉ t / ( 1 − α ˉ t ) ) \lambda_t = \log(\bar{\alpha}_t/(1-\bar{\alpha}_t)) λt=log(αˉt/(1αˉt)),可得:

d x d λ = 1 − α ˉ t α ˉ t ϵ θ ( x , t ) \frac{d\mathbf{x}}{d\lambda} = \sqrt{\frac{1-\bar{\alpha}_t}{\bar{\alpha}_t}}\epsilon_\theta(\mathbf{x}, t) dλdx=αˉt1αˉt ϵθ(x,t)

3.3.2 高阶Taylor展开

DPM-Solver使用Taylor展开提高精度:

x t i − 1 = α ˉ t i − 1 α ˉ t i x t i + α ˉ t i − 1 ∫ λ t i λ t i − 1 e − λ 1 − e − 2 λ ϵ θ ( x ^ λ , λ ) d λ \mathbf{x}_{t_{i-1}} = \frac{\sqrt{\bar{\alpha}_{t_{i-1}}}}{\sqrt{\bar{\alpha}_{t_i}}}\mathbf{x}_{t_i} + \sqrt{\bar{\alpha}_{t_{i-1}}}\int_{\lambda_{t_i}}^{\lambda_{t_{i-1}}} \frac{e^{-\lambda}}{\sqrt{1-e^{-2\lambda}}}\epsilon_\theta(\hat{\mathbf{x}}_\lambda, \lambda) d\lambda xti1=αˉti αˉti1 xti+αˉti1 λtiλti11e2λ eλϵθ(x^λ,λ)dλ

通过多项式插值获得 ϵ θ \epsilon_\theta ϵθ 的高阶近似。

四、潜在扩散模型的数学框架

4.1 编码器-解码器的理论基础

潜在扩散模型(LDM)在低维潜在空间中应用扩散过程,大幅降低计算成本。我们可以将其形式化为:

x = D ( z ) , z = E ( x ) \mathbf{x} = \mathcal{D}(\mathbf{z}), \quad \mathbf{z} = \mathcal{E}(\mathbf{x}) x=D(z),z=E(x)

其中 E : R H × W × C → R h × w × c \mathcal{E}: \mathbb{R}^{H \times W \times C} \rightarrow \mathbb{R}^{h \times w \times c} E:RH×W×CRh×w×c 是编码器, D : R h × w × c → R H × W × C \mathcal{D}: \mathbb{R}^{h \times w \times c} \rightarrow \mathbb{R}^{H \times W \times C} D:Rh×w×cRH×W×C 是解码器, z \mathbf{z} z 是潜在表示。

4.2 潜在空间的信息论分析

定理 4.1(潜在空间的信息保持):假设编码器-解码器满足近似重构条件:
E [ ∣ ∣ x − D ( E ( x ) ) ∣ ∣ 2 ] ≤ ϵ \mathbb{E}[||\mathbf{x} - \mathcal{D}(\mathcal{E}(\mathbf{x}))||^2] \leq \epsilon E[∣∣xD(E(x))2]ϵ

则潜在空间的扩散过程保持了原始数据的主要统计特性。

在潜在空间应用扩散过程:

q ( z t ∣ z 0 ) = N ( z t ; α ˉ t z 0 , ( 1 − α ˉ t ) I ) q(\mathbf{z}_t|\mathbf{z}_0) = \mathcal{N}(\mathbf{z}_t; \sqrt{\bar{\alpha}_t}\mathbf{z}_0, (1-\bar{\alpha}_t)\mathbf{I}) q(ztz0)=N(zt;αˉt z0,(1αˉt)I)

潜在扩散的完整目标函数为:

L LDM = E E ( x ) , ϵ , t [ ∣ ∣ ϵ − ϵ θ ( z t , t ) ∣ ∣ 2 ] + λ L rec + L reg \mathcal{L}_{\text{LDM}} = \mathbb{E}_{\mathcal{E}(\mathbf{x}), \boldsymbol{\epsilon}, t}\left[||\boldsymbol{\epsilon} - \epsilon_\theta(\mathbf{z}_t, t)||^2\right] + \lambda \mathcal{L}_{\text{rec}} + \mathcal{L}_{\text{reg}} LLDM=EE(x),ϵ,t[∣∣ϵϵθ(zt,t)2]+λLrec+Lreg

其中:

  • L rec = ∣ ∣ x − D ( E ( x ) ) ∣ ∣ 2 \mathcal{L}_{\text{rec}} = ||\mathbf{x} - \mathcal{D}(\mathcal{E}(\mathbf{x}))||^2 Lrec=∣∣xD(E(x))2 是重构损失
  • L reg \mathcal{L}_{\text{reg}} Lreg 是正则化项(如KL散度)

4.3 条件生成的数学建模

对于条件生成(如文本到图像),我们修改噪声预测网络:

ϵ θ ( z t , t , c ) ≈ E [ ϵ ∣ z t , c ] \epsilon_\theta(\mathbf{z}_t, t, \mathbf{c}) \approx \mathbb{E}[\boldsymbol{\epsilon}|\mathbf{z}_t, \mathbf{c}] ϵθ(zt,t,c)E[ϵzt,c]

其中 c \mathbf{c} c 是条件信息(如文本嵌入)。

4.3.1 分类器引导(Classifier Guidance)

分类器引导通过梯度 ∇ z log ⁡ p ϕ ( c ∣ z t ) \nabla_{\mathbf{z}} \log p_\phi(\mathbf{c}|\mathbf{z}_t) zlogpϕ(czt) 来增强条件生成:

ϵ ~ θ ( z t , t , c ) = ϵ θ ( z t , t ) − 1 − α ˉ t ⋅ s ⋅ ∇ z t log ⁡ p ϕ ( c ∣ z t ) \tilde{\epsilon}_\theta(\mathbf{z}_t, t, \mathbf{c}) = \epsilon_\theta(\mathbf{z}_t, t) - \sqrt{1-\bar{\alpha}_t} \cdot s \cdot \nabla_{\mathbf{z}_t} \log p_\phi(\mathbf{c}|\mathbf{z}_t) ϵ~θ(zt,t,c)=ϵθ(zt,t)1αˉt sztlogpϕ(czt)

其中 s s s 是引导强度。

4.3.2 无分类器引导(Classifier-free Guidance)

为避免训练额外的分类器,无分类器引导使用:

ϵ ~ θ ( z t , t , c ) = ϵ θ ( z t , t , ∅ ) + s ⋅ ( ϵ θ ( z t , t , c ) − ϵ θ ( z t , t , ∅ ) ) \tilde{\epsilon}_\theta(\mathbf{z}_t, t, \mathbf{c}) = \epsilon_\theta(\mathbf{z}_t, t, \emptyset) + s \cdot (\epsilon_\theta(\mathbf{z}_t, t, \mathbf{c}) - \epsilon_\theta(\mathbf{z}_t, t, \emptyset)) ϵ~θ(zt,t,c)=ϵθ(zt,t,)+s(ϵθ(zt,t,c)ϵθ(zt,t,))

定理 4.2:无分类器引导等价于隐式的分类器引导,其中隐式分类器为:
log ⁡ p implicit ( c ∣ z t ) ∝ log ⁡ p ( z t ∣ c ) p ( z t ) \log p_{\text{implicit}}(\mathbf{c}|\mathbf{z}_t) \propto \log \frac{p(\mathbf{z}_t|\mathbf{c})}{p(\mathbf{z}_t)} logpimplicit(czt)logp(zt)p(ztc)

五、扩散模型的训练和采样流程的算法分析

5.1 训练算法的复杂度分析

5.1.1 标准训练流程
扩散模型训练
输入: 数据集 D, 噪声调度 {βₜ}, 网络 εθ, 学习率 η
输出: 训练好的噪声预测网络 εθ

1. 预计算: ᾱₜ = ∏ᵢ₌₁ᵗ(1-βᵢ) for t = 1,...,T
2. while 未收敛 do:
3.    采样 x₀ ~ D
4.    采样 t ~ Uniform(1,T)
5.    采样 ε ~ N(0,I)
6.    计算 xₜ = √ᾱₜ x₀ + √(1-ᾱₜ) ε
7.    计算损失 L = ||ε - εθ(xₜ,t)||²
8.    更新参数 θ ← θ - η∇θL
9. end while

复杂度分析

  • 时间复杂度: O ( N ⋅ M ⋅ T train ) O(N \cdot M \cdot T_{\text{train}}) O(NMTtrain),其中 N N N 是数据集大小, M M M 是网络参数量, T train T_{\text{train}} Ttrain 是训练步数
  • 空间复杂度: O ( M + B ⋅ D ) O(M + B \cdot D) O(M+BD),其中 B B B 是批大小, D D D 是数据维度

5.2 采样算法的数值分析

5.2.1 DDPM采样的收敛性

定理 5.2(DDPM收敛性):假设噪声预测网络 ϵ θ \epsilon_\theta ϵθ 满足 L L L-Lipschitz条件,则DDPM采样的期望误差为:
E [ ∣ ∣ x 0 true − x 0 sample ∣ ∣ 2 ] ≤ C ⋅ T − 1 / 2 \mathbb{E}[||\mathbf{x}_0^{\text{true}} - \mathbf{x}_0^{\text{sample}}||^2] \leq C \cdot T^{-1/2} E[∣∣x0truex0sample2]CT1/2

其中 C C C 是依赖于 L L L 和数据分布的常数。

5.2.2 DDIM采样的确定性分析

定理 5.3(DDIM一致性):对于相同的噪声预测网络和初始噪声,DDIM采样是确定性的,且满足:
x 0 ( 1 ) = x 0 ( 2 )    ⟺    x T ( 1 ) = x T ( 2 ) \mathbf{x}_0^{(1)} = \mathbf{x}_0^{(2)} \iff \mathbf{x}_T^{(1)} = \mathbf{x}_T^{(2)} x0(1)=x0(2)xT(1)=xT(2)

这个性质使得DDIM支持语义插值和编辑。

5.3 条件采样的控制理论

5.3.1 引导采样的数学基础

对于条件生成 p ( x ∣ c ) p(\mathbf{x}|\mathbf{c}) p(xc),我们可以通过修改分数函数实现控制:

∇ x log ⁡ p ( x ∣ c ) = ∇ x log ⁡ p ( x ) + ∇ x log ⁡ p ( c ∣ x ) \nabla_\mathbf{x} \log p(\mathbf{x}|\mathbf{c}) = \nabla_\mathbf{x} \log p(\mathbf{x}) + \nabla_\mathbf{x} \log p(\mathbf{c}|\mathbf{x}) xlogp(xc)=xlogp(x)+xlogp(cx)

在扩散模型中,这对应于:

ϵ guided ( x t , t , c ) = ϵ θ ( x t , t ) − 1 − α ˉ t ⋅ s ⋅ ∇ x t log ⁡ p ( c ∣ x t ) \epsilon_{\text{guided}}(\mathbf{x}_t, t, \mathbf{c}) = \epsilon_\theta(\mathbf{x}_t, t) - \sqrt{1-\bar{\alpha}_t} \cdot s \cdot \nabla_{\mathbf{x}_t} \log p(\mathbf{c}|\mathbf{x}_t) ϵguided(xt,t,c)=ϵθ(xt,t)1αˉt sxtlogp(cxt)

5.3.2 引导强度的理论分析

定理 5.4:引导强度 s s s 控制生成样本的质量-多样性权衡:

  • s = 0 s = 0 s=0:无条件生成,最大多样性
  • s → ∞ s \rightarrow \infty s:最大条件符合度,最小多样性

最优引导强度可通过以下优化问题确定:

s ∗ = arg ⁡ max ⁡ s E [ log ⁡ p ( c ∣ x 0 ) ] − λ D K L ( p s ( x 0 ) ∣ ∣ p ( x 0 ) ) s^* = \arg\max_s \mathbb{E}[\log p(\mathbf{c}|\mathbf{x}_0)] - \lambda D_{KL}(p_s(\mathbf{x}_0) || p(\mathbf{x}_0)) s=argsmaxE[logp(cx0)]λDKL(ps(x0)∣∣p(x0))

其中 p s ( x 0 ) p_s(\mathbf{x}_0) ps(x0) 是引导强度为 s s s 时的生成分布。

六、扩散模型的代码实现示例

下面是一个使用PyTorch实现的扩散模型:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import math
from tqdm import tqdm
from typing import Optional, Tuple, List

# 改进的位置编码
class SinusoidalPositionEmbeddings(nn.Module):
    """基于正弦函数的位置编码,用于时间步嵌入"""
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(self, time: torch.Tensor) -> torch.Tensor:
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

# 注意力机制模块
class AttentionBlock(nn.Module):
    """自注意力块,用于捕获全局依赖关系"""
    def __init__(self, channels: int, num_heads: int = 8):
        super().__init__()
        self.num_heads = num_heads
        self.norm = nn.GroupNorm(32, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.out = nn.Conv2d(channels, channels, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        h = self.norm(x)
        qkv = self.qkv(h).view(B, 3, self.num_heads, C // self.num_heads, H * W)
        q, k, v = qkv.unbind(1)
        
        # 计算注意力权重
        w = torch.einsum('bhcn,bhcm->bhnm', q, k) * (C // self.num_heads) ** -0.5
        w = F.softmax(w, dim=-1)
        
        # 应用注意力
        h = torch.einsum('bhnm,bhcm->bhcn', w, v)
        h = h.view(B, C, H, W)
        
        return x + self.out(h)

# 残差块
class ResidualBlock(nn.Module):
    """带时间嵌入的残差块"""
    def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int, dropout: float = 0.1):
        super().__init__()
        self.norm1 = nn.GroupNorm(32, in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.norm2 = nn.GroupNorm(32, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.dropout = nn.Dropout(dropout)
        
        self.residual_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
        h = self.conv1(F.silu(self.norm1(x)))
        
        # 添加时间嵌入
        time_emb = F.silu(self.time_mlp(time_emb))
        h = h + time_emb[:, :, None, None]
        
        h = self.conv2(self.dropout(F.silu(self.norm2(h))))
        
        return h + self.residual_conv(x)

# 改进的U-Net架构
class ImprovedUNet(nn.Module):
    """改进的U-Net,包含注意力机制和更好的特征提取"""
    def __init__(
        self, 
        image_channels: int = 3, 
        base_channels: int = 128,
        channel_multipliers: List[int] = [1, 2, 4, 8],
        num_res_blocks: int = 2,
        dropout: float = 0.1,
        use_attention: List[bool] = [False, False, True, True]
    ):
        super().__init__()
        
        self.time_embed_dim = base_channels * 4
        self.time_embed = nn.Sequential(
            SinusoidalPositionEmbeddings(base_channels),
            nn.Linear(base_channels, self.time_embed_dim),
            nn.SiLU(),
            nn.Linear(self.time_embed_dim, self.time_embed_dim),
        )
        
        # 初始卷积
        self.init_conv = nn.Conv2d(image_channels, base_channels, 3, padding=1)
        
        # 下采样路径
        self.down_blocks = nn.ModuleList([])
        self.down_samples = nn.ModuleList([])
        
        channels = [base_channels] + [base_channels * m for m in channel_multipliers]
        for i in range(len(channel_multipliers)):
            in_ch, out_ch = channels[i], channels[i + 1]
            
            # 残差块组
            blocks = nn.ModuleList([])
            for _ in range(num_res_blocks):
                blocks.append(ResidualBlock(in_ch, out_ch, self.time_embed_dim, dropout))
                in_ch = out_ch
            
            # 注意力块
            if use_attention[i]:
                blocks.append(AttentionBlock(out_ch))
            
            self.down_blocks.append(blocks)
            
            # 下采样(除了最后一层)
            if i < len(channel_multipliers) - 1:
                self.down_samples.append(nn.Conv2d(out_ch, out_ch, 4, stride=2, padding=1))
            else:
                self.down_samples.append(nn.Identity())
        
        # 中间层
        mid_channels = channels[-1]
        self.mid_block1 = ResidualBlock(mid_channels, mid_channels, self.time_embed_dim, dropout)
        self.mid_attention = AttentionBlock(mid_channels)
        self.mid_block2 = ResidualBlock(mid_channels, mid_channels, self.time_embed_dim, dropout)
        
        # 上采样路径
        self.up_samples = nn.ModuleList([])
        self.up_blocks = nn.ModuleList([])
        
        for i in reversed(range(len(channel_multipliers))):
            in_ch = channels[i + 1]
            out_ch = channels[i]
            
            # 上采样
            if i < len(channel_multipliers) - 1:
                self.up_samples.append(nn.ConvTranspose2d(in_ch, in_ch, 4, stride=2, padding=1))
            else:
                self.up_samples.append(nn.Identity())
            
            # 残差块组(注意跳跃连接会使通道数翻倍)
            blocks = nn.ModuleList([])
            for j in range(num_res_blocks + 1):
                blocks.append(ResidualBlock(
                    in_ch + out_ch if j == 0 else out_ch, 
                    out_ch, 
                    self.time_embed_dim, 
                    dropout
                ))
            
            # 注意力块
            if use_attention[i]:
                blocks.append(AttentionBlock(out_ch))
            
            self.up_blocks.append(blocks)
        
        # 输出层
        self.out_norm = nn.GroupNorm(32, base_channels)
        self.out_conv = nn.Conv2d(base_channels, image_channels, 3, padding=1)

    def forward(self, x: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
        # 时间嵌入
        time_emb = self.time_embed(time)
        
        # 初始卷积
        h = self.init_conv(x)
        
        # 存储跳跃连接
        skip_connections = [h]
        
        # 下采样路径
        for i, (blocks, downsample) in enumerate(zip(self.down_blocks, self.down_samples)):
            for block in blocks:
                if isinstance(block, ResidualBlock):
                    h = block(h, time_emb)
                else:  # AttentionBlock
                    h = block(h)
            skip_connections.append(h)
            h = downsample(h)
        
        # 中间层
        h = self.mid_block1(h, time_emb)
        h = self.mid_attention(h)
        h = self.mid_block2(h, time_emb)
        
        # 上采样路径
        for i, (upsample, blocks) in enumerate(zip(self.up_samples, self.up_blocks)):
            h = upsample(h)
            # 跳跃连接
            skip = skip_connections.pop()
            h = torch.cat([h, skip], dim=1)
            
            for block in blocks:
                if isinstance(block, ResidualBlock):
                    h = block(h, time_emb)
                else:  # AttentionBlock
                    h = block(h)
        
        # 输出
        h = F.silu(self.out_norm(h))
        return self.out_conv(h)

# 噪声调度器
class NoiseScheduler:
    """改进的噪声调度器,支持多种调度策略"""
    
    def __init__(
        self, 
        num_steps: int = 1000, 
        beta_start: float = 1e-4, 
        beta_end: float = 0.02,
        schedule_type: str = "linear"
    ):
        self.num_steps = num_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        
        if schedule_type == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_steps)
        elif schedule_type == "cosine":
            self.betas = self._cosine_schedule()
        elif schedule_type == "sigmoid":
            self.betas = self._sigmoid_schedule()
        else:
            raise ValueError(f"Unknown schedule type: {schedule_type}")
        
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # 用于DDPM采样的预计算值
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        
        # 后验方差
        self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        self.posterior_log_variance_clipped = torch.log(torch.clamp(self.posterior_variance, min=1e-20))
        
        # 后验均值系数
        self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)

    def _cosine_schedule(self) -> torch.Tensor:
        """余弦噪声调度"""
        steps = torch.arange(self.num_steps + 1, dtype=torch.float32) / self.num_steps
        alphas_cumprod = torch.cos((steps + 0.008) / 1.008 * math.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clamp(betas, 0.0001, 0.9999)

    def _sigmoid_schedule(self) -> torch.Tensor:
        """Sigmoid噪声调度"""
        steps = torch.arange(self.num_steps, dtype=torch.float32) / self.num_steps
        betas = torch.sigmoid((steps - 0.5) * 12) * (self.beta_end - self.beta_start) + self.beta_start
        return betas

    def add_noise(
        self, 
        original_samples: torch.Tensor, 
        noise: torch.Tensor, 
        timesteps: torch.Tensor
    ) -> torch.Tensor:
        """添加噪声到原始样本"""
        sqrt_alpha_prod = self.sqrt_alphas_cumprod[timesteps]
        sqrt_one_minus_alpha_prod = self.sqrt_one_minus_alphas_cumprod[timesteps]
        
        # 广播到正确的形状
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
        
        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

    def get_posterior_mean_variance(
        self, 
        x_0: torch.Tensor, 
        x_t: torch.Tensor, 
        t: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """计算后验分布的均值和方差"""
        coef1 = self.posterior_mean_coef1[t]
        coef2 = self.posterior_mean_coef2[t]
        variance = self.posterior_variance[t]
        
        while len(coef1.shape) < len(x_0.shape):
            coef1 = coef1.unsqueeze(-1)
            coef2 = coef2.unsqueeze(-1)
            variance = variance.unsqueeze(-1)
        
        mean = coef1 * x_0 + coef2 * x_t
        return mean, variance

# 改进的扩散模型类
class AdvancedDiffusionModel:
    """改进的扩散模型,包含多种采样策略和损失函数"""
    
    def __init__(
        self, 
        model: nn.Module,
        noise_scheduler: NoiseScheduler,
        device: str = "cuda"
    ):
        self.model = model
        self.noise_scheduler = noise_scheduler
        self.device = device
        
        # 移动调度器参数到设备
        for attr_name in dir(self.noise_scheduler):
            attr = getattr(self.noise_scheduler, attr_name)
            if isinstance(attr, torch.Tensor):
                setattr(self.noise_scheduler, attr_name, attr.to(device))

    def train_step(
        self, 
        clean_images: torch.Tensor, 
        optimizer: torch.optim.Optimizer,
        loss_type: str = "l2"
    ) -> float:
        """改进的训练步骤,支持多种损失函数"""
        batch_size = clean_images.shape[0]
        
        # 随机采样时间步
        timesteps = torch.randint(0, self.noise_scheduler.num_steps, (batch_size,), device=self.device)
        
        # 采样噪声
        noise = torch.randn_like(clean_images)
        
        # 添加噪声
        noisy_images = self.noise_scheduler.add_noise(clean_images, noise, timesteps)
        
        # 预测噪声
        predicted_noise = self.model(noisy_images, timesteps)
        
        # 计算损失
        if loss_type == "l1":
            loss = F.l1_loss(predicted_noise, noise)
        elif loss_type == "l2":
            loss = F.mse_loss(predicted_noise, noise)
        elif loss_type == "huber":
            loss = F.huber_loss(predicted_noise, noise)
        else:
            raise ValueError(f"Unknown loss type: {loss_type}")
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        return loss.item()

    @torch.no_grad()
    def ddpm_sample(
        self, 
        batch_size: int = 1, 
        image_size: Tuple[int, int] = (64, 64),
        num_channels: int = 3,
        eta: float = 1.0
    ) -> torch.Tensor:
        """DDPM采样"""
        shape = (batch_size, num_channels, *image_size)
        x = torch.randn(shape, device=self.device)
        
        for t in tqdm(reversed(range(self.noise_scheduler.num_steps)), desc="DDPM Sampling"):
            timesteps = torch.full((batch_size,), t, device=self.device, dtype=torch.long)
            
            # 预测噪声
            predicted_noise = self.model(x, timesteps)
            
            # 计算 x_0 的预测
            alpha_t = self.noise_scheduler.alphas_cumprod[t]
            sqrt_alpha_t = torch.sqrt(alpha_t)
            sqrt_one_minus_alpha_t = torch.sqrt(1 - alpha_t)
            
            predicted_x0 = (x - sqrt_one_minus_alpha_t * predicted_noise) / sqrt_alpha_t
            
            if t > 0:
                # 计算后验均值和方差
                mean, variance = self.noise_scheduler.get_posterior_mean_variance(predicted_x0, x, timesteps)
                
                # 添加噪声(DDPM的随机性)
                noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
                x = mean + torch.sqrt(variance) * noise * eta
            else:
                x = predicted_x0
        
        return x

    @torch.no_grad()
    def ddim_sample(
        self, 
        batch_size: int = 1, 
        image_size: Tuple[int, int] = (64, 64),
        num_channels: int = 3,
        num_inference_steps: int = 50,
        eta: float = 0.0
    ) -> torch.Tensor:
        """DDIM采样"""
        # 创建时间步子序列
        step_ratio = self.noise_scheduler.num_steps // num_inference_steps
        timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64)
        timesteps = torch.from_numpy(timesteps).to(self.device)
        
        shape = (batch_size, num_channels, *image_size)
        x = torch.randn(shape, device=self.device)
        
        # 添加最终时间步
        timesteps = torch.cat([timesteps, torch.tensor([self.noise_scheduler.num_steps-1], device=self.device)])
        
        for i in tqdm(reversed(range(len(timesteps) - 1)), desc="DDIM Sampling"):
            t = timesteps[i + 1]
            prev_t = timesteps[i]
            
            timestep_batch = torch.full((batch_size,), t, device=self.device, dtype=torch.long)
            
            # 预测噪声
            predicted_noise = self.model(x, timestep_batch)
            
            # 获取对应的alpha值
            alpha_t = self.noise_scheduler.alphas_cumprod[t]
            alpha_prev_t = self.noise_scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0)
            
            # 预测x_0
            predicted_x0 = (x - torch.sqrt(1 - alpha_t) * predicted_noise) / torch.sqrt(alpha_t)
            
            # 计算方向向量
            direction = torch.sqrt(1 - alpha_prev_t - eta**2 * (1 - alpha_t) / (1 - alpha_prev_t) * (1 - alpha_prev_t / alpha_t))
            
            # 噪声项
            noise = torch.randn_like(x) if eta > 0 else torch.zeros_like(x)
            variance = eta * torch.sqrt((1 - alpha_prev_t) / (1 - alpha_t)) * torch.sqrt(1 - alpha_t / alpha_prev_t)
            
            # DDIM更新
            x = torch.sqrt(alpha_prev_t) * predicted_x0 + direction * predicted_noise + variance * noise
        
        return x

    @torch.no_grad()
    def sample_with_guidance(
        self, 
        batch_size: int = 1,
        image_size: Tuple[int, int] = (64, 64),
        num_channels: int = 3,
        guidance_scale: float = 7.5,
        num_inference_steps: int = 50,
        conditional_input: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """带引导的采样(支持条件生成)"""
        if conditional_input is None:
            return self.ddim_sample(batch_size, image_size, num_channels, num_inference_steps)
        
        # 无分类器引导采样
        shape = (batch_size, num_channels, *image_size)
        x = torch.randn(shape, device=self.device)
        
        step_ratio = self.noise_scheduler.num_steps // num_inference_steps
        timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64)
        timesteps = torch.from_numpy(timesteps).to(self.device)
        timesteps = torch.cat([timesteps, torch.tensor([self.noise_scheduler.num_steps-1], device=self.device)])
        
        for i in tqdm(reversed(range(len(timesteps) - 1)), desc="Guided Sampling"):
            t = timesteps[i + 1]
            prev_t = timesteps[i]
            
            timestep_batch = torch.full((batch_size,), t, device=self.device, dtype=torch.long)
            
            # 条件和无条件预测
            if hasattr(self.model, 'forward_with_cond'):
                noise_pred_cond = self.model.forward_with_cond(x, timestep_batch, conditional_input)
                noise_pred_uncond = self.model.forward_with_cond(x, timestep_batch, None)
            else:
                # 如果模型不支持条件输入,回退到普通采样
                noise_pred_cond = self.model(x, timestep_batch)
                noise_pred_uncond = noise_pred_cond
            
            # 应用引导
            predicted_noise = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
            
            # DDIM更新步骤
            alpha_t = self.noise_scheduler.alphas_cumprod[t]
            alpha_prev_t = self.noise_scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else torch.tensor(1.0)
            
            predicted_x0 = (x - torch.sqrt(1 - alpha_t) * predicted_noise) / torch.sqrt(alpha_t)
            direction = torch.sqrt(1 - alpha_prev_t)
            
            x = torch.sqrt(alpha_prev_t) * predicted_x0 + direction * predicted_noise
        
        return x

    def calculate_loss_weights(self, timesteps: torch.Tensor) -> torch.Tensor:
        """计算基于SNR的损失权重"""
        snr = self.noise_scheduler.alphas_cumprod[timesteps] / (1 - self.noise_scheduler.alphas_cumprod[timesteps])
        # Min-SNR权重策略
        snr_clamped = torch.clamp(snr, max=5.0)
        weights = snr_clamped / (snr_clamped + 1)
        return weights

# 训练和评估函数
class DiffusionTrainer:
    """扩散模型训练器"""
    
    def __init__(
        self,
        model: nn.Module,
        noise_scheduler: NoiseScheduler,
        optimizer: torch.optim.Optimizer,
        device: str = "cuda"
    ):
        self.diffusion_model = AdvancedDiffusionModel(model, noise_scheduler, device)
        self.optimizer = optimizer
        self.device = device
        
    def train_epoch(
        self, 
        dataloader, 
        use_ema: bool = True,
        ema_decay: float = 0.9999,
        gradient_clip_norm: float = 1.0
    ) -> float:
        """训练一个epoch"""
        self.diffusion_model.model.train()
        total_loss = 0.0
        num_batches = len(dataloader)
        
        # EMA模型(可选)
        if use_ema and not hasattr(self, 'ema_model'):
            self.ema_model = self._create_ema_model(ema_decay)
        
        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Training")):
            if isinstance(batch, (list, tuple)):
                images = batch[0].to(self.device)
            else:
                images = batch.to(self.device)
            
            # 训练步骤
            loss = self.diffusion_model.train_step(images, self.optimizer)
            total_loss += loss
            
            # 梯度裁剪
            if gradient_clip_norm > 0:
                torch.nn.utils.clip_grad_norm_(self.diffusion_model.model.parameters(), gradient_clip_norm)
            
            # 更新EMA模型
            if use_ema:
                self._update_ema_model()
        
        return total_loss / num_batches

    def _create_ema_model(self, decay: float):
        """创建EMA模型"""
        ema_model = type(self.diffusion_model.model)(
            **{k: v for k, v in self.diffusion_model.model.__dict__.items() 
               if not k.startswith('_')}
        ).to(self.device)
        ema_model.load_state_dict(self.diffusion_model.model.state_dict())
        ema_model.eval()
        
        # 存储EMA参数
        self.ema_params = {}
        for name, param in ema_model.named_parameters():
            self.ema_params[name] = param.clone().detach()
        
        self.ema_decay = decay
        return ema_model

    def _update_ema_model(self):
        """更新EMA模型参数"""
        for name, param in self.diffusion_model.model.named_parameters():
            if name in self.ema_params:
                self.ema_params[name].mul_(self.ema_decay).add_(param.data, alpha=1 - self.ema_decay)

    def validate(self, dataloader, num_samples: int = 16) -> Tuple[float, torch.Tensor]:
        """验证模型性能"""
        self.diffusion_model.model.eval()
        total_loss = 0.0
        num_batches = 0
        
        with torch.no_grad():
            # 计算验证损失
            for batch in dataloader:
                if isinstance(batch, (list, tuple)):
                    images = batch[0].to(self.device)
                else:
                    images = batch.to(self.device)
                
                batch_size = images.shape[0]
                timesteps = torch.randint(0, self.diffusion_model.noise_scheduler.num_steps, 
                                        (batch_size,), device=self.device)
                noise = torch.randn_like(images)
                noisy_images = self.diffusion_model.noise_scheduler.add_noise(images, noise, timesteps)
                predicted_noise = self.diffusion_model.model(noisy_images, timesteps)
                loss = F.mse_loss(predicted_noise, noise)
                total_loss += loss.item()
                num_batches += 1
            
            # 生成样本
            sample_images = self.diffusion_model.ddim_sample(
                batch_size=num_samples,
                image_size=(64, 64),  # 根据实际情况调整
                num_channels=3,
                num_inference_steps=50
            )
        
        avg_loss = total_loss / num_batches if num_batches > 0 else float('inf')
        return avg_loss, sample_images

# 使用示例和完整训练循环
def main_training_example():
    """完整的训练示例"""
    
    # 设备设置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 模型配置
    model_config = {
        "image_channels": 3,
        "base_channels": 128,
        "channel_multipliers": [1, 2, 4, 8],
        "num_res_blocks": 2,
        "dropout": 0.1,
        "use_attention": [False, False, True, True]
    }
    
    # 初始化模型
    model = ImprovedUNet(**model_config).to(device)
    
    # 噪声调度器
    noise_scheduler = NoiseScheduler(
        num_steps=1000,
        beta_start=1e-4,
        beta_end=0.02,
        schedule_type="cosine"
    )
    
    # 优化器
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=2e-4,
        weight_decay=1e-4,
        betas=(0.9, 0.999)
    )
    
    # 学习率调度器
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 
        T_max=1000,
        eta_min=1e-6
    )
    
    # 训练器
    trainer = DiffusionTrainer(model, noise_scheduler, optimizer, device)
    
    # 模拟数据加载器(实际使用时替换为真实数据)
    class DummyDataLoader:
        def __init__(self, batch_size=16, num_batches=100):
            self.batch_size = batch_size
            self.num_batches = num_batches
        
        def __len__(self):
            return self.num_batches
        
        def __iter__(self):
            for _ in range(self.num_batches):
                yield torch.randn(self.batch_size, 3, 64, 64)
    
    train_dataloader = DummyDataLoader(batch_size=16, num_batches=100)
    val_dataloader = DummyDataLoader(batch_size=16, num_batches=20)
    
    # 训练循环
    num_epochs = 100
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # 训练
        train_loss = trainer.train_epoch(
            train_dataloader,
            use_ema=True,
            gradient_clip_norm=1.0
        )
        
        # 验证
        if (epoch + 1) % 10 == 0:
            val_loss, sample_images = trainer.validate(val_dataloader)
            
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Train Loss: {train_loss:.6f}")
            print(f"Val Loss: {val_loss:.6f}")
            print(f"Learning Rate: {scheduler.get_last_lr()[0]:.2e}")
            
            # 保存最佳模型
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_loss': val_loss,
                    'model_config': model_config
                }, 'best_diffusion_model.pth')
                print(f"Saved best model with val_loss: {val_loss:.6f}")
        
        # 更新学习率
        scheduler.step()
        
        # 定期生成样本进行视觉检查
        if (epoch + 1) % 20 == 0:
            print("Generating samples...")
            with torch.no_grad():
                samples = trainer.diffusion_model.ddim_sample(
                    batch_size=4,
                    image_size=(64, 64),
                    num_channels=3,
                    num_inference_steps=50
                )
                # 这里可以保存或显示样本
                # save_image_samples(samples, f'samples_epoch_{epoch+1}.png')

if __name__ == "__main__":
    main_training_example()

七、扩散模型与其他生成模型的理论对比

7.1 生成模型的统一理论框架

7.1.1 信息论视角的统一

所有生成模型都可以在信息论框架下统一理解:

定理 7.1(生成模型的信息论等价性):给定数据分布 p data ( x ) p_{\text{data}}(\mathbf{x}) pdata(x),生成模型的目标可以表述为最小化以下信息论距离之一:

  1. KL散度 D K L ( p data ∣ ∣ p θ ) = E p data [ log ⁡ p data ( x ) − log ⁡ p θ ( x ) ] D_{KL}(p_{\text{data}} || p_\theta) = \mathbb{E}_{p_{\text{data}}}[\log p_{\text{data}}(\mathbf{x}) - \log p_\theta(\mathbf{x})] DKL(pdata∣∣pθ)=Epdata[logpdata(x)logpθ(x)]
  2. JS散度 D J S ( p data , p θ ) = 1 2 D K L ( p data ∣ ∣ p m ) + 1 2 D K L ( p θ ∣ ∣ p m ) D_{JS}(p_{\text{data}}, p_\theta) = \frac{1}{2}D_{KL}(p_{\text{data}} || p_m) + \frac{1}{2}D_{KL}(p_\theta || p_m) DJS(pdata,pθ)=21DKL(pdata∣∣pm)+21DKL(pθ∣∣pm)
  3. Wasserstein距离 W 1 ( p data , p θ ) = inf ⁡ γ ∈ Γ E ( x , y ) ∼ γ [ ∣ ∣ x − y ∣ ∣ ] W_1(p_{\text{data}}, p_\theta) = \inf_{\gamma \in \Gamma} \mathbb{E}_{(\mathbf{x}, \mathbf{y}) \sim \gamma}[||\mathbf{x} - \mathbf{y}||] W1(pdata,pθ)=infγΓE(x,y)γ[∣∣xy∣∣]

其中 p m = 1 2 ( p data + p θ ) p_m = \frac{1}{2}(p_{\text{data}} + p_\theta) pm=21(pdata+pθ) Γ \Gamma Γ 是边际分布为 p data p_{\text{data}} pdata p θ p_\theta pθ 的所有联合分布。

7.1.2 详细理论对比表
特性扩散模型GANVAEFlow自回归模型
数学基础马尔可夫链/SDE博弈论变分推断可逆变换条件概率链式法则
似然估计隐式(通过ELBO)隐式显式下界显式精确显式精确
采样复杂度 O ( T ) O(T) O(T)(迭代) O ( 1 ) O(1) O(1)(单步) O ( 1 ) O(1) O(1)(单步) O ( 1 ) O(1) O(1)(单步) O ( n ) O(n) O(n)(序列长度)
训练稳定性高(凸优化)低(鞍点问题)中(局部最优)高(最大似然)高(最大似然)
模式崩塌常见罕见
采样质量中-高
内存需求高(长序列)

7.2 收敛性理论比较

7.2.1 扩散模型的收敛保证

定理 7.2(扩散模型的收敛性):假设噪声预测网络 ϵ θ \epsilon_\theta ϵθ 属于函数类 F \mathcal{F} F,且满足:

  1. F \mathcal{F} F 具有有限的Rademacher复杂度 R n ( F ) = O ( n − 1 / 2 ) \mathcal{R}_n(\mathcal{F}) = O(n^{-1/2}) Rn(F)=O(n1/2)
  2. 损失函数 ℓ \ell L L L-Lipschitz连续的

则存在常数 C C C,使得扩散模型的泛化误差满足:
E [ L ( ϵ θ ) ] − min ⁡ ϵ ∈ F L ( ϵ ) ≤ C ⋅ n − 1 / 2 \mathbb{E}[\mathcal{L}(\epsilon_\theta)] - \min_{\epsilon \in \mathcal{F}} \mathcal{L}(\epsilon) \leq C \cdot n^{-1/2} E[L(ϵθ)]minϵFL(ϵ)Cn1/2

7.2.2 GAN的收敛性挑战

定理 7.3(GAN的非收敛性):对于连续策略空间,纯策略Nash均衡可能不存在。即使存在,梯度下降算法也可能无法收敛到均衡点。

这解释了为什么GAN训练困难且不稳定。

7.3 表达能力的理论分析

7.3.1 通用逼近定理

定理 7.4(扩散模型的通用逼近能力):给定足够大的网络容量和充分的训练时间,扩散模型可以学习任意连续概率分布 p ( x ) p(\mathbf{x}) p(x),只要该分布满足:
∫ ∣ ∣ ∇ x log ⁡ p ( x ) ∣ ∣ 2 p ( x ) d x < ∞ \int ||\nabla_\mathbf{x} \log p(\mathbf{x})||^2 p(\mathbf{x}) d\mathbf{x} < \infty ∣∣xlogp(x)2p(x)dx<

这个条件对大多数自然数据分布都成立。

八、扩散模型的理论限制与前沿研究方向

8.1 当前理论局限的数学分析

8.1.1 采样效率的下界

定理 8.1(采样复杂度下界):对于精度 ϵ > 0 \epsilon > 0 ϵ>0,从扩散模型采样一个与真实分布距离不超过 ϵ \epsilon ϵ 的样本,所需的最小迭代次数为:
T min ⁡ = Ω ( d log ⁡ ( 1 / ϵ ) ϵ 2 ) T_{\min} = \Omega\left(\frac{d \log(1/\epsilon)}{\epsilon^2}\right) Tmin=Ω(ϵ2dlog(1/ϵ))

其中 d d d 是数据维度。这表明扩散模型的采样复杂度随维度线性增长。

8.1.2 离散数据的理论局限

对于离散数据(如文本),连续扩散过程的近似会引入额外误差:

引理 8.1:设离散分布 p discrete p_{\text{discrete}} pdiscrete 和其连续松弛 p continuous p_{\text{continuous}} pcontinuous,则总变差距离满足:
TV ( p discrete , p continuous ) ≥ 1 2 ∑ k ∣ p discrete ( k ) − ∫ k − 0.5 k + 0.5 p continuous ( x ) d x ∣ \text{TV}(p_{\text{discrete}}, p_{\text{continuous}}) \geq \frac{1}{2}\sum_{k} |p_{\text{discrete}}(k) - \int_{k-0.5}^{k+0.5} p_{\text{continuous}}(x) dx| TV(pdiscrete,pcontinuous)21kpdiscrete(k)k0.5k+0.5pcontinuous(x)dx

8.1.3 长程依赖的建模挑战

定理 8.2:对于序列长度为 L L L 的数据,扩散模型学习长程依赖关系的样本复杂度为 O ( L 2 log ⁡ L ) O(L^2 \log L) O(L2logL),这限制了其在长序列建模中的效率。

8.2 前沿研究方向的数学基础

8.2.1 一致性模型(Consistency Models)

一致性模型通过学习一个一致性函数 f : ( x t , t ) ↦ x 0 f: (\mathbf{x}_t, t) \mapsto \mathbf{x}_0 f:(xt,t)x0 来实现单步生成:

定义 8.1:一致性函数满足自一致性条件:
f ( x t , t ) = f ( f ( x s , s ) , t ) ∀ s < t f(\mathbf{x}_t, t) = f(f(\mathbf{x}_s, s), t) \quad \forall s < t f(xt,t)=f(f(xs,s),t)s<t

定理 8.3:存在唯一的一致性函数,且可以通过以下目标学习:
L consistency = E [ d ( f ( x t + 1 , t + 1 ) , f ( stopgrad ( x ^ t ϕ ) , t ) ) ] \mathcal{L}_{\text{consistency}} = \mathbb{E}[d(f(\mathbf{x}_{t+1}, t+1), f(\text{stopgrad}(\hat{\mathbf{x}}_t^\phi), t))] Lconsistency=E[d(f(xt+1,t+1),f(stopgrad(x^tϕ),t))]

其中 d d d 是距离函数, stopgrad \text{stopgrad} stopgrad 阻止梯度传播。

8.2.2 修正流(Rectified Flow)

修正流通过学习最短路径来提高采样效率:

定义 8.2:修正流的速度场定义为:
v t ( x ) = E [ x 1 − x 0 ∣ x t = x ] \mathbf{v}_t(\mathbf{x}) = \mathbb{E}[\mathbf{x}_1 - \mathbf{x}_0 | \mathbf{x}_t = \mathbf{x}] vt(x)=E[x1x0xt=x]

对应的ODE为:
d x d t = v t ( x ) \frac{d\mathbf{x}}{dt} = \mathbf{v}_t(\mathbf{x}) dtdx=vt(x)

定理 8.4:修正流的理论优势在于其路径的平直度:
E [ ∫ 0 1 ∣ ∣ v t ( x t ) ∣ ∣ 2 d t ] ≤ E [ ∣ ∣ x 1 − x 0 ∣ ∣ 2 ] \mathbb{E}[\int_0^1 ||\mathbf{v}_t(\mathbf{x}_t)||^2 dt] \leq \mathbb{E}[||\mathbf{x}_1 - \mathbf{x}_0||^2] E[01∣∣vt(xt)2dt]E[∣∣x1x02]

8.2.3 结构化预测与约束生成

对于结构化数据,可以在生成过程中引入硬约束:

定义 8.3:约束扩散模型的更新规则为:
x t − 1 = Π C ( μ θ ( x t , t ) + σ t ϵ ) \mathbf{x}_{t-1} = \Pi_{\mathcal{C}}(\boldsymbol{\mu}_\theta(\mathbf{x}_t, t) + \sigma_t \boldsymbol{\epsilon}) xt1=ΠC(μθ(xt,t)+σtϵ)

其中 Π C \Pi_{\mathcal{C}} ΠC 是到约束集合 C \mathcal{C} C 的投影算子。

定理 8.5:约束投影不会显著增加采样误差,当约束集合 C \mathcal{C} C 是凸集时:
∣ ∣ Π C ( x ) − Π C ( y ) ∣ ∣ ≤ ∣ ∣ x − y ∣ ∣ ||\Pi_{\mathcal{C}}(\mathbf{x}) - \Pi_{\mathcal{C}}(\mathbf{y})|| \leq ||\mathbf{x} - \mathbf{y}|| ∣∣ΠC(x)ΠC(y)∣∣∣∣xy∣∣

8.3 多模态扩散的统一理论

8.3.1 联合分布建模

对于多模态数据 ( x , y ) (\mathbf{x}, \mathbf{y}) (x,y),联合扩散过程为:
q ( ( x t , y t ) ∣ ( x 0 , y 0 ) ) = N ( ( α ˉ t x 0 , α ˉ t y 0 ) , ( 1 − α ˉ t ) I ) q((\mathbf{x}_t, \mathbf{y}_t)|(\mathbf{x}_0, \mathbf{y}_0)) = \mathcal{N}((\sqrt{\bar{\alpha}_t}\mathbf{x}_0, \sqrt{\bar{\alpha}_t}\mathbf{y}_0), (1-\bar{\alpha}_t)\mathbf{I}) q((xt,yt)(x0,y0))=N((αˉt x0,αˉt y0),(1αˉt)I)

定理 8.6:联合扩散模型的条件生成 p ( x ∣ y ) p(\mathbf{x}|\mathbf{y}) p(xy) 可以通过以下方式实现:
ϵ x ∣ y ( x t , y t , t ) = ϵ θ ( [ x t , y t ] , t ) x \epsilon_{\mathbf{x}|\mathbf{y}}(\mathbf{x}_t, \mathbf{y}_t, t) = \epsilon_\theta([\mathbf{x}_t, \mathbf{y}_t], t)_{\mathbf{x}} ϵxy(xt,yt,t)=ϵθ([xt,yt],t)x

其中 [ ⋅ , ⋅ ] [\cdot, \cdot] [,] 表示拼接,下标 x \mathbf{x} x 表示输出的对应部分。

8.3.2 跨模态对齐的信息论分析

定理 8.7:多模态扩散模型隐式最大化模态间的互信息:
I ( X ; Y ) = E [ log ⁡ p ( x , y ) ] − E [ log ⁡ p ( x ) ] − E [ log ⁡ p ( y ) ] I(\mathbf{X}; \mathbf{Y}) = \mathbb{E}[\log p(\mathbf{x}, \mathbf{y})] - \mathbb{E}[\log p(\mathbf{x})] - \mathbb{E}[\log p(\mathbf{y})] I(X;Y)=E[logp(x,y)]E[logp(x)]E[logp(y)]

这解释了为什么联合训练能够学习到语义对齐的表示。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

DuHz

喜欢就支持一下 ~ 谢谢啦!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值