Euler-Maruyama方法,模拟随机微分方程(SDE)

实现了Euler-Maruyama方法,用来从基于得分匹配模型的扩散过程生成样本。它模拟了一个随机微分方程(SDE),逐步从噪声中采样数据点,并通过时间反转过程回归到数据分布。


参数说明

  1. score_model:
    一个 PyTorch 模型,表示时间相关的得分函数,提供对数概率密度的梯度 ∇ x log ⁡ p t ( x ) \nabla_x \log p_t(x) xlogpt(x)

  2. marginal_prob_std:
    给定时间 t t t,返回扰动核的标准差 σ ( t ) \sigma(t) σ(t)。用于定义扩散过程的噪声。

  3. diffusion_coeff:
    返回 SDE 的扩散系数 g ( t ) g(t) g(t)。这是 SDE 中噪声的强度。

  4. batch_size:
    每次采样生成的样本数。

  5. num_steps:
    时间反转过程的离散步数(采样步数),即将时间轴分割为 n u m _ s t e p s num\_steps num_steps 段。

  6. device:
    运行设备,通常为 ‘cuda’ 或 ‘cpu’。

  7. eps:
    最小时间步长,用于数值稳定性,避免时间 t = 0 t=0 t=0


解析

1. 初始化采样时间 t = 1 t=1 t=1
t = torch.ones(batch_size, device=device)

初始化每个样本的时间 t = 1 t=1 t=1,对应扩散过程的初始点(完全被噪声扰动的状态)。

2. 初始化样本 x x x
init_x = torch.randn(batch_size, 32, 48, device=device) * marginal_prob_std(t)[:, None, None]

生成随机噪声初始化样本 x ∼ N ( 0 , σ ( t ) 2 ) x \sim \mathcal{N}(0, \sigma(t)^2) xN(0,σ(t)2),符合扩散过程的初始条件。

  • 形状解释: 假设采样的数据是图像,每个样本形状为 32 × 48 32 \times 48 32×48
  • marginal_prob_std(t): 标准差控制噪声幅度,模拟扩散过程的终点。
3. 时间反转离散化
time_steps = torch.linspace(1., eps, num_steps, device=device)
step_size = time_steps[0] - time_steps[1]
  • time_steps: 从 t = 1 t=1 t=1 t = eps t=\text{eps} t=eps 的等间距离散时间点。
    • t = 1 t=1 t=1 是扩散过程的初始点。
    • t = eps t=\text{eps} t=eps 是反转过程的终止点。
  • step_size: 每一步的时间间隔。
4. 迭代更新样本

核心采样循环:

x = init_x
with torch.no_grad():
    for time_step in time_steps:
        batch_time_step = torch.ones(batch_size, device=device) * time_step
        g = diffusion_coeff(batch_time_step)  # 扩散系数 g(t)

逐步更新样本 x x x,模拟 SDE 的解法。


5. 计算均值部分
if condition is not None:
    perturbed_condition = condition + torch.randn(batch_size, 32, 48, device=device) * marginal_prob_std(batch_time_step)[:, None, None]
    mean_x = x + (g ** 2)[:, None, None] * score_model(x, batch_time_step, perturbed_condition) * step_size
else:
    mean_x = x + (g ** 2)[:, None, None] * score_model(x, batch_time_step) * step_size
  • 扩散系数 g ( t ) g(t) g(t):
    控制噪声强度随时间变化。

  • 得分函数 s θ s_\theta sθ:
    模型预测的得分函数,用于估计数据分布的梯度 ∇ x log ⁡ p t ( x ) \nabla_x \log p_t(x) xlogpt(x)

  • 均值更新公式:
    μ = x + g ( t ) 2 ⋅ s θ ( x , t ) ⋅ Δ t \mu = x + g(t)^2 \cdot s_\theta(x, t) \cdot \Delta t μ=x+g(t)2sθ(x,t)Δt
    这是 SDE 的确定性部分,控制样本逐步接近真实数据分布。

  • 条件生成: 如果提供了 condition 参数,则模型根据扰动条件预测得分。


6. 添加随机噪声(随机部分)
x = mean_x + torch.sqrt(step_size) * g[:, None, None] * torch.randn_like(x)

在均值更新基础上,加入随机噪声,模拟 SDE 的随机部分:
x ← μ + Δ t ⋅ g ( t ) ⋅ N ( 0 , I ) x \leftarrow \mu + \sqrt{\Delta t} \cdot g(t) \cdot \mathcal{N}(0, I) xμ+Δt g(t)N(0,I)


7. 返回最后的采样结果
return mean_x

返回采样后的最后一组样本。注意: 虽并未明确移除噪声,但通常最后一步可以跳过噪声项以提高采样质量。


样本生成流程总结

  1. 初始化:
    随机生成噪声样本 x x x,对应扩散过程的终点。

  2. 时间反转:
    通过 n u m _ s t e p s num\_steps num_steps 步时间离散化,从 t = 1 t=1 t=1 逐步逼近 t = eps t=\text{eps} t=eps

  3. 更新规则:

    • 均值更新:基于得分函数的确定性部分。
    • 随机扰动:模拟 SDE 的噪声项。
  4. 最终结果:
    返回模拟扩散过程逆过程生成的样本。


适用场景

  • 扩散模型: 用于从预训练得分模型中生成样本。
  • 条件生成: 可结合条件变量进行特定样本生成(如条件图像生成或文本生成)。
  • 灵活性: 可调整时间步数、扩散系数和噪声特性,以适配不同数据分布。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值