【论文阅读】Consistency Models

Introduction

  • 相比于单步生成的模型(例如 GANs, VAEs, normalizing flows),扩散模型的迭代式生成过程需要 10 到 2000 步计算来采样,导致推理速度低,实时性应用受限.

  • 本文的目的是创造高效、单步的生成,同时不牺牲迭代采样的优势。在数据到噪声的 PF-ODE 轨迹上,学习轨迹上任意点到轨迹起点的映射,对这些映射的建模成为 consistency model.
    在这里插入图片描述

  • 两种训练 consistency model的方法

    1. 使用 numerical ODE solver 和预训练的扩散模型在 PF-ODE 轨迹上生成若干相邻点对,通过最小化模型输出点对间的距离(相似度),蒸馏出 consistency model.
    2. 不依赖预训练扩散模型,独立训练一个 consistency model.
  • 在一些数据集上测试.

Diffusion Models

使用 p d a t a ( x ) p_{data}(\mathrm{x}) pdata(x)表示数据分布,扩散模型使用如下随机微分公式对服从原分布的数据进行扩散:

d x t = μ ( , x t , t ) + σ ( t ) d w t \large \mathrm{dx}_t = \mu(\mathrm,{x}_t, t) + \sigma(t)\mathrm{dw}_t dxt=μ(,xt,t)+σ(t)dwt

其中 t t t为时间步,范围是 0 0 0 T T T μ ( ⋅ , ⋅ ) \mu(·,·) μ(⋅,⋅) σ ( ⋅ ) \sigma(·) σ()分别是布朗运动中的漂移系数和扩散系数, x t \mathbf{x}_t xt服从分布 p t ( x ) p_{t}(\mathrm{x}) pt(x) x 0 \mathrm{x}_0 x0服从分布 p d a t a ( x ) p_{data}(\mathrm{x}) pdata(x). 该方程的一个重要属性是,其存在一个 PF-ODE 方程:

d x t = [ μ ( x t , t ) − 1 2 σ ( t ) 2 ∇ log ⁡ p t ( x t ) ] d t \large\mathrm{dx}_t = \left[ \mu(\mathrm{x}_t, t)-\frac{1}{2}\sigma(t)^2 \nabla\log{p_t(\mathrm{x}_t)} \right]\mathrm{d}t dxt=[μ(xt,t)21σ(t)2logpt(xt)]dt

其中 ∇ log ⁡ p t ( x ) \nabla\log{p_t(\mathrm{x})} logpt(x) p t ( x ) p_t(\mathrm{x}) pt(x)的 score function.
在 SDE 中,令漂移系数 μ ( x , t ) = 0 \mu(\mathrm{x}, t) = 0 μ(x,t)=0, 扩散系数 σ ( t ) = 2 t \sigma(t) = \sqrt{2t} σ(t)=2t . 使用得分匹配的方式训练模型 s ϕ ( x , t ) ≈ ∇ log ⁡ p t ( x ) s_{\phi}(\mathrm{x},t) \approx \nabla\log{p_t(\mathrm{x})} sϕ(x,t)logpt(x),代入 PF-ODE 方程,得到 empirical PF-ODE:

d x t d t = − t s ϕ ( x t , t ) \large \frac{\mathrm{dx}_t}{\mathrm{d}t}=-ts_{\phi}(\mathrm{x}_t,t) dtdxt=tsϕ(xt,t

采样时,使用 x ^ T ∼ N ( 0 , T 2 I ) \hat{\mathrm{x}}_T\sim\mathcal{N}(0, T^2I) x^TN(0,T2I)初始化,再使用 numerical ODE solver(例如 Euler, Heun)按时间步倒推出 x ^ 0 \hat{x}_0 x^0. 为了防止数值不稳定,会在 t = ϵ t=\epsilon t=ϵ是提前终止, ϵ \epsilon ϵ为一个正小数,同时将 x ^ ϵ \hat{\mathrm{x}}_{\epsilon} x^ϵ作为结果.

扩散模型的瓶颈在于采样速度慢, ODE solver 利用得分模型 s ϕ ( x , t ) s_{\phi}(\mathrm{x},t) sϕ(x,t)迭代求解,消耗算力多. 目前存在一些更快的 ODE solver,但是仍然需要大于 10 10 10 步的采样. 也存在一些蒸馏方法,但是大多数方法需要从扩散模型中采集巨大的数据集,同样消耗算力多.

Consistency Models

Definition

根据 PF-ODE 得到一条解路径 { x t } t ∈ [ ϵ , T ] \{\mathrm{x}_t\}_{t\in[\epsilon, T]} {xt}t[ϵ,T],将 consistency function 定义为:

f : ( x t , t ) ↦ x ϵ \large f:(\mathrm{x}_t, t) \mapsto \mathrm{x}_{\epsilon} f:(xt,t)xϵ

对于该路径上的任意点 ( x t , t ) (\mathrm{x}_t, t) (xt,t),其输出是一致的. 对于任意的 t , t ′ ∈ [ ϵ , T ] t, t' \in [\epsilon, T] t,t[ϵ,T],有 f ( x t , t ) = f ( x t ′ , t ′ ) f(\mathrm{x}_t, t) =f(\mathrm{x}_{t'}, t') f(xt,t)=f(xt,t)恒成立.
在这里插入图片描述

Parameterization

F θ ( x , t ) F_{\theta}(\mathrm{x}, t) Fθ(x,t)表示任意形式的神经网络,使用 sikp connection 可以将模型表示为:

f θ ( x , t ) = c s k i p ( t ) x + c o u t ( t ) F θ ( x , t ) \large f_{\theta}(\mathrm{x}, t)=c_{skip}(t)\mathrm{x}+c_{out}(t)F_{\theta}(\mathrm{x},t) fθ(x,t)=cskip(t)x+cout(t)Fθ(x,t)

其中边界条件为 c s k i p ( ϵ ) = 1 c_{skip}(\epsilon)=1 cskip(ϵ)=1 c o u t ( ϵ ) = 0 c_{out}(\epsilon)=0 cout(ϵ)=0.
具体为:

c s k i p ( t ) = σ d a t a 2 ( t − ϵ ) 2 + σ d a t a 2 \large c_{skip}(t)=\frac{\sigma_{data}^2}{(t-\epsilon)^2+\sigma_{data}^2} cskip(t)=(tϵ)2+σdata2σdata2

c o u t ( t ) = σ d a t a ( t − ϵ ) σ d a t a 2 + t 2 \large c_{out}(t)=\frac{\sigma_{data}(t-\epsilon)}{\sqrt{\sigma_{data}^2+t^2}} cout(t)=σdata2+t2 σdata(tϵ)

σ d a t a \sigma_{data} σdata取值 0.5 0.5 0.5.

Sampling

有了一个训练好的 consistency model f θ ( ⋅ , ⋅ ) f_{\theta}(·, ·) fθ(⋅,⋅)之后,从高斯噪声 N ( 0 , T 2 I ) \mathcal{N}(0, T^2I) N(0,T2I)采样 x ^ T \hat{\mathrm{x}}_T x^T,再代入模型一步推出 x ^ ϵ = f θ ( x T ^ , T ) \hat{\mathrm{x}}_{\epsilon}=f_{\theta}(\hat{\mathrm{x}_T}, T) x^ϵ=fθ(xT^,T).为了提高质量,也可以进行多步采样,算法如下:

在这里插入图片描述

Training Consistency Models via Distillation

作者的第一个方法是在预训练的得分模型 s ϕ ( x , t ) s_{\phi}(\mathrm{x},t) sϕ(x,t)上蒸馏.

首先考虑将 ϵ \epsilon ϵ T T T的时间离散化成 N − 1 N-1 N1 个间隔,也即 t 1 = ϵ < t 2 < t 3 < . . . < t N = T t_1=\epsilon<t_2<t_3<...<t_N=T t1=ϵ<t2<t3<...<tN=T. 在实践中,使用如下公式:

t i = ( ϵ 1 / ρ + i − 1 N − 1 ( T 1 / ρ − ϵ 1 / ρ ) ) ρ \large t_i=\left(\epsilon^{1/\rho} + \frac{i-1}{N-1}\left(T^{1/\rho}-\epsilon^{1/\rho}\right) \right)^{\rho} ti=(ϵ1/ρ+N1i1(T1/ρϵ1/ρ))ρ

其中 ρ = 7 \rho=7 ρ=7. 当 N N N充分大时,可以获得 x t n \mathrm{x}_{t_n} xtn x t n + 1 \mathrm{x}_{t_{n+1}} xtn+1的准确估计,于是 x ^ t n ϕ \hat{\mathrm{x}}_{t_n}^{\phi} x^tnϕ可以定义为:

x ^ t n ϕ = x t n + 1 + ( t n − t n + 1 ) Φ ( x t n + 1 , t n + 1 ; ϕ ) \large \hat{\mathrm{x}}_{t_n}^{\phi}=\mathrm{x}_{t_{n+1}} + (t_n-t_{n+1})\Phi(\mathrm{x}_{t_{n+1}}, t_{n+1};\phi) x^tnϕ=xtn+1+(tntn+1)Φ(xtn+1,tn+1;ϕ)

Φ ( . . . ; ϕ ) \Phi(...;\phi) Φ(...;ϕ)为 one-step ODE solver(比如Euler).

从数据集中采样 x \mathrm{x} x,通过 SDE 加噪 N ( x , t n + 1 2 I ) \mathcal{N}(\mathrm{x}, t_{n+1}^2I) N(x,tn+12I)得到 x t n + 1 \mathrm{x}_{t_{n+1}} xtn+1, 然后使用 ODE solver 求解出 x ^ t n ϕ \hat{\mathrm{x}}_{t_n}^{\phi} x^tnϕ,通过最小化在 x ^ t n ϕ \hat{\mathrm{x}}_{t_n}^{\phi} x^tnϕ x t n + 1 \mathrm{x}_{t_{n+1}} xtn+1计算结果的差距训练模型.

Definition 1
consistency distillation loss (CD)表示为:

L C D N ( θ , θ − ; ϕ ) = E [ λ ( t n ) d ( f θ ( x t n + 1 , t n + 1 ) , f θ − ( x ^ t n ϕ , t n ) ] \large \mathcal{L}_{CD}^{N}(\theta, \theta^-;\phi)=\mathbb{E}\left[\lambda(t_n)d(f_{\theta}(\mathrm{x}_{t_{n+1}},t_{n+1}),f_{\theta^-}(\hat{\mathrm{x}}_{t_n}^{\phi}, t_n) \right] LCDN(θ,θ;ϕ)=E[λ(tn)d(fθ(xtn+1,tn+1),fθ(x^tnϕ,tn)]

其中, λ ( ⋅ ) ∈ R + \lambda(·)\in\mathbb{R}^+ λ()R+是正权重函数, θ − \theta^- θ θ \theta θ在优化过程中历史值的均值. d ( ⋅ , ⋅ ) d(·,·) d(⋅,⋅)是一个度量函数,满足当且仅当两个输入相等时为 0 0 0,其余情况大于 0 0 0.

作者考虑 d ( ⋅ , ⋅ ) d(·,·) d(⋅,⋅) 使用 l 1 l_1 l1 以及 l 2 l_2 l2,在实验中 λ ( t n ) ≡ 1 \lambda(t_n) \equiv1 λ(tn)1表现较好. θ − \theta^- θ使用 EMA 更新,计算公式如下:

θ − ← s t o p g a r d ( μ θ − + ( 1 − μ ) θ ) \large \theta^- \leftarrow \mathrm{stopgard}(\mu\theta^-+(1-\mu)\theta) θstopgard(μθ+(1μ)θ)

其中 0 ≤ μ < 1 0\le\mu<1 0μ<1. 使用 EMA 可以使训练更稳定,同时能提高模型的表现.
模型训练算法如下:
在这里插入图片描述

Training Consistency Models in Isolation

consistency model 可以不依赖预训练扩散模型训练,使用如下无偏估计替换 ∇ log ⁡ p t ( x ) \nabla\log{p_t(\mathrm{x})} logpt(x)

∇ log ⁡ p t ( x ) = − E [ x t − x t 2 ∣ x t ] \large \nabla\log{p_t(\mathrm{x})}=-\mathbb{E}\left[\left.\frac{\mathrm{x}_t-\mathrm{x}}{t^2}\right|\mathrm{x}_t \right] logpt(x)=E[t2xtx xt]

consistency training loss (CT)表示为:

L C D N ( θ , θ − ) = E [ λ ( t n ) d ( f θ ( x + t n + 1 z , t n + 1 ) , f θ − ( x + t n z , t n ) ] \large \mathcal{L}_{CD}^{N}(\theta, \theta^-)=\mathbb{E}\left[\lambda(t_n)d(f_{\theta}(\mathrm{x}+t_{n+1}\mathrm{z},t_{n+1}),f_{\theta^-}(\mathrm{x}+t_{n}\mathrm{z},t_{n}) \right] LCDN(θ,θ)=E[λ(tn)d(fθ(x+tn+1z,tn+1),fθ(x+tnz,tn)]

其中 z ∼ N ( 0 , I ) \mathrm{z}\sim\mathcal{N}(0,I) zN(0,I). 损失函数的计算依赖于 f θ f_{\theta} fθ f θ − f_{\theta^-} fθ,且与扩散模型的无关.

为了提升模型效果,使用 schedule function N ( ⋅ ) N(·) N()控制 N N N 增长. 直觉上,当 N N N 小的时候,使用 consistency distillation loss 模型在一开始收敛更快,同时方差小、偏差大. 反之,在训练结束时,应当使 N N N 大,这样方差大、偏差小。同时,使用 schedule function μ ( ⋅ ) \mu(·) μ()替换 μ \mu μ,让它随着 N N N 增长而变化.
N ( ⋅ ) N(·) N() μ ( ⋅ ) \mu(·) μ()具体为

N ( k ) = ⌈ k K ( ( s 1 + 1 ) 2 − s 0 2 ) + s 0 2 − 1 ⌉ + 1 \large N(k)= \left\lceil\sqrt{\frac{k}{K}((s_1+1)^2-s_0^2)+s_0^2}-1 \right\rceil+1 N(k)= Kk((s1+1)2s02)+s02 1 +1

μ ( k ) = exp ⁡ ( s 0 log ⁡ μ 0 N ( k ) ) \large \mu(k)=\exp\left(\frac{s_0\log{\mu_0}}{N(k)}\right) μ(k)=exp(N(k)s0logμ0)

K K K表示整体训练步数, s 0 s_0 s0表示开始的离散化步数.

训练算法如下:
在这里插入图片描述

Experiment

关于 CD ,作者分别使用 l 1 l_1 l1, l 2 l_2 l2, L P I P S \mathrm{LPIPS} LPIPS作为度量函数,使用一阶Euler和二阶Heun座位 ODE solver, N N N { 9 , 12 , 18 , 36 , 50 , 60 , 80 , 120 } \{9,12,18,36,50,60,80,120\} {9,12,18,36,50,60,80,120},使用相应的预训练扩散模型做初始化. 使用 CT 训练的模型则随机初始化.
在这里插入图片描述

(a) 对比不同的度量函数在 CD 上的表现,其中 LPIPS 的效果最好.
(b, c) 对不不同 ODE solver 和 N N NCD 上的表现,使用 Heun 且 N N N 18 18 18时效果最好.在取相同的 N N N时,二阶Heun的表现优于一阶Euler,因为高阶的 ODE solver 的估计误差更小. 当 N N N充分大时,模型对 N N N变得不敏感.
(d) 根据之前的结论,关于 CT 的实验使用 LPIPS 作为度量函数. 更小的 N N N收敛更快,但是采样结构更差;使用自适应的 N ( ⋅ ) N(·) N() μ ( ⋅ ) \mu(·) μ()效果最好.

对比 CDprogressive disillation(PD) 在不同数据集上的效果,CD 的表现普遍比 PD 好.
在这里插入图片描述

对比 CT 和其它生成模型,仅使用一步或两步生成.
在这里插入图片描述

Zero-Shot Image Editing

  • 24
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
Consistency models are used in distributed computing systems to define the level of consistency that is maintained across different copies of the same data. These models determine how updates to the data are propagated to other copies and how conflicts are resolved when multiple copies are updated simultaneously. There are several consistency models, including: 1. Strong consistency: In this model, all copies of the data are updated synchronously and all updates are visible to all nodes at the same time. This model guarantees that all nodes see the same version of the data at the same time. 2. Weak consistency: In this model, updates are not propagated synchronously and different nodes may have different views of the data at any given time. This model allows for faster updates but may result in temporary inconsistencies. 3. Eventual consistency: In this model, updates are propagated asynchronously and nodes eventually converge to a consistent view of the data. This model allows for high availability and scalability but may result in temporary inconsistencies. 4. Causal consistency: In this model, updates are propagated in a causally consistent manner, meaning that updates that are causally related are propagated in the same order to all nodes. This model provides a compromise between strong and eventual consistency. 5. Read-your-writes consistency: In this model, a node always reads its own writes. This model guarantees that a node will always see its own writes, but may not see the writes of other nodes immediately. Each consistency model has its own trade-offs between performance, availability, and consistency. The choice of consistency model depends on the specific requirements of the application and the underlying distributed system.

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值