扩散模型主要分为前向过程与后向过程,其中前向过程主要是通过不断给数据集中的图像加噪声,并记录下整个过程,后向过程是一个逐渐去噪的过程,这里的关键就是预测一下每一步减去的噪声,从而获取一步去噪后的分布情况。
DDPM
一、前向过程
- 首先令:
α
t
=
1
−
β
t
\alpha_t=1-\beta_t
αt=1−βt
其中 β t \beta_t βt随着t的增加越来越大,论文中是从0.0001到0.002。 - 令 x t = a t x t − 1 + 1 − α t z t ( 1 ) x_t=\sqrt{a_t}x_{t-1}+\sqrt{1-\alpha_t}z_t (1) xt=atxt−1+1−αtzt(1) 此公式的意思是,t时刻的图像是有前一时刻图像加上一个从正态分布中采样的噪声加权相加的过程。其中 z ∼ N ( 0 , 1 ) z\sim\mathcal{N}(0,\mathbf{1}) z∼N(0,1)
- 由式(1)递推,可得到:
x
t
=
a
t
(
a
t
−
1
x
t
−
2
+
1
−
α
t
−
1
z
t
−
1
)
+
1
−
α
t
z
t
x_t=\sqrt{a_t}\big(\sqrt{a_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t-1}}z_{t-1}\big)+\sqrt{1-\alpha_t}z_t
xt=at(at−1xt−2+1−αt−1zt−1)+1−αtzt
=
a
t
a
t
−
1
x
t
−
2
+
(
a
t
(
1
−
α
t
−
1
)
z
t
−
1
+
1
−
α
t
z
t
)
=\sqrt{a_ta_{t-1}}x_{t-2}+(\sqrt{a_t(1-\alpha_{t-1})}z_{t-1}+\sqrt{1-\alpha_t}z_t)
=atat−1xt−2+(at(1−αt−1)zt−1+1−αtzt) 此时根据高斯分布的性质,可得:
=
a
t
a
t
−
1
x
t
−
2
+
1
−
α
t
α
t
−
1
z
‾
t
=\sqrt{a_ta_{t-1}}x_{t-2}+\sqrt{1-\alpha_t\alpha_{t-1}}\overline{z}_{t}
=atat−1xt−2+1−αtαt−1zt 其中
z
‾
t
∼
N
(
0
,
1
)
\overline{z}_{t}\sim\mathcal{N}(0,\mathbf{1})
zt∼N(0,1)
继续迭代,可得最终前向公式: x t = α ‾ t x 0 + 1 − α ‾ t z ‾ t x_t=\sqrt{\overline{\alpha}_t}x_0+\sqrt{1-\overline{\alpha}_t}\overline{z}_t xt=αtx0+1−αtzt 其中, α ‾ t \overline{\alpha}_t αt是表示累乘。
二、反向过程
原理主要是通过后一张的分布预测处前一张图像的分布情况。
- 首先根据贝叶斯公式,可得: q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t,\mathbf{x}_0)=q(\mathbf{x}_t|\mathbf{x}_{t-1},\mathbf{x}_0)\frac{q(\mathbf{x}_{t-1}|\mathbf{x}_0)}{q(\mathbf{x}_t|\mathbf{x}_0)} q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt∣x0)q(xt−1∣x0) 其中q代表的意思是该图像的分布情况。
- 由前向过程我们可得:
上面三个式子对应的概率密度函数,分别为:
再带入到贝斯公式中:
整理成正态分布的概率密度形式:
1
2
π
(
1
−
a
t
1
−
a
ˉ
t
−
1
1
−
a
ˉ
t
)
e
[
−
(
x
t
−
1
−
(
a
t
(
1
−
a
ˉ
t
−
1
)
1
−
a
ˉ
t
x
t
+
a
t
−
1
(
1
−
a
t
)
1
−
a
ˉ
t
x
0
)
)
2
2
(
1
−
a
t
1
−
a
ˉ
t
−
1
1
−
a
ˉ
t
)
2
]
\frac1{\sqrt{2\pi}\left(\color{red}{\frac{\sqrt{1-a_t}\sqrt{1-\bar{a}_{t-1}}}{\sqrt{1-\bar{a}_t}}}\right)}e^{\left[-\frac{\left(x_{t-1}-\left(\color{red}\frac{\sqrt{a_t}(1-\bar{a}_{t-1})}{1-\bar{a}_t}x_t+\frac{\sqrt{a_{t-1}}(1-a_t)}{1-\bar{a}_t}x_0\right)\right)^2}{2\left(\color{red}\frac{\sqrt{1-a_t}\sqrt{1-\bar{a}_{t-1}}}{\sqrt{1-\bar{a}_t}}\right)^2}\right]}
2π(1−aˉt1−at1−aˉt−1)1e
−2(1−aˉt1−at1−aˉt−1)2(xt−1−(1−aˉtat(1−aˉt−1)xt+1−aˉtat−1(1−at)x0))2
其中,x0我们可以通过前向过程公式得到: x 0 = 1 α ˉ t ( x t − 1 − α ˉ t z t ) \mathbf{x}_0=\frac1{\sqrt{\bar{\alpha}_t}}(\mathbf{x}_t-\sqrt{1-\bar{\alpha}_t}\mathbf{z}_t) x0=αˉt1(xt−1−αˉtzt) 现在获取Zt是获得前一张图像分布的关键。
最终分布:
其中
ϵ
\epsilon
ϵ就是
z
t
z_t
zt一个逆向噪声。
上式还可以化简,最后可得分布的均值,方差为:
三、重采样
首先任何正态分布都可以从标准正态分布变换得到,正态分布转标准正态分布式减均值除标准差,因此标准正态分布转其他分部只需要乘标准差加均值即可,因此最终采样公式如下:
四、预测噪声
现在后向过程的关键就是获取一次抽样的噪声了,只能通过模型来进行预测。
通常采用U-net模型来进行,其中模型的输入参数有两个,分别是当前时刻的分布以及当前时刻t。
训练以及采样流程:
预测噪声的部分代码:
噪声模型:
import torch
import torch.nn as nn
class MLPDiffusion(nn.Module) :
def __init__(self,n_steps,num_units=128):
super(MLPDiffusion,self).__init__()
self.linears = nn.ModuleList(
[
nn.Linear(2,num_units),
nn.ReLU(),
nn.Linear(num_units,num_units),
nn.ReLU(),
nn.Linear(num_units,num_units),
nn.ReLU(),
nn.Linear(num_units,2),
]
)
self.step_embeddings = nn.ModuleList(
[
nn.Embedding(n_steps,num_units),
nn.Embedding(n_steps,num_units),
nn.Embedding(n_steps,num_units),
]
)
def forward(self,x,t):
# x = x_0
for idx,embedding_layer in enumerate(self.step_embeddings):
t_embedding = embedding_layer(t)
x = self.linears[2*idx](x)
x += t_embedding
x = self.linears[2*idx+1](x)
x = self.linears[-1](x) #输出噪声 shape:128*2
return x
损失函数:
def diffusion_loss_fn(model,x_0,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,n_steps):
"""对任意时刻t进行采样计算loss"""
batch_size = x_0.shape[0]
#对一个batchsize样本生成随机的时刻t
t = torch.randint(0,n_steps,size=(batch_size//2,)) #0-100中间随机生成时刻
t = torch.cat([t,n_steps-1-t],dim=0)
t = t.unsqueeze(-1) #为了让获取的系数与x_0形状类似以至于可以广播相乘
#x0的系数
a = alphas_bar_sqrt[t]
#eps的系数
aml = one_minus_alphas_bar_sqrt[t]
#生成随机噪音eps
e = torch.randn_like(x_0)
#构造模型的输入
x = x_0*a+e*aml
#送入模型,得到t时刻的随机噪声预测值
output = model(x,t.squeeze(-1))
#与真实噪声一起计算误差,求平均值
return (e - output).square().mean()
注意:这里的e就是正向过程中加进去的噪声,只不过这里没有单独走一遍前向过程,而是在反向过程的途中再算正向过程的那个噪声,这样就不用存每次加的噪声了。
逆采样过程:
def p_sample_loop(model,shape,n_steps,betas,one_minus_alphas_bar_sqrt):
"""从x[T]恢复x[T-1]、x[T-2]|...x[0]"""
cur_x = torch.randn(shape)
x_seq = [cur_x]
for i in reversed(range(n_steps)):
cur_x = p_sample(model,cur_x,i,betas,one_minus_alphas_bar_sqrt)
x_seq.append(cur_x)
return x_seq
def p_sample(model,x,t,betas,one_minus_alphas_bar_sqrt):
"""从x[T]采样t时刻的重构值"""
t = torch.tensor([t])
coeff = betas[t] / one_minus_alphas_bar_sqrt[t]
eps_theta = model(x,t)
mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))
z = torch.randn_like(x)
sigma_t = betas[t].sqrt()
sample = mean + sigma_t * z
return (sample)
DDIM
用来加速DDPM的反向过程,使其反向过程不遵循马尔可夫链。
- 假设
P
(
x
t
−
1
∣
x
t
,
x
0
)
P(x_{t-1}|x_t, x_0)
P(xt−1∣xt,x0)满足以下正态分布:
- 因为加载过程满足公式:
- 代入得:
- 同样,通过正向过程,也可得出
x
t
−
1
x_{t-1}
xt−1的图像:
- 根据系数相同,可得:
- 可得,
P
(
x
t
−
1
∣
x
t
,
x
0
)
P(x_{t-1}|x_t, x_0)
P(xt−1∣xt,x0)分布为:
- 同样采用如下替换:
- 可得
x
t
−
1
x_{t-1}
xt−1重采样为:
9.
关于DDIM也可参考这篇博文https://kexue.fm/archives/9181
总结
代码在资源中。