正文
我们先来看一个简单的case。
有一组坐落在x轴的点集,最小和最大的数值为-4和4。我用浅绿色将这些点标记,记作
X
0
X_0
X0
X
0
∈
{
(
−
4
,
0
)
,
(
−
3
,
0
)
,
(
−
2
,
0
)
,
(
−
1
,
0
)
,
(
0
,
0
)
,
(
1
,
0
)
,
(
2
,
0
)
,
(
3
,
0
)
,
(
4
,
0
)
}
X_0 \in \{(-4,0), (-3,0),(-2,0),(-1,0),(0,0),(1,0),(2,0),(3,0),(4,0) \}
X0∈{(−4,0),(−3,0),(−2,0),(−1,0),(0,0),(1,0),(2,0),(3,0),(4,0)}
很明显,
X
0
X_0
X0分布的特点是9个点都坐落在X轴上,并且有大小范围约束。
那么,如果我们想将
X
0
X_0
X0代表的线段分布变成半圆线段,该如何做呢?
用
X
1
X_1
X1记作半圆线段对应的分布,学过高中数学的同学会想到圆形公式:
x
0
2
+
x
1
2
=
4
2
x
1
=
4
2
−
x
0
2
x_0^2 + x_1^2 = 4^2 \\ x_1 = \sqrt {4^2 - x_0^2}
x02+x12=42x1=42−x02
这里我们只考虑正半轴的情况。因此,定义
f
(
x
)
=
4
2
−
x
2
f(x)=\sqrt {4^2 - x^2}
f(x)=42−x2 是将分布
X
0
X_0
X0转为
X
1
X_1
X1的精准映射函数
用红色的点集表示分布
X
1
X_1
X1。
然而现实问题会更加复杂,我们往往找不到一个精准映射的函数,更多的问题是已知
X
0
X_0
X0和
X
1
X_1
X1,需要找到
f
f
f。因此考虑一种复杂的情况,已知X和Y,但不知道
f
f
f,如何让X分布映射到Y上。
有的同学可能想到了,我们可以设计一条轨迹,或者叫路径,让
X
0
X_0
X0逐渐往
X
1
X_1
X1上迁移,这个轨迹可能有很多步,我们假设第0步为0,最后一步为1。0-1之间的任意步骤都是轨迹上的中间态
X
t
X_t
Xt。
那我们可以设计一个最简单的路径,路径上的中间态
X
t
X_t
Xt为
X
t
=
(
1
−
t
)
×
X
0
+
t
×
X
1
X_t = (1-t) \times X_0 + t \times X_1
Xt=(1−t)×X0+t×X1
t
t
t表示0-1之间的任意一步,当t为0,即轨迹的起点,公式最终得到的是
X
0
X_0
X0;反之当t越大,
X
t
X_t
Xt越接近
X
1
X_1
X1
但就像之前说的,实际情况往往更加复杂,假设X0是一个非常复杂的分布,比如真实图像;X1是个很简单的分布,比如标准高斯噪声,就像DDPM做图像生成任务一样。
我们发现,从X0到X1是简单的,使用以上设计的路径依然成立,即我们可以将任何来自真实图像分布的数据变成随机标准正态分布;但从X1到X0是复杂的,我们无法使用这么简单的路径将随机噪声变成真实图像。
首先约定,从
X
0
X_0
X0到
X
1
X_1
X1的过程为正向过程;从
X
1
X_1
X1到
X
0
X_0
X0的过程为反向过程。
t
t
t的每一步变化长度最小为
d
t
dt
dt
如果没办法使用前向路径的反向公式变换,实现反向过程,我们就设计一个映射函数,帮助我们实现反向过程。
在前向过程中,根据公式
X
t
=
(
1
−
t
)
×
X
0
+
t
×
X
1
X_t = (1-t) \times X_0 + t \times X_1
Xt=(1−t)×X0+t×X1,我们可以得到任意
x
t
x_t
xt,当然也包括
x
t
−
d
t
x_{t-dt}
xt−dt。因此我们就可以得到训练pair数据
(
x
t
,
x
t
−
d
t
)
(x_t, x_{t-dt})
(xt,xt−dt),用于训练一个映射模型
f
(
x
t
,
t
)
f(x_t, t)
f(xt,t),得到轨迹中的
t
t
t时刻前一时刻
t
−
d
t
t-dt
t−dt的状态
x
t
−
d
t
x_{t-dt}
xt−dt。
那么,再细想一下,映射模型的拟合对象该如何设计?
根据公式
X
t
=
(
1
−
t
)
×
X
0
+
t
×
X
1
X_t = (1-t) \times X_0 + t \times X_1
Xt=(1−t)×X0+t×X1,我们已知
x
t
x_t
xt是模型的输入,得到
X
0
X_0
X0可以推导出
X
1
X_1
X1的有偏估计,反之得到
X
1
X_1
X1也能推导出
X
0
X0
X0的有偏估计,通过
X
t
−
d
t
=
(
1
−
(
t
−
d
t
)
)
×
X
0
+
(
t
−
d
t
)
×
X
1
X_{t-dt}= (1-(t-dt)) \times X_0 + (t-dt) \times X_1
Xt−dt=(1−(t−dt))×X0+(t−dt)×X1,我们就能得到前一个状态的估计了,也就是
x
t
−
d
t
x_{t-dt}
xt−dt。
因此
f
f
f的拟合对象有3个选择:
- 直接拟合 x t − d t x_{t-dt} xt−dt,毕竟我们有了训练数据pair对,我们直接拟合前一步的状态值即可。
- 拟合 X 0 X_0 X0
- 拟合 X 1 X_1 X1
然而,论文DDPM中证明了这三种在原理上是等价的(经过一系列的公式换算可以等价,本篇文章目的是使用简单的方式介绍DDPM,因此不进行展开描述)。同时作者经过实验,认为拟合
X
1
X_1
X1效果较好。因此
x
1
e
s
t
=
f
(
x
t
,
t
)
x
0
e
s
t
=
x
t
−
t
×
x
1
1
−
t
x
t
−
d
t
e
s
t
=
(
1
−
(
t
−
d
t
)
)
×
x
0
e
s
t
+
(
t
−
d
t
)
×
x
1
e
s
t
x_1^{est} = f(x_t, t) \\ x_0^{est} = \frac{x_t - t \times x_1}{1-t} \\ x_{t-dt}^{est} = (1-(t-dt)) \times x_0^{est} + (t-dt) \times x_1^{est}
x1est=f(xt,t)x0est=1−txt−t×x1xt−dtest=(1−(t−dt))×x0est+(t−dt)×x1est
首先模型估计出
x
1
e
s
t
x_1^{est}
x1est,利用公式变换形式,进而估计出
x
0
e
s
t
x_0^{est}
x0est;最后仍然是根据公式得到
x
t
−
d
t
e
s
t
x_{t-dt}^{est}
xt−dtest。接着这个过程只要重复
t
/
d
t
t / dt
t/dt次,我们就可以得到将分布
X
1
X_1
X1变成
X
0
X_0
X0的轨迹,实现了完整的反向过程。
接着,我们以X1和X0的点集数据为例,训练一个 f f f模型,同时观察测试集上的轨迹变化,是否符合我们的预期。
记 X 1 X_1 X1为在半圆上的点, X 0 X_0 X0为x轴上的点,
定义公式 X t = ( 1 − t ) × X 0 + t × X 1 X_t = (1-t) \times X_0 + t \times X_1 Xt=(1−t)×X0+t×X1
def get_x_t(t, x0, x1):
return x0 * (1-t) + x1 *
公式变换,定义 x 0 x_0 x0的有偏估计
def get_x0(xt, t, x1):
return (xt - t * x1) / (1 - t + 1e-7)
定义 f ( x t , t ) f(x_t, t) f(xt,t),因为我们的任务很简单,使用一个简单的4层mlp足够了
class mlp(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.nn = torch.nn.Sequential(
torch.nn.Linear(2+1, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 2),
)
def forward(self, xt, t):
x_t = torch.cat([xt, t], dim=1)
return self.nn(x_t)
model = mlp()
model.cuda()
model.train()
定义反向采样过程
x
1
e
s
t
=
f
(
x
t
,
t
)
x
0
e
s
t
=
x
t
−
t
×
x
1
1
−
t
x
t
−
d
t
e
s
t
=
(
1
−
(
t
−
d
t
)
)
×
x
0
e
s
t
+
(
t
−
d
t
)
×
x
1
e
s
t
x_1^{est} = f(x_t, t) \\ x_0^{est} = \frac{x_t - t \times x_1}{1-t} \\ x_{t-dt}^{est} = (1-(t-dt)) \times x_0^{est} + (t-dt) \times x_1^{est}
x1est=f(xt,t)x0est=1−txt−t×x1xt−dtest=(1−(t−dt))×x0est+(t−dt)×x1est
class DDPM():
def __init__(self, model, total_step=11) -> None:
self.total_step = total_step
self.model = model
@torch.no_grad()
def sample(self, x1):
step = torch.linspace(0.0, 0.95, self.total_step).flip(0).to(x1.device)
self.model.eval()
x1[:, 1] = x1[:, 1] * 0.95 # 消除当t为1时,get_x0中的分母影响
bs = x1.shape[0]
traj = []
xt = x1
traj.append(xt)
for step_idx in range(self.total_step):
# step从0.95变到0
x1 = self.model(xt, step[step_idx].view(1, 1).expand(bs, -1))
x0 = get_x0(xt, step[step_idx].item(), x1)
if step_idx < (self.total_step - 1):
x_t_1 = get_x_t(step[step_idx + 1], x0, x1)
# 将计算的前一时刻状态重新赋值给x_t
xt = x_t_1
traj.append(xt)
# 最终的x0是我们所需要的反向过程的最终输出
traj.append(x0)
return x0, traj
ddpm_sample = DDPM(mlp, total_step=100)
定义训练过程
def train_loop():
optim = torch.optim.AdamW(model.parameters(), lr=1e-4)
# 训练2000步
loss_list = []
for idx in tqdm(range(2000)):
# 随机生成一些数据
x_0_data_x = np.random.uniform(-4, 4, (1000,))
x_0_data = np.stack([x_0_data_x, np.zeros_like(x_0_data_x)], axis=1) # 1000, 2
x_1_data_x = np.random.uniform(-4, 4, (1000,))
x_1_data_y = (16 - x_1_data_x ** 2) **0.5
x_1_data = np.stack([x_1_data_x, x_1_data_y], axis=1) # 1000,2
x_0_data = torch.from_numpy(x_0_data).float().cuda()
x_1_data = torch.from_numpy(x_1_data).float().cuda()
n_data = x_1_data.shape[0]
# 随机生成一些时刻
time_data = torch.rand((n_data, 1)).to(x_0_data.device)
x_t = get_x_t(time_data, x_0_data, x_1_data)
target = x_1_data # 拟合对象为X1
pred = model(x_t, time_data)
loss = torch.nn.functional.mse_loss(pred, target)
# print(f'loss:{loss:.3f}, {pred[:10]}')
optim.zero_grad()
loss.backward()
optim.step()
loss_list.append(loss.item())
return loss_list, model
# 开始训练
loss_list, model = train_loop()
plt.plot(np.arange(len(loss_list)), loss_list)
plt.savefig('loss_curve.jpg')
定义测试过程
ddpm_sample = DDPM(model, total_step=100)
# 测试, 重新生成一批X1, 一共20个点
x_1_data_x = np.random.uniform(-4, 4, (20,))
x_1_data_y = (16 - x_1_data_x ** 2) **0.5
x_1_data = np.stack([x_1_data_x, x_1_data_y], axis=1) # 20,2
x_1_data = torch.from_numpy(x_1_data).float().cuda()
x0, traj = ddpm_sample.sample(x_1_data)
figure = plt.figure()
for t in traj[-1:]:
t = t.cpu().numpy()
plt.scatter(t[:, 0], t[:, 1])
x_1_data = x_1_data.cpu().numpy()
plt.scatter(x_1_data[:, 0], x_1_data[:, 1], c='r')
figure.savefig("trajectory.jpg")
loss曲线
下面是轨迹图,最上面的红色点是分布
X
1
X_1
X1,都在一个半圆上面。顺着轨迹上的100个中间状态,慢慢变成了最下面的蓝色点。蓝色点虽然不完全在X轴上,但都大致离X轴接近,并且数值范围在-4到4,满足
X
0
X_0
X0的分布特点。观察轨迹符合我们的预期,模型训练成功。
回到图像生成DDPM
DDPM的前向公式为
其实就是
x
t
=
a
‾
t
x
0
+
(
1
−
a
‾
t
I
x_t = \sqrt{\overline{a}_t} x_0 + (1 - {\overline{a}_t} I
xt=atx0+(1−atI
我们把
I
I
I当成
X
1
X_1
X1,那么DDPM前向公式的形式就和我之前介绍的一致了。
再看DDPM中如何得到
x
t
−
1
x_{t-1}
xt−1
你会发现其实就是两项相加,第一项是关于
x
0
x_0
x0和
x
t
x_t
xt的加权,这个也和我们的推导
x
t
−
d
t
e
s
t
=
(
1
−
(
t
−
d
t
)
)
×
x
0
e
s
t
+
(
t
−
d
t
)
×
x
1
e
s
t
x_{t-dt}^{est} = (1-(t-dt)) \times x_0^{est} + (t-dt) \times x_1^{est}
xt−dtest=(1−(t−dt))×x0est+(t−dt)×x1est类似,只是他还有第三项
β
t
\beta_t
βt,而这一项是已知的数值。
你可能会好奇,这个前向公式是如何得来的呢?
你还可能会好奇,建立在马尔科夫链假设上的ddpm,为何优化目标可以被简化到直接对x_1$进行拟合呢?
这些内容,在未来继续分享。
本文总结
本文从一个简化的问题入手,用两个不同分布的点集这种简单的数据类型作为样例,讲解了DDPM问题的建模过程,整个建模过程的核心是设计前向公式,并围绕着前向公式变换为推理过程,进而引导读者思考模型在推理过程中起到的作用。 并用python代码做了训练和测试的实验,最终的结果也符合我们的预期。从理论和实践上较为完整的介绍了DDPM的核心思想和使用方法。
本文为作者原创,转载请注明出处