Improved-DDPM
paper: https://arxiv.org/pdf/2102.09672.pdf
github: https://github.com/openai/improved-diffusion
DDPM 训练出来的扩散模型虽然其生成效果不错, 但由于对数似然相比于 GAN 等模型不够好, 因此其生成的多样性也会打一个折扣,
Improved DDPM做了如下改动
-
学习 Σ θ ( x t , t ) \Sigma_{\boldsymbol{\theta}}\left(\boldsymbol{x}_t, t\right) Σθ(xt,t)
DDPM 的 σ t 2 \sigma_t^2 σt2 是固定的, 并且探讨了 σ t 2 \sigma_t^2 σt2 取两种极端情况 σ t 2 = β t \sigma_t^2=\beta_t σt2=βt 和 σ t 2 = β ~ t = 1 − α ˉ t − 1 1 − α ˉ t β t \sigma_t^2=\tilde{\beta}_t=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \beta_t σt2=β~t=1−αˉt1−αˉt−1βt 下的模型表现差不多.
因此提出学习一组方差:
Σ θ ( x t , t ) = exp ( v log β t + ( 1 − v ) log β ~ t ) \Sigma_{\boldsymbol{\theta}}\left(\boldsymbol{x}_t, t\right)=\exp \left(v \log \beta_t+(1-v) \log \tilde{\beta}_t\right) Σθ(xt,t)=exp(vlogβt+(1−v)logβ~t)
采用混合损失:
L h y b r i d = L s i m p l e + λ L v l b L_{\mathrm{hybrid}}=L_{\mathrm{simple}}+\lambda L_{\mathrm{vlb}} Lhybrid=Lsimple+λLvlb
其中 L v l b L_{\mathrm{vlb}} Lvlb就是未简化版的变分下界损失, λ = 0.001 \lambda=0.001 λ=0.001保证 vlb 损失的影响不要太大影响了 simple 损失. -
余弦加噪
Improved DDPM 提出了余弦方案 (cosine schedule):
β t = min ( 1 − α ˉ t α ˉ t − 1 , 0.999 ) , α ˉ t = f ( t ) f ( 0 ) , f ( t ) = cos ( t / T + s 1 + s ⋅ π 2 ) 2 \beta_t=\min \left(1-\frac{\bar{\alpha}_t}{\bar{\alpha}_{t-1}}, 0.999\right), \quad \bar{\alpha}_t=\frac{f(t)}{f(0)}, \quad f(t)=\cos \left(\frac{t / T+s}{1+s} \cdot \frac{\pi}{2}\right)^2 βt=min(1−αˉt−1αˉt,0.999),αˉt=f(0)f(t),f(t)=cos(1+st/T+s⋅2π)2 -
降低梯度噪声
为时间步 t均匀采样导致的梯度噪声大, 所以提出了时间步重要性采样的方法:
L v l b = E t ∼ p t [ L t p t ] , p t ∝ E [ L t 2 ] , ∑ p t = 1 L_{\mathrm{vlb}}=E_{t \sim p_t}\left[\frac{L_t}{p_t}\right], \quad p_t \propto \sqrt{E\left[L_t^2\right]}, \quad \sum p_t=1 Lvlb=Et∼pt[ptLt],pt∝E[Lt2],∑pt=1
保存每个时间步前 10 次的损失求平均来估计, 这样损失越大的时间步采样频率越低, 从而整体上可以保证损失的稳定性.