前言
上一篇博文我们详细讲解了扩散模型的基石DDPM ,并且给出了代码讲解,有不了解的小伙伴可以跳转到前面先学习一下(DDPM)。今天我们再来介绍下DDPM的改进版本。DDPM虽然对生成任务带来了新得启发,但是他有一个致命的缺点,就是推理速度比较慢,这就导致实际工业应用中很难落地。为了解决这个问题,斯坦福大学提出了他的改进版本,实际上也不算是新的模型,只是一个新的采样方式,加快了模型的推理速度,就是我们今天要介绍的Denoising Diffusion Implicit Models(DDIM),目前流行的一些采样模型,如stable diffusion,midjourney等都是基于DDIM的采样方式。下面给出DDIM的论文和项目地址:
- 论文:Denoising Diffusion Implicit Models
- 代码(pytorch):https://github.com/ermongroup/ddim
- 视频讲解:一个视频看懂DDIM凭什么加速采样|扩散模型相关
前面讲的DDPM是一个马尔科夫的过程,而DDIM是通过去马尔科夫化,但是又和DDPM有一个同样的训练目标,也就是说如果你已经有一个训练好的DDPM模型了,你就可以通过DDIM的这种采样方式来加速已经训练好的DDPM的采样,所以说DDIM更像是提出了一个新的采样方式,因为他的训练过程跟DDPM是一样的,只不过DDIM改进了采样过程。最后作者说DDIM速度会比DDPM要快10到10到50倍,但是DDIM的这种采样方式也有一些缺点,虽然加快了模型的推理速度,但是由于DDIM的整个生成过程是确定性的(deterministic)过程,这就导致了DDIM的多样性相比于DDPM比较差。
问题
DDPM采样慢主要有两个原因:
- T太大了
- 必须逐步采样
~先讲原因:
DDPM既然采样慢是因为T太大了,能不能把T调小?
一般情况下T=1000,在下面的公式中,为什么我们希望 a t {a_t} at接近1且小于1? x t = 1 − a t × E t + a t × x t − 1 x_t = \sqrt{ 1-a_t} × \mathcal{E}_t + \sqrt{a_t} × x_{t-1} xt=1−at×Et+at×xt−1
原因:
(1) 如果 a t = 0 {a_t}=0 at=0,那么 x t x_t xt直接就变成噪声了,我们还是希望上一时刻有所保留,所以不能让 a t {a_t} at太小
(2) 如果 a t = 1 {a_t}=1 at=1,此时 x 0 x_0 x0一直保存
(3) 因为 a ‾ t = a t a t − 1 a t − 2 ⋅ ⋅ ⋅ a 2 a 1 \overline{a}_t = a_t a_{t-1}a_{t-2} ···a_{2}a_{1} at=atat−1at−2⋅⋅⋅a2a1,所以想要 a ‾ t \overline{a}_t at趋近于0,T 要取很大采样慢的第二个原因是必须逐步采样,能不能直接跳着采样? P ( x t − 1 ∣ x t , x 0 ) = P ( x t ∣ x t − 1 , x 0 ) P ( x t − 1 ∣ x 0 ) / P ( x t ∣ x 0 ) P(x_{t-1} | x_{t},x_{0}) = P(x_{t} | x_{t-1},x_{0}) P(x_{t-1} | x_{0}) / P(x_{t} | x_{0}) P(xt−1∣xt,x0)=P(xt∣xt−1,x0)P(xt−1∣x0)/P(xt∣x0)在上面的式子中,我们可以利用马尔科夫性质去掉一个 x 0 x_0 x0,即 P ( x t − 1 ∣ x t , x 0 ) = P ( x t ∣ x t − 1 ) P ( x t − 1 ∣ x 0 ) / P ( x t ∣ x 0 ) P(x_{t-1} | x_{t},x_{0}) = P(x_{t} | x_{t-1}) P(x_{t-1} | x_{0}) / P(x_{t} | x_{0}) P(xt−1∣xt,x0)=P(xt∣xt−1)P(xt−1∣x0)/P(xt∣x0)这也表明 x t x_{t} xt时刻只和 x t − 1 x_{t-1} xt−1时刻有关,和之前没有任何关系,而且上面式子中右侧的3项都是通过马尔科夫性质计算的,因为用到了马尔科夫,所以采样时就必须一步一步推导,否则就步满足马尔科夫性质了。
一、DDIM主要的工作是什么?
如下图所示,
DDIM论文提出了一种使过程非马尔可夫的方法,允许跳过去噪过程中的步骤,而不需要在当前状态之前访问所有过去的状态。DDIM最好的部分是,它们可以在训练模型后应用,因此DDPM模型可以很容易地转换为DDIM,而无需重新训练新模型。
1. 去马尔可夫化
之前我们推导DDPM的时候都是基于马尔科夫的性质,即整个采样是按照下面的公式进行的:
P
(
x
t
−
1
∣
x
t
,
x
0
)
~
M
a
r
k
o
v
(1)
P(x_{t-1} | x_{t},x_{0})~Markov\tag{1}
P(xt−1∣xt,x0)~Markov(1)那么我们能不能找到一个非马尔科夫的采样公式,像下面这样的:
P
(
x
s
∣
x
k
,
x
0
)
~
N
o
n
−
M
a
r
k
o
v
,
s
<
k
−
1
(2)
P(x_{s} | x_{k},x_{0})~Non-Markov ,s<k−1\tag{2}
P(xs∣xk,x0)~Non−Markov,s<k−1(2)根据贝叶斯公式,我们可以得出下面的公式:
P
(
x
s
∣
x
k
,
x
0
)
=
P
(
x
k
∣
x
s
,
x
0
)
P
(
x
s
∣
x
0
)
P
(
x
k
∣
x
0
)
(3)
{P(x_{s} | x_{k},x_{0})} = \frac{P(x_{k} | x_{s},x_{0}) P(x_{s} | x_{0})}{P(x_{k} |x_{0})}\tag{3}
P(xs∣xk,x0)=P(xk∣x0)P(xk∣xs,x0)P(xs∣x0)(3)
- 在DDPM中,根据马尔科夫性质,上面式子中的 P ( x k ∣ x s , x 0 ) , P ( x s ∣ x 0 ) , P ( x k ∣ x 0 ) P(x_{k} | x_{s},x_{0}),P(x_{s} | x_{0}),P(x_{k} |x_{0}) P(xk∣xs,x0),P(xs∣x0),P(xk∣x0)都是已知的。但是,如果此时采样不满足马尔科夫性质了,上面三个概率分布不就都不知道了吗,这时候该怎么办?
- 回想一下DDPM的训练过程,我们通过 x 0 x_0 x0可以一步生成 x t x_t xt,即满足 p ( x t ∣ x 0 ) p(x_t|x_0) p(xt∣x0),也就是说上面两个式子只有一个是未知的,即 P ( x s ∣ x k , x 0 ) P(x_{s} | x_{k},x_{0}) P(xs∣xk,x0),这也就回答了我们为什么在前言里面说DDIM实际上是一个采样方法,他的训练也是DDPM。
- 那么还剩下的一项怎么解决?这一项在DDPM 中对应的是 P ( x t − 1 ∣ x t , x 0 ) P(x_{t-1} | x_{t},x_{0}) P(xt−1∣xt,x0),这一项在训练中根本没用到,训练的时候直接根据 p ( x t ∣ x 0 ) p(x_t|x_0) p(xt∣x0)来加噪,既然这一项没用到,说明他不重要,我们就先不管,那接下来应该怎么办?
- 因为 P ( x k ∣ x s , x 0 ) P(x_{k} | x_{s},x_{0}) P(xk∣xs,x0)这一项是未知的,所以我们可以先设一下 P ( x s ∣ x k , x 0 ) P(x_{s} | x_{k},x_{0}) P(xs∣xk,x0)的形式,然后再通过待定系数法来求得它的表达式
- 假设
P
(
x
s
∣
x
k
,
x
0
)
P(x_{s} | x_{k},x_{0})
P(xs∣xk,x0)为正态分布:
P ( x s ∣ x k , x 0 ) ~ N ( k x 0 + m x k , σ 2 I ) (4) P(x_{s} | x_{k},x_{0})~N(kx_0+mx_k, \sigma^2I)\tag{4} P(xs∣xk,x0)~N(kx0+mxk,σ2I)(4) - 上面的分布有三个未知数,我们现在要做的就是求出他们,根据参数化技术,采样一个
x
s
x_s
xs
x s = ( k x 0 + m x k ) + σ E (5) x_s = (kx_0+mx_k) + \sigma\mathcal{E}\tag{5} xs=(kx0+mxk)+σE(5) - 因为这项是满足
x
t
=
1
−
a
‾
t
×
E
+
a
‾
t
x
0
x_t =\sqrt{1 - \overline{a}_t} × \mathcal{E} + \sqrt{\overline{a}_t}x_{0}
xt=1−at×E+atx0的,所以可以用
x
0
x_0
x0表示
x
k
x_k
xk
x s = k x 0 + m ( 1 − a ‾ k × E ′ + a ‾ k x 0 ) + σ E (6) x_s = kx_0+m(\sqrt{1 - \overline{a}_k} × \mathcal{E}' + \sqrt{\overline{a}_k}x_{0}) + \sigma\mathcal{E}\tag{6} xs=kx0+m(1−ak×E′+akx0)+σE(6) - 合并同类项,有
x s = ( k + m a ‾ k x 0 ) x 0 + ( m 1 − a ‾ k × E ′ + σ E ) (7) x_s = (k + m\sqrt{\overline{a}_k}x_{0})x_0+(m\sqrt{1 - \overline{a}_k} × \mathcal{E}' + \sigma\mathcal{E}) \tag{7} xs=(k+makx0)x0+(m1−ak×E′+σE)(7) - 其中后面括号中的两项都满足正态分布,分别是
m
1
−
a
‾
k
×
E
′
∽
N
(
0
,
m
2
(
1
−
a
‾
k
)
)
m\sqrt{1 - \overline{a}_k} × \mathcal{E}'∽N(0,m^2({1 - \overline{a}_k}))
m1−ak×E′∽N(0,m2(1−ak))和
σ
E
∽
N
(
0
,
σ
2
)
\sigma\mathcal{E}∽N(0,\sigma^2)
σE∽N(0,σ2)由于正态分布的可加性,可以得到后面括号中的这项也是符合正态分布的即
N
(
0
,
m
2
(
1
−
a
‾
k
)
+
σ
2
)
N(0,m^2({1 - \overline{a}_k})+\sigma^2)
N(0,m2(1−ak)+σ2),所以上面的式子可以继续改写为
x s = ( k + m a ‾ k x 0 ) x 0 + m 2 ( 1 − a ‾ k ) + σ 2 E (8) x_s = (k + m\sqrt{\overline{a}_k}x_{0})x_0+\sqrt{m^2({1 - \overline{a}_k})+\sigma^2}\mathcal{E}\tag{8} xs=(k+makx0)x0+m2(1−ak)+σ2E(8) - 接着我们需要求
k
,
m
,
σ
k,m,\sigma
k,m,σ,因为
x
s
x_s
xs必须满足
x
s
=
a
‾
s
x
0
+
1
−
a
‾
s
×
E
x_s =\sqrt{\overline{a}_s}x_{0} + \sqrt{1 - \overline{a}_s} × \mathcal{E}
xs=asx0+1−as×E,此时我们可以看到如果令
k + m a ‾ k = a ‾ s m 2 ( 1 − a ‾ k ) + σ 2 = 1 − a ‾ s (9) k + m\sqrt{\overline{a}_k} = \sqrt{\overline{a}_s}\quad\quad m^2({1 - \overline{a}_k})+\sigma^2 = 1 - \overline{a}_s\tag{9} k+mak=asm2(1−ak)+σ2=1−as(9) - 此时有3个未知量,2个等式,就是说有一个一定是自由变量,我们令
σ
\sigma
σ为自由变量,先把m,k求出来,因为m,k都在公式(4)的均值里,我们先再均值求出来,经过计算,得到
m = 1 − a ‾ s − σ 2 1 − a ‾ k k = a ‾ s − ( 1 − a ‾ s − σ 2 ) 1 − a ‾ k a ‾ k (10) m = \frac{\sqrt{1 - \overline{a}_s-\sigma^2}}{\sqrt{1 - \overline{a}_k}}\quad\quad k= \sqrt{\overline{a}_s} - \frac{({\sqrt{1- \overline{a}_s-\sigma^2}})}{\sqrt{1 - \overline{a}_k}}{\sqrt{\overline{a}_k}}\tag{10} m=1−ak1−as−σ2k=as−1−ak(1−as−σ2)ak(10) - 因为均值为
μ
=
k
x
0
+
m
x
k
\mu = kx_0+mx_k
μ=kx0+mxk,所以将
m
,
k
m,k
m,k带入得到新的正态分布
P ( x s ∣ x k , x 0 ) ~ N ( a ‾ s x 0 + 1 − a ‾ s − σ 2 1 − a ‾ k ( x k − a ‾ k x 0 ) , σ 2 I ) (11) P(x_{s} | x_{k},x_{0})~N(\sqrt{\overline{a}_s}x_0 + \frac{\sqrt{1 - \overline{a}_s-\sigma^2}}{\sqrt{1 - \overline{a}_k}}(x_k-\sqrt{\overline{a}_k}x_0), \sigma^2I)\tag{11} P(xs∣xk,x0)~N(asx0+1−ak1−as−σ2(xk−akx0),σ2I)(11) - 这就是得到的新的反向生成分布,也就是要去拟合的 “终极目标”。
- 到这里,不知道给位小伙伴们有没有什么疑问,我们假设的DDIM方法,实际上对应的前向加噪过程已经变了,为什么还能用呢?核心就是因为模型在训练的时候没有用到每一步的训练方式,直接一步到位(直接使用 x 0 → x t x_0→x_t x0→xt的公式,没有使用马尔科夫约束推导了),也就是说即使之前的等式 p ( x k ∣ x s , x 0 ) p(x_k|x_s,x_0) p(xk∣xs,x0)已经发生了变化,但是 p ( x t ∣ x 0 ) p(x_t|x_0) p(xt∣x0)没有变化,模型仍然是能用的,即可以使用DDIM这种方式加速。
2. 标准差的选取
有了上面的概率分布,我们就可以采样了:
x
s
=
a
‾
s
x
0
+
1
−
a
‾
s
−
σ
2
1
−
a
‾
k
(
x
k
−
a
‾
k
x
0
)
+
σ
E
(12)
x_s = \sqrt{\overline{a}_s}x_0 + \frac{\sqrt{1 - \overline{a}_s-\sigma^2}}{\sqrt{1 - \overline{a}_k}}(x_k-\sqrt{\overline{a}_k}x_0) + \sigma\mathcal{E}\tag{12}
xs=asx0+1−ak1−as−σ2(xk−akx0)+σE(12)此时唯一一个还未知的量是
σ
\sigma
σ,因此这个结果对应于一组解,通过规定不同的方差,可以得到不同的采样过程,论文中提出:
- σ = 0 \sigma = 0 σ=0,这时候就变成了确定性的过程了,就变成了固定的生成过程了。
- σ = η 1 − a ‾ t − 1 1 − a ‾ t β t , η ∈ [ 0 , 1 ] \sigma = η\sqrt{\frac{1 - \overline{a}_{t-1}}{1 - \overline{a}_t}{\beta_t}},η∈[0,1] σ=η1−at1−at−1βt,η∈[0,1],如果令 σ = 1 − a ‾ t − 1 1 − a ‾ t β t \sigma = \sqrt{\frac{1 - \overline{a}_{t-1}}{1 - \overline{a}_t}{\beta_t}} σ=1−at1−at−1βt,即 η = 1 η=1 η=1,这个过程就变成了马尔科夫的过程了,有兴趣的可以自己证明一下(视频讲解在此视频34分钟),即证明上面的采样公式的均值和DDPM的均值一样即可。当 η = 0 \eta=0 η=0时,此时生成过程不再添加随机噪声项,唯一带有随机性的因素就是采样初始的 x T ∽ N ( 0 , 1 ) x_T ∽ N(0,1) xT∽N(0,1) ,因此采样的过程是确定的,每个 x T x_T xT对应唯一的 x 0 x_0 x0,就是DDIM,所以说DDIM仅仅是一个采样方式
3. 新参数η
η是文中提出的一个新参数,这里我们主要看两个情况,可以看到随着η的减小,效果是在提升的,评价指标叫做FID,这是一个生成模型的常用指标,指标值越小,说明生成效果越好
- η = 1:DDPM(此时不跳步的话是DDPM,如果跳步了,其实也是DDIM的一种),此时如果跳步很大的话,效果是非常差的
- η = 0:DDIM 跳步很大和很小之间的差距是比较小的
即使都是1,000步,DDIM依然要优于DDPM,因为DDIM牺牲了一定的多样性来提高了图片的质量,例如我们已经知道了一个最优的方向,我们就往这个最优的方向去,大概率它的最终的达到的结果会很好,但因为我们一直奔着那一个方向去,所以说肯定会错过一些事件。
作者还提到了一点:50步的时候生成的图像和100步时生成的图像已经相似性非常高了,这说明它其实是有一个路径的,就是从t=T到t=0时,当我们选出这个 x T x_T xT的时候,这张图是什么,在DDIM中其实就已经定死了。
这里面值得关注的是,由于当 η = 0 \eta=0 η=0时,每个 x T x_T xT对应唯一的 x 0 x_0 x0,这有点类似GAN和VAE,那我们可以认为此时的 x T x_T xT就是一个 h i g h − l e v e l high-level high−level的图像编码向量,里面可能蕴涵了大量的信息特征,也许可以用于其他下游任务
4. 采样加速
我们知道 DDIM 的反向过程并不依赖于马尔可夫假设,因此去噪的过程并不需要在相邻的时间步之间进行,也就是跳过一些中间的步骤。形式化地来说,DDPM 的采样时间步应当是 [ T , T − 1 , . . . , 2 , 1 ] [ T , T − 1 , . . . , 2 , 1 ] [T,T−1,...,2,1] ,而 DDIM 可以直接从其中抽取一个子序列 [ τ S , τ S − 1 , . . . , τ 2 , τ 1 ] [\tau_S,\tau_{S-1},...,\tau_2,\tau_1] [τS,τS−1,...,τ2,τ1] 进行采样。
在 DDIM 论文的附录中,给出了两种子序列的选取方式:
- 线性选取:令 T i = [ c i ] T_i = [ci] Ti=[ci]
- 二次方选取:令 T i = [ c i 2 ] T_i = [ci^2] Ti=[ci2]
其中 c 是一个常量,制定这个常量的规则是让 T − 1 T −1 T−1 也就是最后一个采样时间步尽可能与 T T T 接近。在原文的实验中,CIFAR10 使用的是二次方选取,其他数据集都使用的是线性选取方式。
5. DDIM 区别于 DDPM 的两个特性
-
采样一致性:我们知道 DDIM 的采样过程是确定的,它不会引入额外随机性,生成结果只受 x T \mathbf{x}_T xT影响。作者经过实验发现对于同一个 x T \mathbf{x}_T xT使用不同的采样过程,最终生成的 x 0 \mathbf{x}_0 x0比较相近,这表明它能够表征最终图像的核心特征。因此 x T \mathbf{x}_T xT实际上相当于潜在空间中的一个“嵌入点”,对生成图像的语义内容进行了独特表征。
由于扩散过程通过多步噪声添加来达到最终的 𝑥 𝑇 𝑥_𝑇 xT,这种逐步增加噪声的过程会逐步“去掉”图像的细节,但保留其底层的语义信息(如对象形状、姿态等)。这种特性使得 𝑥 𝑇 𝑥_𝑇 xT对图像的结构性信息具有较高保真度。因此,在反向扩散过程中,无论采用不同的步数或调度方式,从 𝑥 𝑇 𝑥_𝑇 xT出发生成的图像 𝑥 0 𝑥_0 x0总会保持高度的语义一致性。这种语义一致性使得 𝑥 𝑇 𝑥_𝑇 xT能够很好地作为一种语义“嵌入”或编码,反映了最终生成的图像 𝑥 0 𝑥_0 x0的特征。
将 𝑥 𝑇 𝑥_𝑇 xT视为图像的语义嵌入还为控制生成过程提供了理论支持。例如,通过对 𝑥 𝑇 𝑥_𝑇 xT的微小扰动或插值,生成的 𝑥 0 𝑥_0 x0会在语义上表现出相似的内容,而并非完全不同的图像。这种特性使得 𝑥 𝑇 𝑥_𝑇 xT成为一种有潜力的控制变量,可在潜在空间中通过修改 𝑥 𝑇 𝑥_𝑇 xT实现图像内容上的连续变化。
-
语义插值效应:根据上一条性质, 𝑥 𝑇 𝑥_𝑇 xT可以看作 𝑥 0 𝑥_0 x0 的嵌入,那么它可能也具有其他隐概率模型所具有的语义差值效应(即在不同生成状态间产生一种平滑的过渡效果。这种效果意味着在从一个图像生成另一个图像的过程中,逐步改变图像内容的同时,保留其语义特征。)。作者首先选取两个隐变量 𝑥 𝑇 ( 0 ) 𝑥_𝑇^{(0)} xT(0)和 𝑥 𝑇 ( 1 ) 𝑥_𝑇^{(1)} xT(1) ,对其分别采样得到结果,然后使用球面线性插值(是一种在两个向量之间进行平滑插值的方法,通常用于在球面空间中的平滑插值。)得到一系列中间隐变量,这个插值定义为:
𝑥 𝑇 ( α ) = s i n ( 1 − α ) θ s i n θ 𝑥 𝑇 ( 0 ) + s i n α θ s i n θ 𝑥 𝑇 ( 1 ) 𝑥_𝑇^{(α)} = \frac{sin(1-α)θ}{sinθ}{𝑥_𝑇^{(0)}} + \frac{sinαθ}{sinθ}{𝑥_𝑇^{(1)}} xT(α)=sinθsin(1−α)θxT(0)+sinθsinαθxT(1)
其中, θ = a r c c o s ( ( 𝑥 𝑇 ( 0 ) ) T 𝑥 𝑇 ( 1 ) ∣ ∣ 𝑥 𝑇 ( 0 ) ∣ ∣ ∣ ∣ 𝑥 𝑇 ( 1 ) ∣ ∣ ) θ = arccos(\frac{(𝑥_𝑇^{(0)})^T 𝑥_𝑇^{(1)}}{||𝑥_𝑇^{(0)}||~||𝑥_𝑇^{(1)}||}) θ=arccos(∣∣xT(0)∣∣ ∣∣xT(1)∣∣(xT(0))TxT(1))。最终也在 DDIM 上观察到了语义插值效应,我们下面也将复现这一实验。
语义插值效应的特点:
- 平滑的语义变化:当输入的潜在向量(通常是噪声向量)在两个状态之间线性插值时,DDIM生成的输出会产生平滑的语义过渡。例如,如果起始图像是“白天的城市景观”,目标图像是“夜晚的城市景观”,插值过程可以让生成的图像逐渐从白天变为夜晚,而不引入额外的噪声。
- 保持内容一致性:与随机生成相比,DDIM插值更关注图像的结构性和语义内容的连贯性。插值生成的每一步图像都较清晰且在语义上合理,如同将原始图像分阶段逐渐转换为目标图像。
- 可控性:通过调节插值的比例,可以生成不同程度的“混合图像”,从而实现对图像内容的细微控制。例如,可以通过插值系数生成出50%属于白天场景、50%属于夜晚场景的图像。
~如何实现语义插值:
- 定义初始和目标潜在向量:使用 DDIM 在图像生成过程中,定义两个不同的潜在向量(噪声向量)对应于起始和目标图像。
- 线性插值:通过逐步插值,将初始向量逐渐转变为目标向量。
- 使用调度器生成插值图像:在生成过程中,DDIM调度器能够逐步减少噪声,生成出逐渐变化的图像。每一步生成的图像在语义上接近于前一个图像,并逐步向目标图像过渡。
二、DDIM 的代码实现
从上面的推导过程可以发现,DDIM 假设的前向过程和 DDPM 相同,只有采样过程不同。因此想把 DDPM 改成 DDIM 并不需要重新训练,只要修改采样过程就可以了。在上一篇文章中我们已经训练好了一个 DDPM 模型(训练的蝴蝶数据集,效果不是很好),这里我们继续用这个训练好的模型来构造 DDIM 的采样过程。
我们把训练好的 DDPM 模型的权重加载进来用作噪声预测网络:
from diffusers import UNet2DModel
model = UNet2DModel.from_pretrained('ddpm-butterflies-128').cuda()
1. 核心代码
首先我们依然是定义一系列常量, α 、 β \alpha、\beta α、β 等都和 DDPM 相同,只有采样的时间步不同。我们在这里直接线性选取 20 个时间步,最大的为 999,最小的为 0:
import torch
class DDIM:
def __init__(
self,
num_train_timesteps:int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
sample_steps: int = 20,
):
self.num_train_timesteps = num_train_timesteps
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.timesteps = torch.linspace(num_train_timesteps - 1, 0, sample_steps).long()
然后是实现采样过程,和 DDPM 一样,我们把需要的公式复制到这里,然后对照着实现:
x
s
=
a
‾
s
x
0
+
1
−
a
‾
s
−
σ
2
1
−
a
‾
k
(
x
k
−
a
‾
k
x
0
)
+
σ
E
(12)
x_s = \sqrt{\overline{a}_s}x_0 + \frac{\sqrt{1 - \overline{a}_s-\sigma^2}}{\sqrt{1 - \overline{a}_k}}(x_k-\sqrt{\overline{a}_k}x_0) + \sigma\mathcal{E}\tag{12}
xs=asx0+1−ak1−as−σ2(xk−akx0)+σE(12)
σ
=
η
1
−
a
‾
t
−
1
1
−
a
‾
t
β
t
\sigma = η\sqrt{\frac{1 - \overline{a}_{t-1}}{1 - \overline{a}_t}{\beta_t}}
σ=η1−at1−at−1βt
import math
from tqdm import tqdm
class DDIM:
...
@torch.no_grad()
def sample(
self,
unet: UNet2DModel,
batch_size: int,
in_channels: int,
sample_size: int,
eta: float = 0.0,
):
alphas = self.alphas.to(unet.device)
alphas_cumprod = self.alphas_cumprod.to(unet.device)
timesteps = self.timesteps.to(unet.device)
images = torch.randn((batch_size, in_channels, sample_size, sample_size), device=unet.device)
for t, tau in tqdm(list(zip(timesteps[:-1], timesteps[1:])), desc='Sampling'):
pred_noise: torch.Tensor = unet(images, t).sample
# sigma_t
if not math.isclose(eta, 0.0):
one_minus_alpha_prod_tau = 1.0 - alphas_cumprod[tau]
one_minus_alpha_prod_t = 1.0 - alphas_cumprod[t]
one_minus_alpha_t = 1.0 - alphas[t]
sigma_t = eta * (one_minus_alpha_prod_tau * one_minus_alpha_t / one_minus_alpha_prod_t) ** 0.5
else:
sigma_t = torch.zeros_like(alphas[0])
# first term of x_tau
alphas_cumprod_tau = alphas_cumprod[tau]
sqrt_alphas_cumprod_tau = alphas_cumprod_tau ** 0.5
alphas_cumprod_t = alphas_cumprod[t]
sqrt_alphas_cumprod_t = alphas_cumprod_t ** 0.5
sqrt_one_minus_alphas_cumprod_t = (1.0 - alphas_cumprod_t) ** 0.5
first_term = sqrt_alphas_cumprod_tau * (images - sqrt_one_minus_alphas_cumprod_t * pred_noise) / sqrt_alphas_cumprod_t
# second term of x_tau
coeff = (1.0 - alphas_cumprod_tau - sigma_t ** 2) ** 0.5
second_term = coeff * pred_noise
epsilon = torch.randn_like(images)
images = first_term + second_term + sigma_t * epsilon
images = (images / 2.0 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy()
return images
最后执行采样过程:
ddim = DDIM()
images = ddim.sample(model, 32, 3, 64)
from diffusers.utils import make_image_grid, numpy_to_pil
image_grid = make_image_grid(numpy_to_pil(images), rows=4, cols=8)
image_grid.save('ddim-sample-results.png')
2. 结果展示
在20个时间步下,效果如下,完全达不到DDPM上的效果(虽然DDPM上的效果也不太好),应该是数据比较少和训练次数少的问题,得到的结果如下图所示:
脸部模型效果
3. 语义插值效应复现
语义插值效应也比较简单,只需要修改初始化的
x
T
\mathbf{x}_T
xT即可。根据上文的叙述,我们首先实现球面线性插值:
import torch
def slerp(
x0: torch.Tensor,
x1: torch.Tensor,
alpha: float,
):
theta = torch.acos(torch.sum(x0 * x1) / (torch.norm(x0) * torch.norm(x1)))
w0 = torch.sin((1.0 - alpha) * theta) / torch.sin(theta)
w1 = torch.sin(alpha * theta) / torch.sin(theta)
return w0 * x0 + w1 * x1
我们这次要实现的和原论文不同,原论文的插值只在一行内部,我们希望实现一个二维的插值,也就是在一个图片网格中,从左上角到右下角存在一个渐变效果。为此,我们需要先构建一个二维的图片网格,然后按以下的步骤完成二维插值:
- 初始化网格四角的 x T ∼ N ( 0 , 1 ) \mathbf{x}_T\sim\mathcal{N}(0,1) xT∼N(0,1)
- 在网格的最左侧和最右侧两列中进行插值,例如最左侧的一列由左上角与左下角两个样本插值得到、最右侧的一列由右上角与右下角的两个样本插值得到;
- 遍历所有行,把每行中间的元素用该行最左侧与最右侧的元素进行插值,完成全部 x T \mathbf{x}_T xT的初始化。
具体的直接看代码就好:
## 生成一个二维图像网格并用slerp插值来构造网格边界和内部的噪声。
## 生成角点噪声:随机生成网格的四个角点噪声,保证生成图像在网格上的连续变化。
## 插值生成边缘噪声:根据四个角点,通过球形插值生成网格的边界列和边界行噪声。
## 插值内部噪声:根据边界噪声插值生成内部的噪声图像。
def interpolation_grid(
rows: int,
cols: int,
in_channels: int,
sample_size: int,
):
images = torch.zeros((rows * cols, in_channels, sample_size, sample_size), dtype=torch.float32)
images[0, ...] = torch.randn_like(images[0, ...]) # top left
images[cols - 1, ...] = torch.randn_like(images[0, ...]) # top right
images[(rows - 1) * cols, ...] = torch.randn_like(images[0, ...]) # bottom left
images[-1] = torch.randn_like(images[0, ...]) # bottom right
for row in range(1, rows - 1): # interpolate left most column and right most column
alpha = row / (rows - 1)
images[row * cols, ...] = slerp(images[0, ...], images[(rows - 1) * cols, ...], alpha)
images[(row + 1) * cols - 1, ...] = slerp(images[cols - 1, ...], images[-1, ...], alpha)
for col in range(1, cols - 1): # interpolate others
alpha = col / (cols - 1)
images[col::cols, ...] = slerp(images[0::cols, ...], images[cols - 1::cols, ...], alpha)
return images
最后把 images 的初始化从 torch.randn 改成调用 interpolation_grid:
images = interpolation_grid(rows, cols, in_channels, sample_size).to(unet.device)
结果:
参考:【笔记】扩散模型(二):DDIM 理论推导与代码实现
完整代码:https://github.com/LittleNyima/code-snippets/tree/master/ddim-tutorial