Denoising Diffusion Probabilistic Models(DDPM)简易代码分析(保姆级逐行讲解)

前言

  上篇文章我们介绍了DDPM的基本原理,对原理不熟悉的小伙伴建议先看下基本原理介绍部分:Diffusion Model(DDPM)保姆级解析——附代码实现,不然代码可能会看不懂。今天我们来实现一个简单的DDPM的代码,训练一个扩散模型,让他自己生成一个形状为S的图。代码总共分为两部分:tran.pysample.py,已经在代码的内部给出了逐行注释,正文部分对代码做简要的分析介绍。完整代码可权重文件可以通过文章最后的github链接获取。

1 Train.py

  • 引入头文件数据生成
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve  # 生成S形二维数据点 https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_s_curve.html
import torch
import torch.nn as nn
from tqdm import tqdm

## ----------------------------- 1、生成数据,(10000, 2)的数据点集,组成一个S形 ----------------------------- ##
s_curve, _ = make_s_curve(10 ** 4, noise=0.1)  # 生成10000个数据点,形状为S形并且带有噪声,shape为(10000,3),形状是3维的
s_curve = s_curve[:, [0, 2]] / 10.0 # 选择数据的第一列和第三列,并进行缩放
print("shape of s:", np.shape(s_curve))
dataset = torch.Tensor(s_curve).float()

  上面的数据生成部分实现创建一个S形的散点图,由于make_s_curve直接创建的散点图是三维的,即下面这样的散点图:
在这里插入图片描述
  我们这次要训练的是下面这样的二维的散点图,因此这里数据生成部分做了截取s_curve[:, [0, 2]] / 10.0并进行了缩放,得到了下面这样的二维散点图:
在这里插入图片描述

  • 确定超参数
## ----------------------------- 2、确定超参数的值 ----------------------------- ##
# 采样时间步总长度 t
num_steps = 100
 
# 制定每一步的beta
betas = torch.linspace(-6, 6, num_steps) # 在-6到6之间生成100个等间距的值
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5 # 将betas缩放到合适的范围
 
# 计算alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt等变量的值
alphas = 1 - betas # 计算每一步的alpha值
alphas_prod = torch.cumprod(alphas, 0) # 每个t时刻的alpha值的累积乘积
# alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0)
alphas_bar_sqrt = torch.sqrt(alphas_prod) # 计算累积乘积的平方根
one_minus_alphas_bar_log = torch.log(1 - alphas_prod) # 计算1减去累积乘积的对数
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod) # 计算1减去累积乘积的平方根

这部分的代码根据论文提供的公式直接创建超参数值,具体可以参考下来的train和sample算法。
在这里插入图片描述
在这里插入图片描述

  • 确定扩散前向过程任意时刻的采样值
## ----------------------------- 3、确定扩散前向过程任意时刻的采样值 x[t]: x[0] + t --> x[t] ----------------------------- ##此代码并未使用这个
def q_x(x_0, t):
    """
    x[0] + t --> x[t]
    :param x_0:初始数据
    :param t:任意时刻
    :return:
    """
    noise = torch.randn_like(x_0)
    alphas_t = alphas_bar_sqrt[t]
    alphas_1_m_t = one_minus_alphas_bar_sqrt[t]
    x_t = alphas_t * x_0 + alphas_1_m_t * noise
    return x_t

这部分在训练的时候没用到,直接创建了,训练的时候未使用这个函数。

  • 创建网络模型UNet
## ----------------------------- 4、编写求逆扩散过程噪声的模型U-Net(这里使用的是MLP模拟U-Net,官方使用的是U-Net) x[t] + t --> noise_predict----------------------------- ##预测噪声
class MLPDiffusion(nn.Module):
    def __init__(self, n_steps, num_units=128):
        super(MLPDiffusion, self).__init__()
 
        self.linears = nn.ModuleList(
            [
                nn.Linear(2, num_units),
                nn.ReLU(),
                nn.Linear(num_units, num_units),
                nn.ReLU(),
                nn.Linear(num_units, num_units),
                nn.ReLU(),
                nn.Linear(num_units, 2),
            ]
        )
        self.step_embeddings = nn.ModuleList(
            [
                nn.Embedding(n_steps, num_units),
                nn.Embedding(n_steps, num_units),
                nn.Embedding(n_steps, num_units),
            ]
        )
 
    def forward(self, x, t):
        #  x = x[0]
        for idx, embedding_layer in enumerate(self.step_embeddings):
            t_embedding = embedding_layer(t)
            x = self.linears[2 * idx](x)
            x += t_embedding
            x = self.linears[2 * idx + 1](x)
        x = self.linears[-1](x)
 
        return x

这里为了减少训练时间,使用一个简单的MLP网络代替UNet网络。

  • 损失函数
## ----------------------------- 损失函数 = 真实噪声eps与预测出的噪声noise_predict 之间的loss ----------------------------- ##
def diffusion_loss_fn(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps):
    """对任意时刻t进行采样计算loss"""
    batch_size = x_0.shape[0]
 
    # 对一个batchsize样本生成随机的时刻t, t的形状是torch.Size([batchsize, 1])
    t = torch.randint(0, n_steps, size=(batch_size // 2,)) # 随机生成时间步t,一半时间
    t = torch.cat([t, n_steps - 1 - t], dim=0) # 创建对称的时间步
    t = t.unsqueeze(-1) # 添加一个维度,使t的形状为(batch_size, 1)
 
    ## 1) 根据 alphas_bar_sqrt, one_minus_alphas_bar_sqrt --> 得到任意时刻t的采样值x[t]
    # x0的系数
    a = alphas_bar_sqrt[t] # 获取时间步t对应的alphas_bar_sqrt值
    # 噪声eps的系数
    aml = one_minus_alphas_bar_sqrt[t] # 获取时间步t对应的one_minus_alphas_bar_sqrt值
    # 生成生成与x_0形状相同的随机噪声e
    e = torch.randn_like(x_0)
    # 计算任意时刻t的采样值
    x = x_0 * a + e * aml
 
    ## 2) x[t]送入U-Net模型,得到t时刻的随机噪声预测值,这里是用UNet直接预测噪声,输入网络的参数是加上噪声的图像和时间t,网络返回预测所加的噪声
    output = model(x, t.squeeze(-1))
 
    ## 3)计算真实噪声eps与预测出的噪声之间的loss
    loss = (e - output).square().mean()
    return loss
  • 训练模型
## ----------------------------- 训练模型 ----------------------------- ##

if __name__ == "__main__":
    print('Training model...')
    batch_size = 128
    num_epoch = 4000
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    model = MLPDiffusion(num_steps)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    for t in tqdm(range(num_epoch),desc="Traing epoch"):
        for idx, batch_x in enumerate(dataloader):
            loss = diffusion_loss_fn(model, batch_x, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, num_steps)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
            optimizer.step()
    
        if (t % 100 == 0):
            print(loss)
            torch.save(model.state_dict(), 'model_{}.pth'.format(t))

根据自己电脑的性能调整batch_size的大小,这里我们展示一个batch_size128的散点图看一下:
在这里插入图片描述

Sample.py

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve  # 生成S形二维数据点 https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_s_curve.html
import torch
import torch.nn as nn
from tqdm import tqdm

from train import MLPDiffusion

## ----------------------------- 1、生成数据,(10000, 2)的数据点集,组成一个S形 ----------------------------- ##
s_curve, _ = make_s_curve(10 ** 4, noise=0.1)  # 10000个数据点
s_curve = s_curve[:, [0, 2]] / 10.0
print("shape of s:", np.shape(s_curve))
dataset = torch.Tensor(s_curve).float()

## ----------------------------- 2、确定超参数的值 ----------------------------- ##
# 采样时间步总长度 t
num_steps = 100
 
# 制定每一步的beta
betas = torch.linspace(-6, 6, num_steps)
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5
 
# 计算alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt等变量的值
alphas = 1 - betas
alphas_prod = torch.cumprod(alphas, 0)
alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)


def p_sample(model, x, t, betas, one_minus_alphas_bar_sqrt):
    """
    从x[t]采样t-1时刻的重构值x[t-1],根据论文中的采样公式计算单步的采样
    :param model:
    :param x: x[T]
    :param t:
    :param betas:
    :param one_minus_alphas_bar_sqrt:
    :return:
    """
    ## 1) 求出 bar_u_t
    t = torch.tensor([t])
    coeff = betas[t] / one_minus_alphas_bar_sqrt[t] # 这里先计算采样公式中的一部分参数,方便后面表示,看不懂的可以直接对着论文公式看
    # 送入U-Net模型,得到t时刻的随机噪声预测值 eps_theta
    eps_theta = model(x, t)
    mean = (1 / (1 - betas[t]).sqrt()) * (x - (coeff * eps_theta))
 
    ## 2) 得到 x[t-1]
    z = torch.randn_like(x)
    sigma_t = betas[t].sqrt()
    sample = mean + sigma_t * z
    return sample

def p_sample_loop(model, noise_x_t, n_steps, betas, one_minus_alphas_bar_sqrt):
    """
    从x[T]恢复x[T-1]、x[T-2]|...x[0] 的循环
    :param model:
    :param shape:数据集的形状,也就是x[T]的形状
    :param n_steps:
    :param betas:
    :param one_minus_alphas_bar_sqrt:
    :return: x_seq由x[T]、x[T-1]、x[T-2]|...x[0]组成, cur_x是从噪声中生成的图片
    """
    # 得到噪声x[T]
    cur_x = noise_x_t # 初始化当前的x为噪声x[T]
    x_seq = [noise_x_t] # 初始化x序列为第一个元素为x[T],也就是纯噪声
    # 从x[T]恢复x[T-1]、x[T-2]|...x[0]
    for i in reversed(range(n_steps)):
        cur_x = p_sample(model, cur_x, i, betas, one_minus_alphas_bar_sqrt)
        x_seq.append(cur_x)
    return x_seq, cur_x

# 1) 加载训练好的diffusion model
model = MLPDiffusion(num_steps)
model.load_state_dict(torch.load('./checkpoints_cpu/model_3900.pth'))

# 2) 生成随机噪声x[T]
noise_x_t = torch.randn(dataset.shape)

# 3) 根据随机噪声逆扩散为x[T-1]、x[T-2]|...x[0] + 图片x[0]
x_seq, cur_x = p_sample_loop(model, noise_x_t, num_steps, betas, one_minus_alphas_bar_sqrt)

# 4) 绘制并保存图像
def plot_samples(x_seq, cur_x):
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    
    # 绘制 x_seq
    for i, x in enumerate(x_seq):
        if i % 10 == 0:  # 每10个时间步绘制一次
            ax[0].scatter(x.detach().numpy()[:, 0], x.detach().numpy()[:, 1], label=f'Step {i}', alpha=0.5)
    ax[0].legend()
    ax[0].set_title('x_seq')
    
    # 绘制 cur_x
    ax[1].scatter(cur_x.detach().numpy()[:, 0], cur_x.detach().numpy()[:, 1], color='red')
    ax[1].set_title('cur_x')
    
    plt.savefig('samples_plot.png')
    plt.show()

plot_samples(x_seq, cur_x)

  采样部分的代码比较简单,由于我创建了两个文件:train.pysample.py,因此在sample.py文件中重新对超参数进行了一次创建。也可以直接使用train.py中的超参数。完整的代码和权重文件可以在我的github进行下载。
  以上就是对DDPM的代码分析,欢迎各位大佬批评指正。

  • 6
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

I松风水月

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值