扩散模型(Diffusion Model)原理深度解读及代码实现

1. 简 介

扩散模型(Diffusion Model)的起源可以追溯到概率图模型和统计物理学领域。它最初的灵感来自于对热扩散和布朗运动等物理现象的研究,这些过程描述了系统如何从一个高能量、不均匀的状态(如高温区)逐步过渡到一个低能量、平衡的状态(如均匀温度分布),后来被引入机器学习和生成模型领域。

扩散模型的早期形式可以追溯到多种研究工作,Sohl-Dickstein 等人(2015年) 在论文《Deep Unsupervised Learning using Nonequilibrium Thermodynamics》中首次将扩散过程引入深度生成模型的框架,奠定了扩散模型的理论基础。

近年来,扩散模型得到了快速发展。尤其是 Ho 等人(2020年)在论文《Denoising Diffusion Probabilistic Models (DDPM)》中提出了一种高效的扩散模型框架,这一工作开辟了扩散模型的新方向,使其在图像生成、语音合成和其他生成任务中表现优异。DDPM 将扩散过程分为两个阶段,一是前向过程(Forward Process), 向数据中逐步添加噪声,直到数据接近高斯分布。二是反向过程(Reverse Process),学习逐步去噪,最终从随机噪声中生成目标数据。如下图所示,随机噪声经过去噪模块1000次去噪之后,生成了一个猫的图像。

就像是米开朗基罗说的:“塑像就在石头里,我只是把不需要的部分去掉”。

2. 原 理

2.1 前向过程

前向过程也就是我们常说的扩散过程,它模拟真实数据逐渐被噪声污染的过程。其做法通常是从高斯分布里面采样一组噪声添加到正常的图片当中,产生有点噪声的图像,然后从高斯分布中再采样一次,再得到更加噪声的图片,以此类推,最后整张图片就看不出来原来是什么东西,也就是整张图片变成了一个接近高斯分布的噪声。做完这个扩散过程以后,就有去噪模块的训练数据了。

前向过程是一个固定的、不可学习的马尔可夫链。从初始数据分布  开始,逐步向数据中添加高斯噪声,使数据分布逐渐接近标准高斯分布 。公式表示为:

其中  是当前时间步的数据, 是前一时间步的数据, 是一个预定义的时间步长参数,表示每一步添加噪声的强度。 是单位矩阵,表示每个数据维度上的噪声是独立且均匀的。 是高斯分布,表示  的条件分布。

添加噪声的过程是一个马尔可夫链,在每一步中,数据根据高斯分布从  生成 ,噪声的均值为 ,方差为 。前向过程将原始数据逐渐加噪,最终在  步后,使其分布接近标准高斯分布 。

2.2 反向过程

反向过程的目的学习如何从完全随机的噪声逐步还原出目标数据。反向过程是需要通过神经网络来学习的,该网络的输入是一张有噪声的图,输出是一张滤掉一点噪声的图像,去噪越做越多,最终就能看到一张清晰的图片。如下图所示

 

通常,这个去噪的模型里面实际上是一个噪声预测器(noise predictor),它会预测图片里面的噪声。这个噪声预测器的输入是去噪的图片和噪声现在的严重程度(也就是我们现在进行到去噪的第几个步骤的代号),如下图所示。它预测在这张图片里面噪声应该长什么样子,再在去噪的图片中减去它预测的噪声,就产生去噪以后的结果,即输出一张噪声少了一点的图。

 

要训练这样的噪声预测器,要用到之前我们在扩散过程中产生的训练数据。即,扩散过程中产生的一张加完噪声的图片跟现在是第几次加噪声,是网络的输入,而加入的这个噪声就是网络应该要预测的输出。比如,噪声预测器的输入是一张加了第2次噪声的猫的图像以及 2 这个数字,接下来的输出应该是第2步加的噪声,如下图所示:

 

但有些时候我们不仅想要产生图片,还想产生与我们文字描述一样的图片,对于这样的情况,我们只需要在训练数据中增加对图片的描述,同时在去噪的每一个步骤中让噪声预测器多一个额外的输入,也就是描述的这段文字,如下图所示:

 

反向过程假设数据的逆演化仍然是一个马尔可夫过程,模型需要学习如何从  预测  ,逐步还原出无噪声的原始数据:

其中  是当前时间步的带噪声数据, 是目标时间步的数据(去噪后)。 是由神经网络预测的均值。 是由神经网络预测的方差(通常固定为常数以简化计算)。通过训练神经网络,使其能够预测每一步中的噪声成分。

扩散模型的训练目标是通过最大化似然估计来优化反向过程。训练过程等价于一个降噪任务,模型学习在给定加噪数据  的情况下,如何估计当前时间步的噪声,扩散模型的训练损失函数如下:

其中  是前向过程中实际添加的噪声, 是模型预测的噪声。 是当前时间步的损失,表示实际噪声与预测噪声之间的均方误差。通过最小化这个损失,模型学习如何在每一步准确预测噪声,从而能够反向还原数据。

如果你还想学习更多的AI大模型知识,这里我也贴心的为大家准备了一份学习资料。无偿分享给大家,VX扫描以下二维码即可领取

👉CSDN大礼包🎁:[全网最全《LLM大模型入门+进阶学习资源包》免费分享](安全链接,放心点击)👈 https://blog.csdn.net/2401_85325726/article/details/147295734?spm=1001.2014.3001.5501

3. 代 码

下面我们以生成动漫人脸图像为目标来训练Diffusion Model

动漫人脸数据集下载链接:

https://www.kaggle.com/datasets/b07202024/diffusion/download?datasetVersionNumber=1

本代码遵循典型的 DDPM(Denoising Diffusion Probabilistic Model)框架,整体分为 U-Net 模型(用于去噪)、GaussianDiffusion 类(提供前向扩散和反向采样逻辑)以及数据集和训练器等部分。U-Net 负责在不同尺度下对图像特征进行编码与解码,以预测在每个时间步中加入的噪声;GaussianDiffusion 封装了核心公式与超参数,包括 beta 调度、采样/训练流程及损失函数;Trainer 则管理训练过程,如数据加载、梯度累积、EMA(指数移动平均)等。这种架构将“加噪”和“去噪”分离并封装在模型和调度器中,使得训练和推理流程更加清晰易懂,也能方便地进行扩展或替换不同的网络结构与超参数策略。以下是完整代码(引自《李宏毅深度学习》):

import math  # 引入数学相关函数和常量
import copy  # 提供浅拷贝和深拷贝功能
from pathlib import Path  # 方便进行路径操作
from random import random  # 随机数生成函数
from functools import partial  # 可用于对函数进行部分参数绑定
from collections import namedtuple  # 提供类似结构体的命名元组
from multiprocessing import cpu_count  # 获取当前机器的 CPU 核心数

import torch  # 引入 PyTorch 库
from torch import nn, einsum  # nn 用于神经网络相关模块,einsum 可进行爱因斯坦求和
import torch.nn.functional as F  # 提供常用函数式神经网络操作
from torch.utils.data import Dataset, DataLoader  # 数据集基类和数据加载器

from torch.optim import Adam  # Adam 优化器

import torchvision  # PyTorch 视觉工具包
from torchvision import transforms as T, utils  # 图像预处理 transforms,utils 提供图像显示保存等功能

from einops import rearrange, reduce, repeat  # einops 提供灵活的张量变换函数
from einops.layers.torch import Rearrange  # einops 在 PyTorch 中的 Layer 实现

from PIL import Image  # Python Image Library,用于图像读写等操作
from tqdm.auto import tqdm  # 进度条库,auto 会根据环境自动选择合适的进度条
from ema_pytorch import EMA  # 指数滑动平均库,用于模型权重的 EMA
                              # (Exponential Moving Average)

from accelerate import Accelerator  # 用于分布式训练加速的库
import matplotlib.pyplot as plt  # 常用的绘图库
import os  # 与操作系统相关的功能

# 设置 cuDNN 的自动优化为 True,可以在某些情况下提升卷积运算速度
torch.backends.cudnn.benchmark = True

# 设置随机数种子,保证可重复性
torch.manual_seed(4096)

# 如果有 GPU 可用,则为 GPU 设置相同的随机数种子
if torch.cuda.is_available():
  torch.cuda.manual_seed(4096)


def linear_beta_schedule(timesteps):
    """
    linear schedule, proposed in original ddpm paper
    线性 Beta 时间调度函数,用于扩散模型(DDPM)中定义 Beta 参数在每个时间步的取值。
    """
    scale = 1000 / timesteps  # 用 timesteps 对原始范围进行缩放
    beta_start = scale * 0.0001  # 线性 Beta 的起始值
    beta_end = scale * 0.02      # 线性 Beta 的结束值
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
    # 在 [beta_start, beta_end] 范围上,生成 timesteps 个等差数列。


def extract(a, t, x_shape):
    """
    从向量 a 中取出与时间步 t 对应的值,并 reshape 成 x_shape 的形状。
    """
    b, *_ = t.shape  # 获取 batch 大小 b,忽略后面维度
    out = a.gather(-1, t)  # 在最后一维上按照索引 t 收集元素
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))
    # reshape 成 (b, 1, 1, ..., 1) 的形式,以便后续进行广播操作


class Dataset(Dataset):
    """
    自定义数据集,用于加载指定文件夹下的 .jpg 图像文件。
    """
    def __init__(
        self,
        folder,      # 图像文件所在的文件夹路径
        image_size   # 需要 resize 到的图像尺寸
    ):
        self.folder = folder        # 保存文件夹路径
        self.image_size = image_size  # 保存图像尺寸
        self.paths = [p for p in Path(f'{folder}').glob(f'**/*.jpg')]
        # 递归搜索 folder 目录下的所有 .jpg 文件,并将路径保存到 self.paths 列表中

        #################################
        ## TODO: Data Augmentation ##
        #################################
        self.transform = T.Compose([
            T.Resize(image_size),  # 调整图像大小到 image_size x image_size
            T.ToTensor()           # 将图像转换为 PyTorch 张量,并将像素值归一化到 [0,1]
        ])

    def __len__(self):
        # 返回数据集的大小(图像的数量)
        return len(self.paths)

    def __getitem__(self, index):
        # 根据索引返回对应的图像数据
        path = self.paths[index]  # 获取图像路径
        img = Image.open(path)    # 打开图像
        return self.transform(img)  # 应用预处理并返回

def exists(x):
    return x is not None
    # 如果 x 不是 None,则返回 True,否则返回 False

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d
    # 如果 val 存在,则返回 val;否则,如果 d 可调用,则返回 d(),否则返回 d
    # 这个函数通常用于参数的默认值设置

def identity(t, *args, **kwargs):
    return t
    # 一个简单的恒等函数,原样返回输入 t,本质上不对数据做任何处理

def cycle(dl):
    while True:
        for data in dl:
            yield data
    # 这是一个生成器,循环迭代给定的 dataloader (dl),使其可无限次迭代

def has_int_squareroot(num):
    return (math.sqrt(num) ** 2) == num
    # 判断 num 的平方根是否为整数
    # math.sqrt(num) ** 2 如果等于 num,则 num 的平方根是整数

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr
    # 将 num 拆分成大小为 divisor 的若干组,最后如果有余数 remainder,则将它作为一组附加到数组的末尾
    # 举例:num=10, divisor=3,则返回 [3, 3, 4]

# normalization functions

def normalize_to_neg_one_to_one(img):
    return img * 2 - 1
    # 将图像像素值 [0, 1] 的范围映射到 [-1, 1]

def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5
    # 将图像像素值 [-1, 1] 的范围映射回 [0, 1]

# small helper modules

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        # 初始化时传入一个 nn.Module,然后在 forward 中进行残差连接

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x
        # 前向传播时,将输入 x 经过 self.fn,再加回原始 x 实现残差结构

def Upsample(dim, dim_out = None):
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode = 'nearest'),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
    )
    # 上采样模块:
    # 1) 将特征图放大 2 倍(最近邻插值)
    # 2) 卷积将通道数从 dim 映射到 dim_out(如果 dim_out 没传,则仍为 dim)

def Downsample(dim, dim_out = None):
    return nn.Sequential(
        Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1)
    )
    # 下采样模块:
    # 1) 通过 einops 将高宽各自分辨率乘以 2 的模式展平到通道维度上(相当于像素重排)
    #    使得通道数扩大 4 倍 (p1=2, p2=2 => 2*2=4)
    # 2) 1x1 卷积将通道数映射到 dim_out(默认保持不变)

class WeightStandardizedConv2d(nn.Conv2d):
    """
    https://arxiv.org/abs/1903.10520
    weight standardization purportedly works synergistically with group normalization
    """
    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        # 根据数据类型选择不同的数值稳定性常数 eps

        weight = self.weight
        mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
        var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False))
        # 计算卷积核在输出通道维度 o 上的均值 mean 和方差 var

        normalized_weight = (weight - mean) * (var + eps).rsqrt()
        # 对卷积核进行标准化

        return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
        # 使用标准化后的卷积核进行卷积操作

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        # 可学习的缩放参数 g,初始化为 1

    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) * (var + eps).rsqrt() * self.g
        # 沿着通道维度(=1)做 layer norm,并使用 g 来进行缩放

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)
        # 在执行 fn 之前,先进行 LayerNorm

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)
        # 前向过程:先归一化,再执行传入的 fn

# sinusoidal positional embeds

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        # 记录嵌入维度

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        # 生成等比数列,以用于构建正余弦频率

        emb = x[:, None] * emb[None, :]
        # 将输入时间步 x 与频率 emb 相乘

        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        # 拼接正弦和余弦部分

        return emb
        # 返回正余弦位置编码

class RandomOrLearnedSinusoidalPosEmb(nn.Module):
    """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
    """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """

    def __init__(self, dim, is_random = False):
        super().__init__()
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
        # 如果 is_random 为 False,则这些频率是可学习的;否则是随机固定

    def forward(self, x):
        x = rearrange(x, 'b -> b 1')
        # 将输入 x reshape 成 (batch, 1)

        freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
        # 计算随机或可学习的频率 freqs

        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
        # 拼接正弦和余弦

        fouriered = torch.cat((x, fouriered), dim = -1)
        # 再将原始 x 与正余弦部分合并

        return fouriered
        # 返回包含输入和正余弦编码的结果

# building block modules

class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()
        # 一个卷积 -> GroupNorm -> SiLU 激活的基本模块
        # 卷积使用 WeightStandardizedConv2d,便于搭配 GroupNorm

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift
            # 如果从时间嵌入得到 scale_shift,则对特征图进行缩放和偏移

        x = self.act(x)
        return x
        # 输出经过标准化和激活的张量

class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
        super().__init__()
        # 如果传入了 time_emb_dim,则对时间嵌入进行线性映射得到 scale 和 shift
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, dim_out * 2)
        ) if exists(time_emb_dim) else None

        self.block1 = Block(dim, dim_out, groups = groups)
        self.block2 = Block(dim_out, dim_out, groups = groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
        # 如果 dim != dim_out,就用 1x1 卷积在残差分支中对通道数进行调整

    def forward(self, x, time_emb = None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)
            # 将 time_emb 拆分为 (scale, shift)

        h = self.block1(x, scale_shift = scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)
        # 最终输出为正常流 (h) + 残差分支

class LinearAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        # 一次性生成 q, k, v
        self.to_out = nn.Sequential(
            nn.Conv2d(hidden_dim, dim, 1),
            LayerNorm(dim)
        )
        # 输出层(卷积 + LayerNorm)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        # 将通道维分成 q, k, v
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)
        # 分别对 q 的通道维(-2)和 k 的序列维(-1)做 softmax

        q = q * self.scale
        v = v / (h * w)
        # 缩放 q,以及对 v 做归一化

        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
        # 先将 k 和 v 做乘积,得到上下文 context

        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        # 再和 q 做乘积以得到输出

        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
        # reshape 回原始形状

        return self.to_out(out)
        # 卷积 + LayerNorm 得到最终结果

class Attention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)
        # 自注意力机制:先获取 q, k, v,再做注意力加权求和,最后映射回 dim

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)

        q = q * self.scale
        # 缩放 q

        sim = torch.einsum('b h d i, b h d j -> b h i j', q, k)
        # 相似度矩阵 sim (b, heads, i, j)

        attn = sim.softmax(dim = -1)
        # 沿着最后一维做 softmax,得到注意力分布

        out = torch.einsum('b h i j, b h d j -> b h i d', attn, v)
        # 加权求和得到输出

        out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
        # reshape 回原始分辨率

        return self.to_out(out)
        # 最后再用 1x1 卷积映射回 dim 维度

# model

class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim = None,
        out_dim = None,
        dim_mults=(1, 2, 4, 8),
        channels = 3,
        resnet_block_groups = 8,
        learned_sinusoidal_cond = False,
        random_fourier_features = False,
        learned_sinusoidal_dim = 16
    ):
        super().__init__()

        # determine dimensions

        self.channels = channels

        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding = 3)
        # 输入通道 -> init_dim, 使用 7x7 卷积做初始特征提取

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        # 比如 dim=64, dim_mults=(1,2,4,8), 则 dims=[64, 64*1, 64*2, 64*4, 64*8]
        # in_out 就是 [(64,64),(64,128),(128,256),(256,512)]

        block_klass = partial(ResnetBlock, groups = resnet_block_groups)
        # 使用部分函数 partial,将 ResnetBlock 的 groups 参数固定

        # time embeddings

        time_dim = dim * 4
        # 时间嵌入的维度,一般设置为 4 倍 base dim

        self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features

        if self.random_or_learned_sinusoidal_cond:
            sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
            fourier_dim = learned_sinusoidal_dim + 1
        else:
            sinu_pos_emb = SinusoidalPosEmb(dim)
            fourier_dim = dim
        # 根据需要选择使用随机/可学习的正弦嵌入,或使用经典的正弦嵌入

        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )
        # 时间嵌入先经过正弦嵌入,然后用两个全连接层(中间激活为 GELU),维度转为 time_dim

        # layers

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
            # 判断是否是最后一个分辨率

            self.downs.append(nn.ModuleList([
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
            ]))
            # down 阶段:
            # 1) ResnetBlock(dim_in -> dim_in)
            # 2) 再一个 ResnetBlock(dim_in -> dim_in)
            # 3) Residual(PreNorm(LinearAttention))
            # 4) 如果不是最后层,用 Downsample;否则用 3x3 卷积保持分辨率

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
        # 中间层(U-Net 最底部):ResnetBlock -> 自注意力 -> ResnetBlock

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)
            # 倒序遍历 in_out,用于 up 阶段

            self.ups.append(nn.ModuleList([
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                Upsample(dim_out, dim_in) if not is_last else  nn.Conv2d(dim_out, dim_in, 3, padding = 1)
            ]))
            # up 阶段的逻辑与 down 类似,只是要先拼接 skip connection

        self.out_dim = default(out_dim, channels)
        # 最终输出通道数,默认与输入通道一致

        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
        # 最后一步和初始输入拼接后,再过一个 ResnetBlock

        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
        # 通过 1x1 卷积将维度映射到 out_dim

    def forward(self, x, time):
        x = self.init_conv(x)
        # 初始卷积提取特征
        r = x.clone()
        # 保存初始特征用于最后拼接

        t = self.time_mlp(time)
        # 将时间步 time 通过 time_mlp 得到时间嵌入 t

        h = []
        # 用于保存每层的输出,以便在解码器阶段做 skip connection

        # downsample
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)
            # 依次执行 block1 -> block2 -> attn -> downsample
            # 并存储中间输出 h

        # mid
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)
        # U-Net 中间层的处理

        # upsample
        for block1, block2, attn, upsample in self.ups:
            # pop 出下采样时存储的输出,进行 skip connection
            x = torch.cat((x, h.pop()), dim = 1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim = 1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        # final
        x = torch.cat((x, r), dim = 1)
        # 跟最初的输入特征 r 拼接

        x = self.final_res_block(x, t)
        return self.final_conv(x)
        # 最终输出一个跟输入维度相匹配的特征图

model = Unet(64)
# 实例化一个 U-Net 模型,基本通道数 dim = 64

class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        model,                    # 传入的 U-Net 等模型,用于预测噪声
        *,
        image_size,              # 图像大小(宽和高)
        timesteps = 1000,        # 扩散过程的总时间步数
        beta_schedule = 'linear',# beta 的调度方式;此处仅支持 'linear'
        auto_normalize = True    # 是否自动将图像 [0,1] 归一化到 [-1,1]
    ):
        super().__init__()
        # 继承自 nn.Module

        assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
        # 如果是 GaussianDiffusion 类本身,则要求 model 的输入通道和输出通道一致,否则会出错
        assert not model.random_or_learned_sinusoidal_cond
        # 在本实现里,不允许网络使用随机或可学习的正弦位置编码

        self.model = model
        # 保存传入的模型(通常是一个 U-Net)

        self.channels = self.model.channels
        # 模型的通道数量(图像的通道,默认为 3)

        self.image_size = image_size
        # 保存图像大小

        if beta_schedule == 'linear':
            beta_schedule_fn = linear_beta_schedule
        else:
            raise ValueError(f'unknown beta schedule {beta_schedule}')
        # 根据传入的 beta_schedule 字符串选择 beta 调度函数
        # 目前只支持 'linear',否则抛出异常

        # calculate beta and other precalculated parameters
        betas = beta_schedule_fn(timesteps)
        # 计算在每个时间步上的 beta 值(线性递增)

        alphas = 1. - betas
        # α_t = 1 - β_t
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        # 累乘得到 α_1 * α_2 * ... * α_t
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
        # 向前偏移一个时间步,便于在计算 q(x_{t-1}|x_t, x_0) 时使用
        # 第一个时间步补 1,使 α_cumprod_prev 的长度与 alphas_cumprod 一致

        timesteps, = betas.shape
        # 获取时间步数(1000)
        self.num_timesteps = int(timesteps)
        # 将其保存为整型

        # sampling related parameters
        self.sampling_timesteps = timesteps
        # 采样时使用的步数,默认和训练步数相同

        # helper function to register buffer from float64 to float32
        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
        # 定义一个小函数,用于将各种张量注册为 buffer,并转换为 float32 类型

        register_buffer('betas', betas)
        register_buffer('alphas_cumprod', alphas_cumprod)
        register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
        # 将以上计算好的 beta、alpha 累乘、以及前一个时间步的 alpha 累乘注册为 buffer
        # 这些值是训练和推理都会用到,但不会被训练的参数

        # calculations for diffusion q(x_t | x_{t-1}) and others
        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        # sqrt(累乘α_t)
        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        # sqrt(1 - 累乘α_t)
        register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        # 记录 log(1 - 累乘α_t)
        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        # sqrt(1 / 累乘α_t)
        register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
        # sqrt(1 / 累乘α_t - 1)

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
        # q(x_{t-1} | x_t, x_0) 的后验方差
        # 根据公式: posterior_variance_t = β_t * (1 - α_{t-1}累乘) / (1 - α_t累乘)

        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)

        register_buffer('posterior_variance', posterior_variance)
        # 注册后验方差

        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
        # 取对数时夹紧最小值防止数值溢出
        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        # 后验均值系数1
        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
        # 后验均值系数2

        # derive loss weight
        # snr - signal noise ratio
        snr = alphas_cumprod / (1 - alphas_cumprod)
        # SNR = α_t累乘 / (1 - α_t累乘)

        # https://arxiv.org/abs/2303.09556
        maybe_clipped_snr = snr.clone()
        # 这里可以对 snr 做一些裁剪操作,如果需要的话

        register_buffer('loss_weight', maybe_clipped_snr / snr)
        # 用于加权损失的系数

        # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False
        self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
        self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
        # 根据 auto_normalize 决定是否对数据进行 [-1,1] <-> [0,1] 的转换

    def predict_start_from_noise(self, x_t, t, noise):
        """
        通过 x_t 和噪声,反推 x_0 的预测值
        x_0 = 1 / sqrt(alpha_cumprod) * x_t - sqrt(1 / alpha_cumprod - 1) * noise
        """
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    def predict_noise_from_start(self, x_t, t, x0):
        """
        通过 x_t 和对 x_0 的预测值,反推噪声的预测值
        noise = (1 / sqrt(alpha_cumprod) * x_t - x_0) / sqrt(1 / alpha_cumprod - 1)
        """
        return (
            (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
        )

    def q_posterior(self, x_start, x_t, t):
        """
        计算后验分布 q(x_{t-1} | x_t, x_0) 的均值和方差
        """
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        # 后验分布的均值
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        # 后验分布的方差
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        # 后验分布方差的对数(已做 clip)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def model_predictions(self, x, t, clip_x_start = False, rederive_pred_noise = False):
        """
        给定当前噪声图 x 和时间步 t,通过模型预测噪声 pred_noise,并得到对 x_0 的估计 x_start
        """
        model_output = self.model(x, t)
        # 模型输出,通常是预测噪声

        maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
        # 如果需要对预测出的 x_0 做裁剪,则 partial(torch.clamp);否则恒等函数

        pred_noise = model_output
        # 这里把模型输出视为噪声预测
        x_start = self.predict_start_from_noise(x, t, pred_noise)
        x_start = maybe_clip(x_start)
        # 对 x_0 进行 [-1,1] 裁剪(可选)

        if clip_x_start and rederive_pred_noise:
            # 如果 x_0 被裁剪,为了更准确,需要重新计算一次噪声
            pred_noise = self.predict_noise_from_start(x, t, x_start)

        return pred_noise, x_start

    def p_mean_variance(self, x, t, clip_denoised = True):
        """
        计算从扩散过程中 p(x_{t-1} | x_t) 的均值和方差,用于反向采样
        """
        noise, x_start = self.model_predictions(x, t)
        # 模型预测噪声和 x_0

        if clip_denoised:
            x_start.clamp_(-1., 1.)
            # 默认会把 x_0 的范围裁剪到 [-1,1]

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
            x_start = x_start,
            x_t = x,
            t = t
        )
        # 计算后验分布的均值和方差
        # 这里的后验分布相当于 q(x_{t-1}|x_t, x_0)

        return model_mean, posterior_variance, posterior_log_variance, x_start

    @torch.no_grad()
    def p_sample(self, x, t: int):
        """
        在反向扩散的某一个时间步 t,从 p(x_{t-1} | x_t) 采样
        """
        b, *_, device = *x.shape, x.device
        batched_times = torch.full((b,), t, device = x.device, dtype = torch.long)
        # 构造与批大小相同的时间张量

        model_mean, _, model_log_variance, x_start = self.p_mean_variance(
            x = x,
            t = batched_times,
            clip_denoised = True
        )
        # 根据 x_t 计算后验均值和方差

        noise = torch.randn_like(x) if t > 0 else 0.
        # 如果 t > 0 则在采样时加噪声;如果 t=0,则不再加噪声

        pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
        # 采样公式: x_{t-1} = 均值 + 标准差 * 噪声

        return pred_img, x_start

    @torch.no_grad()
    def p_sample_loop(self, shape, return_all_timesteps = False):
        """
        从纯噪声开始,逐步反向采样还原图像
        """
        batch, device = shape[0], self.betas.device
        # batch 大小, 使用存储在 buffer 中的 betas 的设备

        img = torch.randn(shape, device = device)
        # 初始从标准正态分布采样

        imgs = [img]
        # 用于保存采样过程中每个时间步的结果
        x_start = None

        ###########################################
        ## TODO: plot the sampling process ##
        ###########################################
        for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
            # 从 T-1 到 0 逐步反向采样
            img, x_start = self.p_sample(img, t)
            imgs.append(img)

        ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
        # 如果 return_all_timesteps=True, 返回整个采样序列;否则只返回最终生成的图像

        ret = self.unnormalize(ret)
        # 将图像从 [-1,1] 转回 [0,1]
        return ret

    @torch.no_grad()
    def sample(self, batch_size = 16, return_all_timesteps = False):
        """
        对外提供的采样接口
        """
        image_size, channels = self.image_size, self.channels
        sample_fn = self.p_sample_loop
        # 默认使用 p_sample_loop 进行逐步采样
        return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)

    def q_sample(self, x_start, t, noise=None):
        """
        前向扩散:从 x_0 得到 x_t 的采样
        x_t = sqrt(α_cumprod) * x_0 + sqrt(1-α_cumprod) * noise
        """
        noise = default(noise, lambda: torch.randn_like(x_start))
        # 如果不指定噪声,则生成一个和 x_start 形状相同的高斯噪声

        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

    @property
    def loss_fn(self):
        return F.mse_loss
        # 训练时使用的损失函数,默认是 MSE

    def p_losses(self, x_start, t, noise = None):
        """
        在给定 x_0 以及随机的时间步 t 时,计算训练时的损失
        """
        b, c, h, w = x_start.shape
        noise = default(noise, lambda: torch.randn_like(x_start))

        # noise sample
        x = self.q_sample(x_start = x_start, t = t, noise = noise)
        # 前向扩散,将 x_0 添加噪声到 x_t

        # predict and take gradient step
        model_out = self.model(x, t)
        # 模型对 x_t 进行估计噪声

        loss = self.loss_fn(model_out, noise, reduction = 'none')
        # 计算 MSE 损失 (逐元素)
        loss = reduce(loss, 'b ... -> b (...)', 'mean')
        # 在除 batch 之外的所有维度取平均 (即每个样本的损失)
        loss = loss * extract(self.loss_weight, t, loss.shape)
        # 乘以权重 (与 SNR 相关)

        return loss.mean()
        # 返回对整个 batch 的平均损失

    def forward(self, img, *args, **kwargs):
        """
        模块的前向调用接口,一般在训练时调用
        """
        b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
        # 解包图像形状、设备以及定义的图像大小
        assert h == img_size and w == img_size, f'height and width of image must be {img_size}'

        t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
        # 随机采样一个时间步 t 用于训练

        img = self.normalize(img)
        # 如果开启了 auto_normalize,则把 [0,1] 的图片映射到 [-1,1]

        return self.p_losses(img, t, *args, **kwargs)
        # 调用 p_losses 计算训练损失

path = './faces/faces'
# 数据所在的文件路径,这里假设所有训练图像都在 ./faces/faces 目录中

IMG_SIZE = 64
# 设置图像尺寸为 64x64

batch_size = 16
# 设置训练时的批大小为 16 张图像

train_num_steps = 10000
# 训练的总步数,指优化器更新(iteration)次数

lr = 1e-3
# 学习率 (learning rate),这里设置为 0.001

grad_steps = 1
# 梯度累积步数;若设置大于 1 则表示每累积一定次数的反向传播再进行一次优化更新

ema_decay = 0.995
# 指数移动平均 (EMA) 的衰减率,常用于在训练过程中平滑模型权重

channels = 16
# U-Net 的基础通道数,即第一个卷积层的通道数

dim_mults = (1, 2, 4)
# 用来指定 U-Net 不同下采样 / 上采样阶段的通道扩展倍数,  
# 最终网络结构中的通道数将按 (channels, 2*channels, 4*channels, ...) 的形式逐步增加

timesteps = 100
# 扩散过程中加噪声的时间步数 T;比如在 DDPM 中可以是 1000,这里设置为 100

beta_schedule = 'linear'
# beta 的调度方式(表示在扩散过程中 beta 的变化),此处设置为线性

model = Unet(
    dim = channels,
    dim_mults = dim_mults
)
# 实例化一个 U-Net 模型对象,输入的基本通道数为 16,  
# 会根据 dim_mults 逐步在网络层中增加通道数

diffusion = GaussianDiffusion(
    model,
    image_size = IMG_SIZE,
    timesteps = timesteps,
    beta_schedule = beta_schedule
)
# 将 U-Net 模型封装到 GaussianDiffusion 类中,  
# 并设置扩散过程中的一些参数(如图像大小、时间步数等)。  
# 该类会负责前向扩散(加噪)和反向扩散(去噪)的具体实现。

trainer = Trainer(
    diffusion,
    path,
    train_batch_size = batch_size,
    train_lr = lr,
    train_num_steps = train_num_steps,
    gradient_accumulate_every = grad_steps,
    ema_decay = ema_decay,
    save_and_sample_every = 1000
)
# 实例化一个 Trainer 类来管理训练流程:  
# - 使用 diffusion 模型进行前向与反向传播  
# - 每个 batch 的大小为 16  
# - 使用学习率 1e-3  
# - 总训练步数为 10000  
# - 每个 step 都更新梯度(grad_steps=1)  
# - EMA 衰减因子为 0.995  
# - 每 1000 步保存一次模型并进行一次采样

trainer.train()
# 开始训练,Trainer 内部会执行循环读取数据、前向计算、损失反传、优化器更新等流程。

运行环境:

accelerate                1.0.1
einops                    0.8.0
ema-pytorch               0.7.7 
matplotlib                3.5.1
multiprocess              0.70.15
numpy                     1.24.4
python                    3.8.19
pytorch                   2.4.0
pytorch-cuda              12.1
tqdm                      4.66.5

下图为模型在完成训练之后生成的动漫人脸图像:

 

从该结果可以看出,模型成功地学习到了二次元人脸的整体特征与色彩分布,生成的人像在发型、五官、配色等方面都有一定的多样性,说明扩散模型在此任务中具备一定的泛化能力。不过图像中仍存在一定程度的模糊、面部细节缺失或扭曲等现象,表明训练规模与网络容量可能还需要进一步优化,以获得更精细、更稳定的生成质量。

总结一下,扩散模型(Diffusion Model)通过在前向过程逐步向图像添加噪声、在反向过程逐步去噪的方式实现图像生成,具有相对稳定的训练过程和良好的生成多样性。它在高分辨率图像生成、条件生成(文本、语音、语义分割等)方面表现不错,且与自回归、GAN 等其他生成方法形成互补。未来发展方向包括更高效的采样策略、更灵活的条件控制、多尺度或多模态的融合,以及在更广泛的数据类型(视频、3D 等)上的应用和研究。

如果你还想学习更多的AI大模型知识,这里我也贴心的为大家准备了一份学习资料。无偿分享给大家,VX扫描以下二维码即可领取

👉CSDN大礼包🎁:[全网最全《LLM大模型入门+进阶学习资源包》免费分享](安全链接,放心点击)👈 https://blog.csdn.net/2401_85325726/article/details/147295734?spm=1001.2014.3001.5501

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值