去噪扩散模型

Denoising Diffusion Probabilistic Models

图像扩散模型是一种生成模型,它基于概率扩散过程来生成新的图像。请添加图片描述
核心步骤包括:(1)前向扩散过程;(2)逆向扩散过程

前向扩散过程(正向过程)

  • 这个过程从原始图像开始,逐步向图像添加噪声,直到图像完全变成噪声。这是一个由数据分布逐渐转变为简单噪声分布(通常是高斯分布)的过程。
  • 每一步添加噪声的操作都是可逆的,意味着理论上可以通过逆过程恢复原始图像。
  • 前向扩散过程可以数学上表示为一个马尔可夫链,其中每个状态是向图像添加一定量噪声的结果。

逆向扩散过程(反向过程)

  • 逆向过程的目标是从噪声图像恢复出原始图像。这个过程通过逐步减少噪声来实现,每一步都尝试预测并去除前一步添加的噪声。
  • 逆向过程也是一个马尔可夫链,但是它的转换是学习得到的,目的是逆转前向过程中的噪声添加。

扩散模型的训练涉及到学习逆向过程中的转换。这通常通过最小化原始图像和通过逆向过程生成的图像之间的差异来实现。训练过程中,模型需要学习如何从部分噪声的图像中预测缺失的部分,这通常通过比较噪声图像和去噪图像来完成。

在这里插入图片描述

代码实现

class DDPM(nn.Module):
    def __init__(self, nn_model, betas, n_T, device, drop_prob=0.1):
        super(DDPM, self).__init__()
        self.nn_model = nn_model.to(device)

        # 注册ddpm_schedules函数生成的字典中的缓冲区
        # 例如,后续可以访问self.sqrtab
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T  # 总的时间步数
        self.device = device  # 设备(CPU或GPU)
        self.drop_prob = drop_prob  # Dropout概率
        self.loss_mse = nn.MSELoss()  # 均方误差损失函数

    def forward(self, x, c):
        """
        训练时使用此方法,随机采样时间步和噪声
        """
        
        # 随机采样时间步t,对于批次中的每个样本,t在1到n_T(包括两端)之间均匀分布
        _ts = torch.randint(1, self.n_T+1, (x.shape[0],)).to(self.device)  

        # 生成与输入x形状相同的标准正态分布噪声
        noise = torch.randn_like(x)  
        
        # 根据当前时间步和噪声计算x_t
        # x_t是输入x的噪声版本,噪声按时间t的累积产品(1-beta)的倒数的平方根缩放(sqrtab),
        # 输入x按累积产品alpha的平方根缩放(sqrtmab)
        x_t = (
            self.sqrtab[_ts, None, None, None] * x
            + self.sqrtmab[_ts, None, None, None] * noise
        )  
        
        # 模型应从这个x_t预测“误差项”。
        # 这个误差项是实际添加的噪声和模型预测的噪声之间的差异,
        # 将用于计算损失。这个前向方法返回的损失就是我们计算的损失。

        # 为c中的每个元素创建一个以一定概率的dropout上下文掩码。
        # 这个掩码用于随机将c中的一些元素设置为零,这有助于防止过拟合。
        context_mask = torch.bernoulli(torch.zeros_like(c)+self.drop_prob).to(self.device)
        
        # 返回实际噪声和模型预测噪声之间的均方误差(MSE)。
        # 预测的噪声是通过将x_t和c传递给神经网络模型nn_model获得的,
        # 同时传递标准化的时间步_ts / self.n_T和上下文掩码。
        return self.loss_mse(noise, self.nn_model(x_t, c, _ts / self.n_T, context_mask))

参考代码链接

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

云朵不吃雨

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值