从另一种简单的形式理解扩散模型原理和代码实践

正文

我们先来看一个简单的case。

有一组坐落在x轴的点集,最小和最大的数值为-4和4。我用浅绿色将这些点标记,记作 X 0 X_0 X0
在这里插入图片描述
X 0 ∈ { ( − 4 , 0 ) , ( − 3 , 0 ) , ( − 2 , 0 ) , ( − 1 , 0 ) , ( 0 , 0 ) , ( 1 , 0 ) , ( 2 , 0 ) , ( 3 , 0 ) , ( 4 , 0 ) } X_0 \in \{(-4,0), (-3,0),(-2,0),(-1,0),(0,0),(1,0),(2,0),(3,0),(4,0) \} X0{(4,0),(3,0),(2,0),(1,0),(0,0),(1,0),(2,0),(3,0),(4,0)}
很明显, X 0 X_0 X0分布的特点是9个点都坐落在X轴上,并且有大小范围约束。
那么,如果我们想将 X 0 X_0 X0代表的线段分布变成半圆线段,该如何做呢?
X 1 X_1 X1记作半圆线段对应的分布,学过高中数学的同学会想到圆形公式:
x 0 2 + x 1 2 = 4 2 x 1 = 4 2 − x 0 2 x_0^2 + x_1^2 = 4^2 \\ x_1 = \sqrt {4^2 - x_0^2} x02+x12=42x1=42x02
这里我们只考虑正半轴的情况。因此,定义 f ( x ) = 4 2 − x 2 f(x)=\sqrt {4^2 - x^2} f(x)=42x2 是将分布 X 0 X_0 X0转为 X 1 X_1 X1的精准映射函数
在这里插入图片描述
用红色的点集表示分布 X 1 X_1 X1

然而现实问题会更加复杂,我们往往找不到一个精准映射的函数,更多的问题是已知 X 0 X_0 X0 X 1 X_1 X1,需要找到 f f f。因此考虑一种复杂的情况,已知X和Y,但不知道 f f f,如何让X分布映射到Y上。
有的同学可能想到了,我们可以设计一条轨迹,或者叫路径,让 X 0 X_0 X0逐渐往 X 1 X_1 X1上迁移,这个轨迹可能有很多步,我们假设第0步为0,最后一步为1。0-1之间的任意步骤都是轨迹上的中间态 X t X_t Xt
那我们可以设计一个最简单的路径,路径上的中间态 X t X_t Xt
X t = ( 1 − t ) × X 0 + t × X 1 X_t = (1-t) \times X_0 + t \times X_1 Xt=(1t)×X0+t×X1
t t t表示0-1之间的任意一步,当t为0,即轨迹的起点,公式最终得到的是 X 0 X_0 X0;反之当t越大, X t X_t Xt越接近 X 1 X_1 X1

但就像之前说的,实际情况往往更加复杂,假设X0是一个非常复杂的分布,比如真实图像;X1是个很简单的分布,比如标准高斯噪声,就像DDPM做图像生成任务一样。
我们发现,从X0到X1是简单的,使用以上设计的路径依然成立,即我们可以将任何来自真实图像分布的数据变成随机标准正态分布;但从X1到X0是复杂的,我们无法使用这么简单的路径将随机噪声变成真实图像。
首先约定,从 X 0 X_0 X0 X 1 X_1 X1的过程为正向过程;从 X 1 X_1 X1 X 0 X_0 X0的过程为反向过程。 t t t的每一步变化长度最小为 d t dt dt
如果没办法使用前向路径的反向公式变换,实现反向过程,我们就设计一个映射函数,帮助我们实现反向过程。
在前向过程中,根据公式
X t = ( 1 − t ) × X 0 + t × X 1 X_t = (1-t) \times X_0 + t \times X_1 Xt=(1t)×X0+t×X1,我们可以得到任意 x t x_t xt,当然也包括 x t − d t x_{t-dt} xtdt。因此我们就可以得到训练pair数据 ( x t , x t − d t ) (x_t, x_{t-dt}) (xt,xtdt),用于训练一个映射模型 f ( x t , t ) f(x_t, t) f(xt,t),得到轨迹中的 t t t时刻前一时刻 t − d t t-dt tdt的状态 x t − d t x_{t-dt} xtdt
那么,再细想一下,映射模型的拟合对象该如何设计?
根据公式 X t = ( 1 − t ) × X 0 + t × X 1 X_t = (1-t) \times X_0 + t \times X_1 Xt=(1t)×X0+t×X1,我们已知 x t x_t xt是模型的输入,得到 X 0 X_0 X0可以推导出 X 1 X_1 X1的有偏估计,反之得到 X 1 X_1 X1也能推导出 X 0 X0 X0的有偏估计,通过 X t − d t = ( 1 − ( t − d t ) ) × X 0 + ( t − d t ) × X 1 X_{t-dt}= (1-(t-dt)) \times X_0 + (t-dt) \times X_1 Xtdt=(1(tdt))×X0+(tdt)×X1,我们就能得到前一个状态的估计了,也就是 x t − d t x_{t-dt} xtdt
因此 f f f的拟合对象有3个选择:

  • 直接拟合 x t − d t x_{t-dt} xtdt,毕竟我们有了训练数据pair对,我们直接拟合前一步的状态值即可。
  • 拟合 X 0 X_0 X0
  • 拟合 X 1 X_1 X1

然而,论文DDPM中证明了这三种在原理上是等价的(经过一系列的公式换算可以等价,本篇文章目的是使用简单的方式介绍DDPM,因此不进行展开描述)。同时作者经过实验,认为拟合 X 1 X_1 X1效果较好。因此
x 1 e s t = f ( x t , t ) x 0 e s t = x t − t × x 1 1 − t x t − d t e s t = ( 1 − ( t − d t ) ) × x 0 e s t + ( t − d t ) × x 1 e s t x_1^{est} = f(x_t, t) \\ x_0^{est} = \frac{x_t - t \times x_1}{1-t} \\ x_{t-dt}^{est} = (1-(t-dt)) \times x_0^{est} + (t-dt) \times x_1^{est} x1est=f(xt,t)x0est=1txtt×x1xtdtest=(1(tdt))×x0est+(tdt)×x1est
首先模型估计出 x 1 e s t x_1^{est} x1est,利用公式变换形式,进而估计出 x 0 e s t x_0^{est} x0est;最后仍然是根据公式得到 x t − d t e s t x_{t-dt}^{est} xtdtest。接着这个过程只要重复 t / d t t / dt t/dt次,我们就可以得到将分布 X 1 X_1 X1变成 X 0 X_0 X0的轨迹,实现了完整的反向过程。

接着,我们以X1和X0的点集数据为例,训练一个 f f f模型,同时观察测试集上的轨迹变化,是否符合我们的预期。

X 1 X_1 X1为在半圆上的点, X 0 X_0 X0为x轴上的点,

定义公式 X t = ( 1 − t ) × X 0 + t × X 1 X_t = (1-t) \times X_0 + t \times X_1 Xt=(1t)×X0+t×X1

def get_x_t(t, x0, x1):
    return x0 * (1-t) + x1 * 

公式变换,定义 x 0 x_0 x0的有偏估计

def get_x0(xt, t, x1):
    return (xt - t * x1) / (1 - t + 1e-7)

定义 f ( x t , t ) f(x_t, t) f(xt,t),因为我们的任务很简单,使用一个简单的4层mlp足够了

class mlp(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.nn = torch.nn.Sequential(
            torch.nn.Linear(2+1, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 2),
        )
      
    def forward(self, xt, t):
        x_t = torch.cat([xt, t], dim=1)
        return self.nn(x_t)
    
model = mlp()
model.cuda()
model.train()

定义反向采样过程
x 1 e s t = f ( x t , t ) x 0 e s t = x t − t × x 1 1 − t x t − d t e s t = ( 1 − ( t − d t ) ) × x 0 e s t + ( t − d t ) × x 1 e s t x_1^{est} = f(x_t, t) \\ x_0^{est} = \frac{x_t - t \times x_1}{1-t} \\ x_{t-dt}^{est} = (1-(t-dt)) \times x_0^{est} + (t-dt) \times x_1^{est} x1est=f(xt,t)x0est=1txtt×x1xtdtest=(1(tdt))×x0est+(tdt)×x1est

class DDPM():
    def __init__(self, model, total_step=11) -> None:
        self.total_step = total_step
        self.model = model
       
    @torch.no_grad()
    def sample(self, x1):
        step = torch.linspace(0.0, 0.95, self.total_step).flip(0).to(x1.device)
        self.model.eval()
        x1[:, 1] = x1[:, 1] * 0.95  # 消除当t为1时,get_x0中的分母影响
        bs = x1.shape[0]
        traj = []
        xt = x1
        traj.append(xt)
        for step_idx in range(self.total_step):
            # step从0.95变到0
            x1 = self.model(xt, step[step_idx].view(1, 1).expand(bs, -1))
            x0 = get_x0(xt, step[step_idx].item(), x1)
            if step_idx < (self.total_step - 1):
                x_t_1 = get_x_t(step[step_idx + 1], x0, x1)
                # 将计算的前一时刻状态重新赋值给x_t
                xt = x_t_1
                traj.append(xt)
        # 最终的x0是我们所需要的反向过程的最终输出
        traj.append(x0)
        return x0, traj
    
ddpm_sample = DDPM(mlp, total_step=100)

定义训练过程


def train_loop():
   
    optim = torch.optim.AdamW(model.parameters(), lr=1e-4)
   
    # 训练2000步
    loss_list = []
    for idx in tqdm(range(2000)):
        # 随机生成一些数据
        x_0_data_x  = np.random.uniform(-4, 4, (1000,))
        x_0_data = np.stack([x_0_data_x, np.zeros_like(x_0_data_x)], axis=1) # 1000, 2
     
        x_1_data_x = np.random.uniform(-4, 4, (1000,))
        x_1_data_y = (16 - x_1_data_x ** 2) **0.5
        x_1_data = np.stack([x_1_data_x, x_1_data_y], axis=1) # 1000,2
        
        x_0_data = torch.from_numpy(x_0_data).float().cuda()
        x_1_data = torch.from_numpy(x_1_data).float().cuda()
        n_data = x_1_data.shape[0]
        # 随机生成一些时刻
        time_data = torch.rand((n_data, 1)).to(x_0_data.device)
        x_t  =  get_x_t(time_data, x_0_data, x_1_data)
        target = x_1_data  # 拟合对象为X1
        
        pred = model(x_t, time_data)
        loss = torch.nn.functional.mse_loss(pred, target)
        # print(f'loss:{loss:.3f}, {pred[:10]}')
        optim.zero_grad()
        loss.backward()
        optim.step()
        loss_list.append(loss.item())
    return loss_list, model
# 开始训练
loss_list, model = train_loop()
plt.plot(np.arange(len(loss_list)), loss_list)
plt.savefig('loss_curve.jpg')

定义测试过程

ddpm_sample = DDPM(model, total_step=100)
# 测试, 重新生成一批X1, 一共20个点
x_1_data_x = np.random.uniform(-4, 4, (20,))
x_1_data_y = (16 - x_1_data_x ** 2) **0.5
x_1_data = np.stack([x_1_data_x, x_1_data_y], axis=1) # 20,2
x_1_data = torch.from_numpy(x_1_data).float().cuda()

x0, traj = ddpm_sample.sample(x_1_data)
figure = plt.figure()
for t in traj[-1:]:
    t = t.cpu().numpy()
    plt.scatter(t[:, 0], t[:, 1])
x_1_data = x_1_data.cpu().numpy()
plt.scatter(x_1_data[:, 0], x_1_data[:, 1], c='r')
figure.savefig("trajectory.jpg") 

loss曲线
在这里插入图片描述
下面是轨迹图,最上面的红色点是分布 X 1 X_1 X1,都在一个半圆上面。顺着轨迹上的100个中间状态,慢慢变成了最下面的蓝色点。蓝色点虽然不完全在X轴上,但都大致离X轴接近,并且数值范围在-4到4,满足 X 0 X_0 X0的分布特点。观察轨迹符合我们的预期,模型训练成功。
在这里插入图片描述

回到图像生成DDPM

DDPM的前向公式为
在这里插入图片描述
其实就是
x t = a ‾ t x 0 + ( 1 − a ‾ t I x_t = \sqrt{\overline{a}_t} x_0 + (1 - {\overline{a}_t} I xt=at x0+(1atI
我们把 I I I当成 X 1 X_1 X1,那么DDPM前向公式的形式就和我之前介绍的一致了。

再看DDPM中如何得到 x t − 1 x_{t-1} xt1
在这里插入图片描述
你会发现其实就是两项相加,第一项是关于 x 0 x_0 x0 x t x_t xt的加权,这个也和我们的推导 x t − d t e s t = ( 1 − ( t − d t ) ) × x 0 e s t + ( t − d t ) × x 1 e s t x_{t-dt}^{est} = (1-(t-dt)) \times x_0^{est} + (t-dt) \times x_1^{est} xtdtest=(1(tdt))×x0est+(tdt)×x1est类似,只是他还有第三项 β t \beta_t βt,而这一项是已知的数值。

你可能会好奇,这个前向公式是如何得来的呢?
你还可能会好奇,建立在马尔科夫链假设上的ddpm,为何优化目标可以被简化到直接对x_1$进行拟合呢?
这些内容,在未来继续分享。

本文总结

本文从一个简化的问题入手,用两个不同分布的点集这种简单的数据类型作为样例,讲解了DDPM问题的建模过程,整个建模过程的核心是设计前向公式,并围绕着前向公式变换为推理过程,进而引导读者思考模型在推理过程中起到的作用。 并用python代码做了训练和测试的实验,最终的结果也符合我们的预期。从理论和实践上较为完整的介绍了DDPM的核心思想和使用方法。

本文为作者原创,转载请注明出处

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值