实现了Euler-Maruyama方法,用来从基于得分匹配模型的扩散过程生成样本。它模拟了一个随机微分方程(SDE),逐步从噪声中采样数据点,并通过时间反转过程回归到数据分布。
参数说明
-
score_model
:
一个 PyTorch 模型,表示时间相关的得分函数,提供对数概率密度的梯度 ∇ x log p t ( x ) \nabla_x \log p_t(x) ∇xlogpt(x)。 -
marginal_prob_std
:
给定时间 t t t,返回扰动核的标准差 σ ( t ) \sigma(t) σ(t)。用于定义扩散过程的噪声。 -
diffusion_coeff
:
返回 SDE 的扩散系数 g ( t ) g(t) g(t)。这是 SDE 中噪声的强度。 -
batch_size
:
每次采样生成的样本数。 -
num_steps
:
时间反转过程的离散步数(采样步数),即将时间轴分割为 n u m _ s t e p s num\_steps num_steps 段。 -
device
:
运行设备,通常为 ‘cuda’ 或 ‘cpu’。 -
eps
:
最小时间步长,用于数值稳定性,避免时间 t = 0 t=0 t=0。
解析
1. 初始化采样时间 t = 1 t=1 t=1
t = torch.ones(batch_size, device=device)
初始化每个样本的时间 t = 1 t=1 t=1,对应扩散过程的初始点(完全被噪声扰动的状态)。
2. 初始化样本 x x x
init_x = torch.randn(batch_size, 32, 48, device=device) * marginal_prob_std(t)[:, None, None]
生成随机噪声初始化样本 x ∼ N ( 0 , σ ( t ) 2 ) x \sim \mathcal{N}(0, \sigma(t)^2) x∼N(0,σ(t)2),符合扩散过程的初始条件。
- 形状解释: 假设采样的数据是图像,每个样本形状为 32 × 48 32 \times 48 32×48。
marginal_prob_std(t)
: 标准差控制噪声幅度,模拟扩散过程的终点。
3. 时间反转离散化
time_steps = torch.linspace(1., eps, num_steps, device=device)
step_size = time_steps[0] - time_steps[1]
time_steps
: 从 t = 1 t=1 t=1 到 t = eps t=\text{eps} t=eps 的等间距离散时间点。- t = 1 t=1 t=1 是扩散过程的初始点。
- t = eps t=\text{eps} t=eps 是反转过程的终止点。
step_size
: 每一步的时间间隔。
4. 迭代更新样本
核心采样循环:
x = init_x
with torch.no_grad():
for time_step in time_steps:
batch_time_step = torch.ones(batch_size, device=device) * time_step
g = diffusion_coeff(batch_time_step) # 扩散系数 g(t)
逐步更新样本 x x x,模拟 SDE 的解法。
5. 计算均值部分
if condition is not None:
perturbed_condition = condition + torch.randn(batch_size, 32, 48, device=device) * marginal_prob_std(batch_time_step)[:, None, None]
mean_x = x + (g ** 2)[:, None, None] * score_model(x, batch_time_step, perturbed_condition) * step_size
else:
mean_x = x + (g ** 2)[:, None, None] * score_model(x, batch_time_step) * step_size
-
扩散系数 g ( t ) g(t) g(t):
控制噪声强度随时间变化。 -
得分函数 s θ s_\theta sθ:
模型预测的得分函数,用于估计数据分布的梯度 ∇ x log p t ( x ) \nabla_x \log p_t(x) ∇xlogpt(x)。 -
均值更新公式:
μ = x + g ( t ) 2 ⋅ s θ ( x , t ) ⋅ Δ t \mu = x + g(t)^2 \cdot s_\theta(x, t) \cdot \Delta t μ=x+g(t)2⋅sθ(x,t)⋅Δt
这是 SDE 的确定性部分,控制样本逐步接近真实数据分布。 -
条件生成: 如果提供了
condition
参数,则模型根据扰动条件预测得分。
6. 添加随机噪声(随机部分)
x = mean_x + torch.sqrt(step_size) * g[:, None, None] * torch.randn_like(x)
在均值更新基础上,加入随机噪声,模拟 SDE 的随机部分:
x
←
μ
+
Δ
t
⋅
g
(
t
)
⋅
N
(
0
,
I
)
x \leftarrow \mu + \sqrt{\Delta t} \cdot g(t) \cdot \mathcal{N}(0, I)
x←μ+Δt⋅g(t)⋅N(0,I)
7. 返回最后的采样结果
return mean_x
返回采样后的最后一组样本。注意: 虽并未明确移除噪声,但通常最后一步可以跳过噪声项以提高采样质量。
样本生成流程总结
-
初始化:
随机生成噪声样本 x x x,对应扩散过程的终点。 -
时间反转:
通过 n u m _ s t e p s num\_steps num_steps 步时间离散化,从 t = 1 t=1 t=1 逐步逼近 t = eps t=\text{eps} t=eps。 -
更新规则:
- 均值更新:基于得分函数的确定性部分。
- 随机扰动:模拟 SDE 的噪声项。
-
最终结果:
返回模拟扩散过程逆过程生成的样本。
适用场景
- 扩散模型: 用于从预训练得分模型中生成样本。
- 条件生成: 可结合条件变量进行特定样本生成(如条件图像生成或文本生成)。
- 灵活性: 可调整时间步数、扩散系数和噪声特性,以适配不同数据分布。