Common Diffusion Noise Schedules and Sample Steps are Flawed

Common Diffusion Noise Schedules and Sample Steps are Flawed

TL; DR:本文发现了现有生图模型的生图结果都是中等亮度和色彩,整体画面偏灰,并指出这是由于 noise schedule 无法在最后时间步处达到零信噪比,从而训练/推理时数据分布不一致导致的。本文提出了 Zero Terminal SNR、Train with V Prediction、Sample From the Last Step、Rescale CFG 一系列改进来解决这一问题,方法简单直接却有效。


发现问题

许多现有生图模型(尤其是 Stable Diffusion)的 noise schedule 都存在一个问题,就是当加噪强度达到最大时,信噪比仍旧没有达到零。这使得训练/推理时模型的输入不一致,从而导致生图结果中色彩的动态范围很低,都是中等亮度和色彩,整体画面偏灰,无法生成纯黑纯白的图片。

在这里插入图片描述

我们知道,扩散模型在进行训练时,在每步给真实图片添加不同程度的噪声,噪声的强度与步数的关系由一组超参数 β \beta β 控制(这组超参数就可以称为 noise schedule),然后训练模型预测噪声/数据,从噪声中逐步恢复出原始真实图片。一般来说,噪声强度随着步数变大而变大,要求最终达到完全的高斯噪声。在推理生图时,随机采样一个高斯噪声,然后用训练好的模型一步步去噪,即可生成新的图片。

信噪比(Signal-to-Noise Ratio, SNR),用来表征数据中信号与噪声的比值,理想情况下,最后一步的 SNR 应该达到 0,即 SNR ( T ) = 0 \text{SNR}(T)=0 SNR(T)=0。然而,作者发现(如下表所示),现有的很多 noise schedule (尤其是 SD 的 noise schedule)在最后一步并不能达到完全的高斯噪声,SNR 远没有达到 0。

在这里插入图片描述

按照 SD 的 noise schedule,最后一步加噪结果为:
x T = 0.068265 x 0 + 0.997667 ϵ x_T=0.068265x_0+0.997667\epsilon xT=0.068265x0+0.997667ϵ
其中 x 0 , ϵ x_0,\epsilon x0,ϵ 分别表示真实图片和噪声,可以看到,在最后一步的加噪结果中,真实图片的占比远远没有达到可以忽略不计的程度,仍旧含有少量的信号。这些泄露的信号会包含一些低频信息,比如每个通道的均值,训练时模型就会学习到根据这些均值信息来去噪。然而,在采样生图时,我们是从标准的高斯分布中采样第 T T T​ 步的输入,其 SNR 自然为零,也不会包含任何有意义的信息。这就导致训练、推理时的数据分布不一致,最终使得模型的生图结果中色彩和亮度都比较平均。

SD 之外的其他生图模型的 noise schedule 在第 T T T 步的信噪比更低一些,对实际生图效果的影响也更小。但是作者建议,最好直接保证第 T T T 步的 SNR 达到 0(即 Zero Terminal SNR),从而完全避免训练、推理时的数据分布差异。当然,这也意味着必须使用 Variance Preseving(VP)的形式化,因为 Variance Exploding (VE)的形式化是无法做到 Zero Terminal SNR 的。

方法

强制 Zero Terminal SNR

本文提出了一种最简单直接的强制保证 Zero Terminal SNR 的方法,即直接通过对 noise schedule 进行 rescale 来实现。标准的 DDPM 加噪公式:
x t = α ˉ t x 0 + 1 − α ˉ t ϵ ,    ϵ ∼ N ( 0 , I ) x_t=\sqrt{\bar{\alpha}_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon,\ \ \epsilon\sim\mathcal{N}(0,\mathbf{I}) xt=αˉt x0+1αˉt ϵ,  ϵN(0,I)
其中 α ˉ t \sqrt{\bar\alpha_t} αˉt 定义了我们的 noise schedule (实际上,我们直接定义的 noise schedule 是一组 { β } 1 T \{\beta\}_1^T {β}1T,然后 α t = 1 − β t ,   α ˉ t = ∏ i = 1 t α i \alpha_t=1-\beta_t,\ \bar\alpha_t=\prod_{i=1}^t\alpha_i αt=1βt, αˉt=i=1tαi),从而控制着每一步的信号和噪声混合的比例,即控制着每一步的 SNR。我们现在想要保持 α ˉ 1 \sqrt{\bar\alpha_1} αˉ1 不变,而 α ˉ T \sqrt{\bar\alpha_T} αˉT 变为 0,然后中间的 α ˉ t ,   t ∈ [ 2 , T − 1 ] \sqrt{\bar\alpha_t},\ t\in[2,T-1] αˉt , t[2,T1] 则是线性插值得出。之所以在 α ˉ t \sqrt{\bar{\alpha}_t} αˉt 空间进行 rescale,而不是在 SNR ( t ) \text{SNR}(t) SNR(t) 空间进行 rescale,是因为这样能在使得最后一步的 SNR 达到 0 的同时,更好地保持原来的 noise schedule 曲线。具体的 python 代码如下所示。

import torch

def enforce_zero_terminal_snr(betas: torch.Tensor):
    # 将 betas 转换为 alphas_bar_sqrt
    alphas = 1 - betas
    alphas_bar = alphas.cumprod(dim=0)
    alphas_bar_sqrt = alphas_bar.sqrt()

    # 保存原始值
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

    # 如此变换,使得最后一步的值为0
    alphas_bar_sqrt -= alphas_bar_sqrt_T

    # 如此变换,使得第一步的值为原始值
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # 将 alphas_bar_sqrt 转换回 betas
    alphas_bar = alphas_bar_sqrt ** 2
    alphas = alphas_bar[1: ] / alphas_bar[: -1]
    betas = 1 - alphas
    return betas

对 noise schedule 进行 rescale 之后的 log SNR 和 α ˉ \sqrt{\bar\alpha} αˉ 如下图所示。可以看到,对原 noise schedule 的曲线的改变不大,但确实实现了 Zero Terminal SNR。

在这里插入图片描述

注意这里只有非 cosine 的 noise schedule 需要使用上述 rescale 方式来实现 Zero Terminal SNR,对于 cosine 的 noise schedule 来说,只要不要对 β t \beta_t βt 进行,保证 β T = 1 \beta_T=1 βT=1 即可实现 Zero Terminal SNR。

使用 V Prediction 目标函数进行训练

由于我们已经强行保证了第 T T T 步的 SNR 为 0,也就是第 T T T 步已经是纯噪声了。此时,若仍采用常规的 ϵ \epsilon ϵ-prediction 预测噪声作为目标函数是没有意义的,模型学不到任何有意义的东西。这里,作者改为使用 v-prediction 作为训练目标:
v t = α ˉ t ϵ − 1 − α ˉ t x 0 L = λ t ∣ ∣ v t − v ~ t ∣ ∣ 2 2 v_t=\sqrt{\bar\alpha_t}\epsilon-\sqrt{1-\bar\alpha_t}x_0\\ \mathcal{L}=\lambda_t||v_t-\tilde{v}_t||^2_2 vt=αˉt ϵ1αˉt x0L=λt∣∣vtv~t22
对 noise schedule 进行 resscale 后,在 t = T t=T t=T 时,有 α ˉ T = 0 \bar\alpha_T=0 αˉT=0,从而 v T = x 0 v_T=x_0 vT=x0。因此这一步相当于输入一个纯噪声 ϵ \epsilon ϵ 给模型,需要模型预测真实图片 x 0 x_0 x0 ,这一步其实已经不是 “去噪” 了,因为输入中不包含任何有意义的信号,这一步相当于是在根据 prompt 去预测数据分布的均值。

作者对 noise schedule 进行 rescale 后,使用 λ t = 1 \lambda_t=1 λt=1 的 v loss 对 SD 模型进行微调,发现出图的质量与使用 ϵ \epsilon ϵ loss 很接近,作者在此建议以后使用 v prediction 作为训练目标,并可以通过修改超参数 λ t \lambda_t λt 来调整损失的权重。

从最后一个时间步开始采样

现在有很多新的采样器,在推理生图时只使用很少的步数,就能达到不错的生图结果。现在一般是在训练时对全部 T T T (如 T = 1000 T=1000 T=1000)步训练,在推理时只需要采样 S S S (如 S = 25 S=25 S=25​) 步来生图。调整采样步数可以在推理生图时实现效率和质量的权衡。

然而,现在很多少步数的采样器在时间步的选择上,并没有从最后一步开始(如下表所示)。这样的话,即使已经保证了训练阶段最后一步的 Zero Terminal SNR,但是由于生图时采样了纯高斯噪声但不是从最后一步开始去噪,所以还是会出现训练、推理数据分布不一致的情况,从而导致生图结果的亮度和色彩比较平均。作者认为,在通过 rescale 保证了 Zero Terminal SNR 之后,确保推理生图时从最后一步开始也是极其关键的,因为只有这样才能真正使得模型训练和推理时的起始步都是完全的高斯噪声。

在这里插入图片描述

作者考虑了两种采样步数选择方法来确保这一点。一是 iDDPM 中提出的 Linspace,首先将第一步和最后一步囊括进来,然后根据指定的具体步数,均匀地选择中间步;二是 DPM 中提出的 Trailing,先只选上最后一步,然后按照特定的间隔从后往前选。

作者发现 Trailing 这种方法更高效,尤其是当总的采样步数 S S S 比较小的时候。这是因为前面几步(如 x 0 x_0 x0 x 1 x_1 x1)之间的差别实际非常小,因此不需要在前面选择过多的步数。

Rescale CFG

作者还发现当我们保证了最后一步的 SNR 为 0 时,CFG 会变得非常敏感,并有可能产生过曝的图片。Imagen (采用 cosine 的 noise schedule,非常接近 Zero Terminal SNR)也遇到了类似的问题,Imagen 的作者提出了动态阈值的方式来解决这一问题,但该方法只适合于 pixel-based 生图模型。本文中作者提出了一种 rescale CFG 的方法,称对 pixel-based 和 latent-based 两类生图模型都有效。

常规 CFG 的计算方式为:
x cfg = x neg + w ( x pos − x neg ) x_\text{cfg}=x_\text{neg}+w(x_\text{pos}-x_\text{neg}) xcfg=xneg+w(xposxneg)
其中 x pos , x neg x_\text{pos},x_\text{neg} xpos,xneg 分别指模型在 prompt 和 negative prompt 条件下的预测值。作者发现,当 w w w 很大时, x cfg x_\text{cfg} xcfg 的值会非常大,进而导致过曝的问题。作者提出对 CFG 再进行一次 rescale:
σ pos = std ( x pos ) ,    σ cfg = std ( x cfg ) x rescaled = x cfg σ pos σ cfg \sigma_\text{pos}=\text{std}(x_\text{pos}),\ \ \sigma_\text{cfg}=\text{std}(x_\text{cfg}) \\ x_\text{rescaled}=x_\text{cfg}\frac{\sigma_\text{pos}}{\sigma_\text{cfg}} σpos=std(xpos),  σcfg=std(xcfg)xrescaled=xcfgσcfgσpos
但直接使用这个 rescale 过的值又会导致生成的图片过于简单,因此作者又引入了一个超参数 ϕ \phi ϕ,将该值与原值进行插值,得到最终的结果:
x final = ϕ x rescaled + ( 1 − ϕ ) x cfg x_\text{final}=\phi x_\text{rescaled}+(1-\phi)x_\text{cfg} xfinal=ϕxrescaled+(1ϕ)xcfg
参考伪代码如下所示。

def apply_cfg(pos, neg, weight=7.5, rescale=0.7):
    # 常规 CFG
    cfg = neg + weight * (pos - neg)

    # 计算标准差
    std_pos = pos.std([1, 2, 3], keepdim=True)
    std_cfg = cfg.std([1, 2, 3], keepdim=True)

    # 进行 rescale 和插值, 得到最终结果
    factor = std_pos / std_cfg
    factor = rescale * factor + (1 - rescale)
    return cfg * factor

实验结果

可以看到,本文提出的一系列改进使得 SD 能够生成色彩和亮度动态范围更广,更高质量的图片。

在这里插入图片描述

总结

本文是比较早指出 Noise Schedule 中最后一步 SNR 不能达到零这一问题的工作,提出的方法简单直接,却有效地解决了这一问题。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值