结合代码详细讲解DDPM的训练和采样过程

本篇文章结合代码讲解Denoising Diffusion Probabilistic Models(DDPM),首先我们先不关注推导过程,而是结合代码来看一下训练和推理过程是如何实现的,推导过程会在别的文章中讲解;首先我们来看一下论文中的算法描述。DDPM分为扩散过程和反向扩散过程,也就是训练过程和采样过程;
代码来自https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-

请添加图片描述

1. 训练(扩散)过程

首先我们来逐个看一下训练过程中的所有符号的含义:

x 0 x_0 x0是真实图像;

t 是扩散的步数,取值范围从1到T;

ϵ \epsilon ϵ是从标准正态分布中采样的噪声;

ϵ θ \epsilon_\theta ϵθ是模型,用于预测噪声,其输入是 x t x_t xt和 t;

x t x_t xt的表达式如下:

在这里插入图片描述

x t x_t xt x 0 x_0 x0加噪获得,其中 α t ‾ \overline{\alpha_{t}} αt是常数
因此训练过程总结成一句话就是,向真实图像 x 0 x_0 x0中加噪,获得加噪后的图像 x t x_t xt;然后将 x t x_t xt和t输入到网络中,得到预测的噪声,通过使得网络预测的噪声和真实加入的噪声更接近,完成网络的训练。
从另一个角度,我们也可以这么理解:向 x 0 x_0 x0中加噪的过程,可以理解成是编码的过程,加噪之后获取到了图像的中间表示 x t x_t xt;而预测噪声的过程则是从 x t x_t xt解码的过程,只是并没有选择直接解码出 x 0 x_0 x0,而是解码出加入的噪声,也就是残差。请添加图片描述

下面来看一下代码,跟上面讲解的过程是一一对应的,首先在初始化函数中我们需要准备好每个时刻t所需要的常数量 α t ‾ \sqrt{\overline{\alpha_{t}}} αt 1 − α t ‾ \sqrt{1-\overline{\alpha_{t}}} 1αt 。这些参数最原始来源于一个超参数 β t \beta_t βt,这个参数为加入噪声的方差。他们的关系如下:

[图片]

所以很容易理解代码中的sqrt_alphas_bar就是 α t ‾ \sqrt{\overline{\alpha_{t}}} αt ,sqrt_one_minus_alphas_bar 就是 1 − α t ‾ \sqrt{1-\overline{\alpha_{t}}} 1αt
接着在forward函数中,首先从[0,T]中随机选取一个时刻t,然后从标准正态分布中采样一个噪声,shape和 x 0 x_0 x0一致,接着获取 x t x_t xt

x_t = (
extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)

然后将然后将 x t x_t xt和t输入到网络中,得到预测的噪声:

self.model(x_t, t)

计算Loss函数:

loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')

训练过程的完整代码:

class GaussianDiffusionTrainer(nn.Module):
    def __init__(self, model, beta_1, beta_T, T):
        super().__init__()

        self.model = model
        self.T = T

        self.register_buffer(
            'betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
    # 每次forward时,给每个样本随机取一个t,并采样一个高斯噪声,然后根据t从sqrt_alphas_bar和sqrt_one_minus_alphas_bar中取出对应的系数,然后根据x_0和采样的高斯噪声生成x_t。然后将x_t和t输入到噪声预测网络中,得到预测的噪声。预测出的噪声输入到网络中,计算loss,从而实现model的训练。
    def forward(self, x_0):
        """
        Algorithm 1.
        """
        t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device) # 给batch中每个样本取一个t,取值范围是[0, 1000]
        noise = torch.randn_like(x_0) # 采样高斯噪声,shape与x_0一致
        x_t = (
            extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
            extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
        loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
        return loss

2. 推理(反向)过程

首先我们来明确一下,反向过程的目标是什么。反向过程的目标是逐步从一张噪声图像 x T x_T xT中恢复出一张图像,表示成 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt),我们没法推导出 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt),但是 p ( x t − 1 ∣ x t , x 0 ) p(x_{t-1}|x_t, x_0) p(xt1xt,x0)是可以用贝叶斯公式推导出来的,其也是一个高斯分布,并且可以把 x 0 x_0 x0化简掉。最终 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)分布的均值为:
请添加图片描述

方差为 β t \beta_t βt
因此我们可以从 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)分布中采样出一个 x t − 1 x_{t-1} xt1
请添加图片描述
这种采样方式叫做重参数技巧,如果不了解可以看如下介绍:
在这里插入图片描述
注意:是标准差与标准正态分布相乘,而不是方差;

因为DDPM的方差固定为 β t \beta_t βt,所以反向过程的重点就是学习出这个分布的方差,从上面的表达式可以看出分布的均值与 x t x_t xt和当前时刻加入的噪声 ϵ t \epsilon_t ϵt有关,而我们的模型可以完成对 ϵ t \epsilon_t ϵt的预测,只要将 x t x_t xt和 t 输入进去模型中即可。代码中描述的过程与此一一对应。

注意代码中存在三个噪声,其中eps是模型预测出来的,其和分布的均值计算相关;forward函数中的noise也是噪声,但是它是从标准正态分布中采样的,用于从 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt1xt)采样;forward函数中的 x T x_T xT是整个反向过程的输入,也是从标准正态分布中采样的。

# 反向过程是从纯噪声x_T开始逐步去噪以生成样本,此过程也是一个高斯分布,均值和x_t以及预测出的噪声相关,方差在ddpm中没有进行学习,直接使用的是后验分布q(x_t-1|x_t,x_0)的方差。
class GaussianDiffusionSampler(nn.Module):
    def __init__(self, model, beta_1, beta_T, T):
        super().__init__()

        self.model = model
        self.T = T

        self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)
        alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]

        self.register_buffer('coeff1', torch.sqrt(1. / alphas))
        self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))

        self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))

    def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return (
            extract(self.coeff1, t, x_t.shape) * x_t -
            extract(self.coeff2, t, x_t.shape) * eps
        )

    def p_mean_variance(self, x_t, t):
        # below: only log_variance is used in the KL computations
        var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
        var = extract(var, t, x_t.shape)

        eps = self.model(x_t, t)
        xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)

        return xt_prev_mean, var

    def forward(self, x_T):
        """
        Algorithm 2.
        """
        x_t = x_T # 输入是一个标准正态分布噪声
        # 从T到1进行reverse过程
        for time_step in reversed(range(self.T)):
            print(time_step)
            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
            mean, var= self.p_mean_variance(x_t=x_t, t=t) 
            # no noise when t == 0
            if time_step > 0:
                noise = torch.randn_like(x_t)
            else:
                noise = 0
            x_t = mean + torch.sqrt(var) * noise # 从q(x_t-1|x_t)中采样
            assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
        x_0 = x_t
        return torch.clip(x_0, -1, 1)
DDPM(Denoising Diffusion Probabilistic Model)是一种生成式模型,可以用于图像生成图像去噪等任务。下面是用PyTorch框架训练自己数据集的DDPM代码示例: ``` import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import torchvision.datasets as datasets from torch.utils.data import DataLoader # 定义DDPM模型 class DDPM(nn.Module): def __init__(self, in_channels, out_channels): super(DDPM, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.conv1 = nn.Conv2d(in_channels, 128, 3, stride=1, padding=1, bias=False) self.conv2 = nn.Conv2d(128, 128, 3, stride=1, padding=1, bias=False) self.conv3 = nn.Conv2d(128, 128, 3, stride=1, padding=1, bias=False) self.conv4 = nn.Conv2d(128, out_channels, 3, stride=1, padding=1, bias=False) self.register_buffer('eps', torch.tensor(1e-8)) def forward(self, x): noise = torch.randn_like(x) out = self.conv1(noise) out = F.relu(out) out = self.conv2(out) out = F.relu(out) out = self.conv3(out) out = F.relu(out) out = self.conv4(out) out = out / torch.sqrt(torch.mean(out**2, dim=[1,2,3], keepdim=True) + self.eps) return x + out # 定义训练函数 def train(model, train_loader, optimizer, criterion, device): model.train() for i, (input, _) in enumerate(train_loader): input = input.to(device) optimizer.zero_grad() output = model(input) loss = criterion(output, input) loss.backward() optimizer.step() if i % 10 == 0: print('Step [{}/{}], Loss: {:.4f}'.format(i, len(train_loader), loss.item())) # 定义数据集数据加载器 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) train_dataset = datasets.ImageFolder(root='./train', transform=transform) train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) # 定义模型、损失函数、优化器设备 model = DDPM(1, 1).to('cuda') criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 训练模型 num_epochs = 10 for epoch in range(num_epochs): train(model, train_loader, optimizer, criterion, 'cuda') # 保存模型 torch.save(model.state_dict(), 'ddpm.pth') ``` 在代码中,我们定义了一个DDPM模型,包含四个卷积层一个标准差归一化层,用于对输入数据进行处理。然后,我们使用PyTorch自带的`ImageFolder`类加载训练集数据,并使用`DataLoader`类构建数据加载器。接着,我们定义了一个训练函数`
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值