DDPM | 扩散模型代码详解【较为详细细致!!!】

1、UNet网络结构

UNet网络的总体框架如下,右边是UNet网络的整体框架,左边是residual网络和attention网络,

在这里插入图片描述

下面是UNet网络的详解结构图,左边进行有规律地残差、下采样、attention,右边也是有规律地残差、上采样、attention,相关的代码在图中给出,

在这里插入图片描述

1.1 residual网络和attention网络的细节

熟悉CNN的同学应该能看懂下图中的大部分过程。其中的 t 是时间从0到1000的随机值,假如是888,经过Positional Embedding输出长度是128的向量,下面再经过全连接层和silu层等,下面会详细讲解Positional Embedding、residual网络和attention网络,

在这里插入图片描述

1.2 t 的作用

1、和原图像一起,计算出 t 时刻的图像 x t = 1 − α t ‾ ϵ + α t ‾ x 0 x_t=\sqrt{1-\overline{\alpha_t}}\epsilon+\sqrt{\overline{\alpha_t}}x_0 xt=1αt ϵ+αt x0
2、将 t 进行编码,编码后,加到模型中,使模型学习到当前在哪个时刻

在这里插入图片描述

1.3 DDPM 中的 Positional Embedding 的使用

左图是Transformer的Positional Embedding,行索引代表第几个单词,列索引代表每个单词的特征向量,右图是DDPM的Positional Embedding,DDPM的Positional Embedding和Transformer的Positional Embedding的区别是DDPM的Positional Embedding并不是给每个词位置编码的,只需要在1000行中随机取出一行就可以了;另一个区别是DDPM的Positional Embedding并没有按照奇数位和偶数位进行拼接,而是按照前后的sin和cos进行拼接的,虽然拼接方式不同,但是最终的效果是一样的。如下图所示,
位置编码只要能保证每一行的唯一性,以及每一行和其他行的关系性就可以了。

在这里插入图片描述

1.4 DDPM 中的 Positional Embedding 代码

代码:

class PositionalEmbedding(nn.Module):
    __doc__ = r"""..."""
    def init (self, dim, scale=1.0):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim  # 特征向量
        self.scale = scale  # 正弦函数和余弦函数的周期不做调整

    def forward(self, x):  # x:表示t,从0-1000中随机出来的一个数值,因为设置batch-size=2,所以假设x:tensor([645,958])
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * - emb)
        emb = torch.outer(x * self.scale, emb)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

代码解释:

下图中的 e m b 2 × 64 emb_{2\times64} emb2×64中的2表示batch-size等于2,

在这里插入图片描述

使用位置:

在这里插入图片描述

1.5 residual block

原代码:

class ResidualBlock(nn.Module):
    __doc__ = r"""Applies two conv blocks with resudual connection. Adds time and class conditioning by adding bias after first convolution.

    Input:
        x: tensor of shape (N, in_channels, H, W)
        time_emb: time embedding tensor of shape (N, time_emb_dim) or None if the block doesn't use time conditioning
        y: classes tensor of shape (N) or None if the block doesn't use class conditioning
    Output:
        tensor of shape (N, out_channels, H, W)
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        time_emb_dim (int or None): time embedding dimension or None if the block doesn't use time conditioning. Default: None
        num_classes (int or None): number of classes or None if the block doesn't use class conditioning. Default: None
        activation (function): activation function. Default: torch.nn.functional.relu
        norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"
        num_groups (int): number of groups used in group normalization. Default: 32
        use_attention (bool): if True applies AttentionBlock to the output. Default: False
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        dropout,
        time_emb_dim=None,
        num_classes=None,
        activation=F.relu,
        norm="gn",
### DDPM(去噪扩散概率模型)用于图像增强的代码实现 #### 使用PyTorch实现DDPM进行图像增强 为了利用DDPM进行图像增强,可以基于已有的框架并做适当调整。下面是一个简化版的Python代码示例,展示了如何使用PyTorch构建一个基本的DDPM架构来进行图像处理。 ```python import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader import numpy as np class MLPDiffusion(torch.nn.Module): def __init__(self, input_dim, hidden_dims, output_dim, time_steps): super().__init__() layers = [] current_dim = input_dim + 1 # Add one dimension for the timestep embedding for hdim in hidden_dims: layers.append(torch.nn.Linear(current_dim, hdim)) layers.append(torch.nn.ReLU()) current_dim = hdim layers.append(torch.nn.Linear(current_dim, output_dim)) self.network = torch.nn.Sequential(*layers) self.time_embedding = torch.nn.Embedding(time_steps, input_dim) def forward(self, x, timesteps): t_emb = self.time_embedding(timesteps).view(-1, 1) # Shape (batch_size, 1) xt = torch.cat([x.view(x.shape[0], -1), t_emb], dim=-1) return self.network(xt) def get_dataloader(batch_size=128): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) return dataloader def sample_noise_like(data_shape, device="cpu"): return torch.randn(data_shape, device=device) def p_sample(model, img_shape, beta_start=0.0001, beta_end=0.02, T=1000, device="cpu"): betas = torch.linspace(beta_start, beta_end, T).to(device) alphas = 1. - betas alpha_bars = torch.cumprod(alphas, axis=0) imgs = sample_noise_like(img_shape, device=device) with torch.no_grad(): for i in reversed(range(T)): t = torch.full((img_shape[0], ), i, dtype=torch.long, device=device) noise_pred = model(imgs, t) if i > 0: z = torch.randn_like(imgs) else: z = 0 prev_alpha_bar = 1 if i == 0 else alpha_bars[i-1] sqrt_recip_alphas = torch.sqrt(1 / alphas[i]) sqrt_one_minus_alphas = torch.sqrt(1 - alphas[i]) imgs = ( sqrt_recip_alphas * (imgs - sqrt_one_minus_alphas * noise_pred) + torch.sqrt(betas[i]) * z ) return imgs.clip(-1., 1.) model = MLPDiffusion(input_dim=784, hidden_dims=[512, 512], output_dim=784, time_steps=1000).cuda() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) dataloader = get_dataloader() for epoch in range(num_epoch): # num_epoch defined elsewhere or set here directly for step, (images, labels) in enumerate(dataloader): images = images.cuda() # Move data to GPU # Forward pass and loss calculation would go here... ... ``` 此段代码定义了一个简单的MLP网络作为基础结构,并实现了前向传播过程以及采样函数`p_sample()`来逐步去除噪声[^1]。注意这里的例子是以MNIST手写数字数据集为例说明,实际应用时需根据具体需求调整参数设置和预处理方式。 对于想要进一步探索或优化该模型的研究者来说,可以从以下几个方面入手: - 尝试不同的神经网络架构; - 改变beta值的时间表以影响每一步加上的噪音量级; - 探索更复杂的损失函数形式; - 应用更加高效的优化算法。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值