训练及推理步骤如下
训练:
① 准备真实图片样本数据
② 根据递推公式直接加噪到第t步,得到第t步数据xt,并记录真实加噪噪声noise
③ 根据第t步数据xt 和t,喂给神经网络得到预测的 噪声predict_noise
④ 真实噪声noise 和预测噪声predict_noise进行损失函数计算
采样:
① 随机生成第噪声数据,作为第t步的数据 xt
② 根据公式q(x_{t-1} | x_t, x_0),得到x{t-1}步数据
③ 直到②计算到x0时,即得到最终的x0数据
重点公式代码实现如下
1 参数准备
- beta 越来越大,alpha越来越小
x t = α t ∗ x t − 1 + β z 1 \quad x_t=\sqrt{\alpha_t}*x_{t-1} + \sqrt{\beta}z_1 xt=αt∗xt−1+βz1
beta=torch.linspace(beta_start, beta_end, timesteps)
alphas=1-betas
- 第xt个样本值为
x t = α ˉ t ∗ x 0 + 1 − α ˉ t z t x_t=\sqrt{\bar \alpha_t}*x_{0}+\sqrt{1-\bar \alpha_t}z_t xt=αˉt∗x0+1−αˉtzt
sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
- 先验分布q(x_{t-1} | x_t, x_0) 方差为
e x p ( − 1 2 ( α t 1 − α t + 1 1 − α ˉ t − 1 ) x t − 1 2 ( α t 1 − α t + 1 1 − α ˉ t − 1 ) = 1 σ 2 σ 2 = β ( 1 − α ˉ t − 1 ) 1 − α ˉ t exp(-\frac{1}{2} (\frac{\alpha_t}{1-\alpha_t}+\frac{1}{1-\bar \alpha_{t-1}})x_{t-1}^2 \\ (\frac{\alpha_t}{1-\alpha_t}+\frac{1}{1-\bar \alpha_{t-1}})=\frac{1}{\sigma^2} \\ \sigma^2=\frac{\beta(1-\bar \alpha_{t-1})}{1-\bar \alpha_{t}} exp(−21(1−αtαt+1−αˉt−11)xt−12(1−αtαt+1−αˉt−11)=σ21σ2=1−αˉtβ(1−αˉt−1)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
2 模型训练损失函数
1 正向传播时,随机给noise,通过下式计算真实xt
x
t
=
α
ˉ
t
∗
x
0
+
1
−
α
ˉ
t
z
t
x_t=\sqrt{\bar \alpha_t}*x_{0}+\sqrt{1-\bar \alpha_t}z_t
xt=αˉt∗x0+1−αˉtzt
sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
2 根据上面真实x_t,和时间步t,计算模型预测的噪声,(模型选择unet网络)
predicted_noise = denoise_model(x_noisy, t)
3 损失函数
最小化noise与predicted_noise距离
loss = F.mse_loss(noise, predicted_noise)
3 采样去噪
随机x_t ——> 根据 q(x_{t-1} | x_t, x_0) 公式,采样x_{t-1}步数——>递推最终得到第x0步数据。
- 1 随机生成图片数据
随机生成第t步数据x_t
x_t = torch.randn(shape, device=device)
- 2 时间步去噪
根据贝叶斯公司推导
先验分布q(x_{t-1} | x_t, x_0) 方差为
σ 2 = β ( 1 − α ˉ t − 1 ) 1 − α ˉ t \sigma^2=\frac{\beta(1-\bar \alpha_{t-1})}{1-\bar \alpha_{t}} σ2=1−αˉtβ(1−αˉt−1)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
均值mu为
μ ˉ t = 1 α t ( x t − 1 − α t 1 − α ˉ t z t ) \bar \mu_t=\frac{1 } {\sqrt \alpha_{t}} (x_t -\frac{1-\alpha_t } {\sqrt{1- \bar \alpha_{t}}} z_t) μˉt=αt1(xt−1−αˉt1−αtzt)
model_mean = sqrt_recip_alphas_t * (
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
)
第t-1步数据
根据上面均值和方差进行采样得到第t-1步数据
noise = torch.randn_like(x)
x_(t-1)=model_mean + torch.sqrt(posterior_variance_t) * noise
- 3 当t=0时,即得到x0数据