DDPM代码复现

主要参考这个博主的代码https://lixudong.ink/2023/03/17/%e3%80%90%e7%bf%bb%e8%af%91%e3%80%91the-annotated-diffusion-model/(https://lixudong.ink/2023/03/17/%E3%80%90%E7%BF%BB%E8%AF%91%E3%80%91the-annotated-diffusion-model/)
上面链接如果不行,可以去他的主页:https://lixudong.ink/

# https://lixudong.ink/2023/03/17/%E3%80%90%E7%BF%BB%E8%AF%91%E3%80%91the-annotated-diffusion-model/
import math
from inspect import isfunction
from functools import partial

import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange, reduce
from einops.layers.torch import Rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F


# 代码块2:U-Net网络中的模块实现
# exists函数用来判断输入是否非None,是则输出true,否则输出false
def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


# 若输入的val不是None,则直接返回val的值;
# 若输入的val是None: 若d是一个函数,那么返回d();若d不是一个函数,那么直接返回d的值.

def num_to_groups(num, divisor):
    groups = num // divisor  # //为整除符号,返回商的整数部分,向下取整。例如12//5 = 2 ; 20//3 = 6。groups为组数
    remainder = num % divisor  # %为取余符号,这里返回分组后的余数remainder
    arr = [divisor] * groups  # 数组操作,排列成一个元素值为divisor,共有groups个数的一维数组
    if remainder > 0:
        arr.append(remainder)  # 如果无法除尽,将余数连到上述数组的末尾。
    return arr


# 本函数是用来把一个数(num)变成尽量多的某个数的和,加上一个余数。
# 例如,num = 20 , divisor = 4, arr = [4,4,4,4];
# 再比如,num = 25 , divisor = 6 , arr = [6,6,6,6,1]
# 再比如,num = 17 , divisor = 2 , arr = [2,2,2,2,2,2,2,2,1]

class Residual(nn.Module):
    # 定义nn.Model的子类Residual

    def __init__(self, fn):
        # 定义子类Residual的初始化函数,self是子类的一个对象,fn是传入的参数
        super().__init__()
        # super()方法是通用的抽象调用父类的方法,这里是使用父类(nn.Module)的初始化方法__init__()来初始化子类对象
        self.fn = fn
        # 子类比父类多了一个fn,父类的__init__()函数无法初始化,因此用传入的参数fn来初始化

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x
    # forward函数:定义了一个前向过程,使用残差块连接,即将输出值self.fn和输入值x相加,来避免梯度消失和梯度爆炸
    # *args是位置参数,**kwargs是关键字参数,位置参数是通过相对位置来对应传参的参数,而关键字参数是通过指定参数名来传参的参数
    # *args 用来接收任意数量的位置参数,它将这些参数打包成一个元组 args
    # **kwargs 用来接收任意数量的关键字参数,它将这些参数打包成一个字典kwargs, 即以“参数-参数值”的形式存放
    # 例如:def add_numbers(x, y): return x + y
    # 想通过位置参数来调用这个函数,则:add_numbers(2, 3)
    # 想通过关键字参数来调用这个函数,则:add_numbers(x = 2, y = 3)
    # 在这里采用*args 和 **kwargs的形式是为了通用性,即在不确定输入参数个数和形式的情况下依然可以使用这个函数


def Upsample(dim, dim_out=None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode="nearest"), # scale_factor表示放大多少倍,mode为采样模式
        nn.Conv2d(dim, default(dim_out, dim), 3, padding=1),
    )             # default这个用法感觉很有用


# 定义上采样函数Upsample,默认的输出通道数参数为None
# nn.Sequential():表示上采样的链式过程,先经过nn.Upsample(),再经过nn.Conv2d()
# nn.Upsample():内置的上采样函数,scale_factor为放大倍数(放大2倍),mode为采样模式(最近邻采样)
# nn.Conv2d():内置的二维卷积函数,dim是输入的通道数
# default表示如果dim_out非None(即传入了相应的参数,而非使用默认值)则dim_out为输出的通道数
# 否则dim作为输出的通道数

# 3为卷积核的大小(kernel_size),padding为要填充的数,这里表示填1。
# 因为卷积操作常常导致输出尺寸小于输入尺寸,为了保持输出尺寸与输入尺寸相同,
# 我们可以在输入的边界上添加一些额外的像素(通常是0,这里是1)来进行填充(padding)

def Downsample(dim, dim_out=None):
    # No More Strided Convolutions or Pooling
    # 这里的dim参数为输入图像的通道数,即channels
    # 将通道数c变为4倍,然后长和宽都降为一半
    '''
                  输入形状 - "b c (h p1) (w p2)":
                      b, c, h, w 分别代表批大小(batch size)、通道数(channels)、高度(height)和宽度(width)
                      (h p1) 和 (w p2) 表示将原始的高度和宽度维度分别分割成多个小块,p1 和 p2 是分割的大小
                      例如,如果 p1=2 和 p2=2,那么高度和宽度每两个像素分割成一块
                  输出形状 - "b (c p1 p2) h w":
                      b(批大小)维度保持不变,
                      (c p1 p2) 表示新的通道维度,原来的通道数 c 被扩展为 c * p1 * p2,
                      这是因为每个块里面的元素都成为h/p1,w/p2的一部分,h 和 w 维度变成了原来的一半,因为这里是以2x2 的块进行分割的

                  例子:如果输入是[1, 1, 6, 6],数据是1到36
                  1  2  3  4  5  6
                  7  8  9  10 11 12
                  13 14 15 16 17 18
                  19 20 21 22 23 24
                  25 26 27 28 29 30
                  31 32 33 34 35 36
                  [[[ 1,  3,  5],
                    [13, 15, 17],
                    [25, 27, 29]],

                   [[ 2,  4,  6],
                    [14, 16, 18],
                    [26, 28, 30]],

                   [[ 7,  9, 11],
                    [19, 21, 23],
                    [31, 33, 35]],

                   [[ 8, 10, 12],
                    [20, 22, 24],
                    [32, 34, 36]]]
              ''',
    return nn.Sequential(
        Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),  # 分割图片并重排
        # 输入维度*4,因为上面分割成了4倍
        nn.Conv2d(dim * 4, default(dim_out, dim), 1),
    )

# 定义下采样函数Downsample,默认的输出通道数参数为None
# nn.Conv2d:完成图片分割之后,将结果进行卷积操作。


#代码块3:Position Embeddings ,将时间t,也就要加噪的程度进行位置编码
class SinusoidalPositionEmbeddings(nn.Module):
#Sinusiodal:正弦的
#意思是采用正弦函数对位置进行编码嵌入
    def __init__(self, dim):
        super().__init__()#调用父类的初始化函数
        self.dim = dim    # 嵌入维度(在unet里是通道数,为了方便估计就即表示输入图片维度,又表示通道数量了)

    def forward(self, time):
        # time是一个输入张量,time.device代表这个张量所在的设备类型(即是CPU还是GPU)
        # 这句话的意义在于获取time的设备类型作为后续计算的设备,保证一致性
        device = time.device
        # 方便运算的参数,确定half_dim是传入维度的一半(向下取整)
        # 一会做sin cos会把另一半补全
        half_dim = self.dim // 2
        # 下面两行是将传入的0-half_dim序列值映射到0-1之间
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        # time是[m,1]张量,embedding是[1,n]张量,形成一个[m,n]张量
        embeddings = time[:, None] * embeddings[None, :]
        # 将每个元素都sin和cos并拼接,对每一行的所有元素都先做sin,再做cos,然后把cos接在这一行sin的后面
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)

        return embeddings

# U-Net 模型的核心构建块 ResNet block
# 这个类其实就是重写了Conv2d类,使得它的卷积操作的权重更加标准化。
# 调用这个类其实就是做了一个更加标准化的卷积操作
class WeightStandardizedConv2d(nn.Conv2d):
    """
    https://arxiv.org/abs/1903.10520
    weight standardization purportedly works synergistically with group normalization
    """
    # 它继承自 nn.Conv2d(PyTorch 的标准卷积层类),它会自动继承 nn.Conv2d 的所有属性和方法
    # 在 WeightStandardizedConv2d 类中,没有明显地定义自己的构造函数(__init__),因此它直接使用了父类 nn.Conv2d 的构造函数
    # 重写了Conv2d的forward方法
    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        # 获得权重的均值
        mean = reduce(weight, "o ... -> o 1 1 1", "mean")
        # 获得权重的方差
        var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False))
        # 对其标准化,确保权重具有0均值和单位方差
        normalized_weight = (weight - mean) * (var + eps).rsqrt()

        return F.conv2d(
            x,
            normalized_weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        # WeightStandardizedConv2d这个类继承了Conv2d类,没有重新__init__方法,所以继承了Conv2d的方法
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
        # 是将卷积操作中的通道维度进行分组,然后在每个组内进行归一化操作
        # 它的分组和batch无关,它会将batch中的每一个样本都进行操作
        self.norm = nn.GroupNorm(groups, dim_out)
        # Sigmoid Linear Unit 它的返回值是x*sigmoid(x)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)
        # 对张量x中的值进行缩放和平移操作
        # 综合来看,这一步貌似是引入时间信息
        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385"""
    # 在函数的参数列表中使用 *,表示它后面的参数必须以关键字参数(keyword arguments)的形式传递
    # dim 和 dim_out 是位置参数(positional arguments),这意味着在调用函数时,可以直接按顺序传递值给这些参数
    # 后面的 time_emb_dim 和 groups 是关键字参数。这意味着在调用这个函数(或者说,在创建这个类的实例)时
    # 必须明确指定这些参数的名字,如 time_emb_dim=value 和 groups=value。
    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = (
            # 这是 Python 中的条件表达式(也称为三元操作符),它的基本结构是 A if condition else B
            # 如果 condition 为真,则表达式的值是 A,否则是 B。
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        scale_shift = None
        # 如果有时间信息嵌入,就处理
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1 1")
            # 将张量分割成特定数量的块。这里,它把 time_emb 沿着第二个维度(dim=1,即通道维度)分割成两个块
            scale_shift = time_emb.chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)



# 注意力模块
# 第一个是常规的多头自注意力(是多头注意力机制)
# 第二个是线性注意力变体,它的时间和内存需求与序列长度成线性比例,这和常规注意力的二次方不同
class Attention(nn.Module):
    # heads是指原本的QKV被分为了几个小头
    # dim_head是指每个头的向量长度
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        # **是幂运算符,也就是计算dim_head的-0.5次方
        # scale是一个缩放因子,它的作用是对计算得到的查询向量 Q 进行缩放。在自注意力机制中,
        # 计算 Q 和 K 之间的相似度通常涉及到它们的点积。点积会随着维度的增加而增长,
        # 可能导致非常大的数值,这会使得 softmax 函数的梯度变得非常小(梯度消失问)
        self.scale = dim_head ** -0.5
        self.heads = heads
        # 每个QKV通道数(特征维度)
        hidden_dim = dim_head * heads
        # 一个卷积层,输入通道输是给的dim,输出通道数是hidden_dim*3
        # 用于将输入转换为查询(Q)、键(K)和值(V),每个Q、K、V都有hidden_个通道
        # 这里是先弄成QKV维度,等会再拆
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        # 将通道数再转为正常通道数
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        # 将通道数分割为3块
        # 张量最外层都有一个[],比如[[1,2,3],[4,5,6]],dim=0是代表[1,2,3],[4,5,6]这一层
        # 也就是batch所在,dim=1代表1,2,3这一层,也就是通道
        # chunk已经将to_qkv的返回值根据第一维度划分为3个张量了,只不过封装起来给了qkv,具体可以运行chunk.py文件
        qkv = self.to_qkv(x).chunk(3, dim=1)
        # t从qkv取张量,然后再通过rearrange操作将通道拆分为头数,和每个头的通道数
        # 然后再讲x,y二维矩阵,压缩为一个一维向量,保持整个张量形状不变
        # 所以q,k,v张量的意思就是:b是batch,h是这个样本分为了多少头,c是每个头对应了多少通道,(x,y)是特征数量
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        # 对q进行缩放,防止qk乘积过大
        q = q * self.scale
        # torch.einsum 函数,它是一个强大的工具,用于指定和执行张量之间的特定操作,特别是那些涉及多维数据的操作
        # 在这个代码中它计算了查询(Q)和键(K)之间的点积,用于生成注意力得分
        # b是batch,h是有多少个头,d是每个头有多少通道,i和j是每个通道的特征数量
        # 对于每个批次 b 和每个头 h,q 的第 i 个位置和 k 的第 j 个位置上的向量在 d 维度上进行点积
        # q的b和h,和k的b和h相同的位置进行ij的相乘,einsum操作有它的公式
        # einsum 的行为完全由提供给它的公式字符串决定,b h d i, b h d j -> b h i j:这个公式明确指定了在哪些维度上执行操作
        # 在这个公式中,点积发生在 d 维度上,因为 d 是两个输入张量共有的唯一维度标识符,且在输出中并未出现
        # 根据公式,生成的是ij,所以要把d i转向,变成i d,这样i d 点乘 d j 才是i j
        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        # 执行的是一个数值稳定性操作,用于调整注意力得分矩阵 sim
        # sim.amax寻找最后一维的最大值,detach创建一个新的张量,但不需要梯度,sim再减去自己所对应的每部分里面的最大值
        # 这样保证没有很大的正数,保证了数值的稳定性,这样再用softmax函数比较好
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        # 对最后一维用softmax,让数值都变为0-1之间,且对应小团体内和为1
        attn = sim.softmax(dim=-1)
        '''
            下面这个张量操作比较复杂,它实现了注意力机制中的每个权重和对应的v里的每个元素相乘,然后再将所有的v相加操作
            首先它两个张量的b和h都是一样的,在对应的b和h维度进行操作
            先解释一下它最后形成的矩阵[b h i d]是啥意思,b h还是老意思,i d在图片上可以理解为每一个像素点i在和其它
            像素点做完注意力机制后的值,d其实是以前的通道数,在这里可以理解为在b h维度上i像素点去和其它像素点做了多少遍注意力
            i j和d j如何操作到了i d呢,它其实是首先固定i,然后让[i j]的j去和[d j]每一个对应位置的j相乘,然后把每一个d对应的所有乘积的值
            再加到一块儿形成一个值,这样就形成了i d,文字不好描述,看下面的过程
            当 j=0 时:
                results[0][1][2][0] += attn[0][1][2][0] * v[0][1][0][0]
                results[0][1][2][1] += attn[0][1][2][0] * v[0][1][1][0]
                ...
                results[0][1][2][9] += attn[0][1][2][0] * v[0][1][9][0]
            当 j=1 时:
                results[0][1][2][0] += attn[0][1][2][1] * v[0][1][0][1]
                results[0][1][2][1] += attn[0][1][2][1] * v[0][1][1][1]
                ...
                results[0][1][2][9] += attn[0][1][2][1] * v[0][1][9][1]
            其实就是模拟,每个权重和对应的v里的每个元素相乘,然后再将所有的v相加操作
        '''
        # 所以经过这个操作以后out为b h i d,i是图片的所有像素,d是每个像素做了多少遍注意力机制
        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        # 将所有像素i为原来的x y,再将 h d每一个batch样本有几个头,每个头有几个通道合并起来,也就是变为hidden_dim(每个batch样本的所有通道数)
        # 经过这个操作out变为 b hidden_dim x y
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        # 最后通过一个卷积欢迎为原始x的通道数
        return self.to_out(out)
        '''
            总结一下这个多头注意力操作:首先在__init__中定义了heads和dim_head,以及它们的乘积hidden_dim(QKV的通道数)
            然后定义了一个生成qkv的卷积操作,就是将输入通道变为了hidden_dim通道数,当然这里乘了个3,在forward中再分别分给QKV,然后又定义了个
            将hidden_dim通道数转为输入通道数的卷积操作
            接下来是forward过程,首先获得qkv,然后将q进行缩放,防止QK乘积过大,导致梯度出现问题,然后用einsum操作进行QK乘积,再softmax
            然后再将softmax的结果和v相乘,最后再将通道数转回
        '''
# 线性注意力机制的一个变体
class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        # 对倒数第二维做softmax会使分母下面变成倒数第二维度下的所有数
        # softmax_weight[i][j][h] = exp(q[i][j][h]) / sum(exp(q[i][j]))
        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        # 进行KV计算,
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
        # 然后与Q做点积
        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        # 输出
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)



# DDPM 作者将 U-Net 的卷积层/注意力层与 group normalization 交织在一起
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)


# Unet网络
class Unet(nn.Module):
    def __init__(
            self,
            dim,  # 结合上下文来看,这个dim貌似是通道数量
            init_dim=None,
            out_dim=None,
            dim_mults=(1, 2, 4, 8), # 用来改变通道数量
            channels=3,  # 输入图像的通道数
            self_condition=False,
            resnet_block_groups=4, # 指定ResNet块的组数
    ):
        super().__init__()

        # determine dimensions
        self.channels = channels
        self.self_condition = self_condition
        input_channels = channels * (2 if self_condition else 1)
        # 如果init_dim存在就用它,不存在就返回dim
        init_dim = default(init_dim, dim)
        # 将输入的通道数变为init_dim通道数
        self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0)  # changed to 1 and 0 from 7,3
        # 将dim变为它的1,2,4,8倍,在Python中,星号(*)操作符可以将一个可迭代对象(如列表、元组等)展开为位置参数列表
        # 在这个例子中,它将map()函数的结果转换为列表的一部分,Python内置函数,接受一个函数和一个或多个可迭代对象作为输入
        # 并返回一个由该函数应用于这些可迭代对象的每个元素后得到的新值组成的可迭代对象
        # 所以这个句话的意思就是取dim_mults中的元素为m,然后和dim相乘,最后用*展开
        # 若dim为28,则dims为[28, 28, 56, 112, 224]
        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        # zip是接受一个或多个可迭代对象作为参数,并返回一个元组组成的迭代器,其中每个元组包含了来自各个可迭代对象的一个元素
        # dims[:-1]提取dims列表除最后一个元素外的所有元素。例如,如果dims是[8, 16, 32, 64],则dims[:-1]为[8, 16, 32]
        # dims[1:]提取dims列表从第二个元素开始的所有元素。例如,如果dims是[8, 16, 32, 64],则dims[1:]为[16, 32, 64]
        # 将zip()返回的迭代器转换为列表
        # 根据上下文,如果dim为28,则in_out为[(28, 28), (28, 56), (56, 112), (112, 224)]
        in_out = list(zip(dims[:-1], dims[1:]))
        # partial的功能就是将一个类或者函数的参数预先固定,形成一个有了一部分参数的那个类或函数
        # 这是 partial 函数返回的新函数,它实际上是一个预配置了 groups 参数的 ResnetBlock 类构造器
        # 每次调用 block_klass 时,都相当于调用 ResnetBlock,但不需要再次指定 groups 参数,因为它已经被 partial 预设了
        # 总的来说,这行代码创建了一个定制版本的 ResnetBlock 构造器,其中 groups 参数已经设置好了
        # 这样做的好处是简化了后续对 ResnetBlock 的多次调用,使代码更加清晰和简洁
        block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time embeddings

        time_dim = dim * 4
        self.time_mlp = nn.Sequential(
            # 时间位置编码,先传入张量大小,使用时再传入时间
            SinusoidalPositionEmbeddings(dim),
            # 传入的是[m,dim],nn.linear会让[m,dim]乘以权重矩阵[dim,time_dim]张量,最后得到[m,time_dim],当然还有偏置向量[time_dim]
            nn.Linear(dim, time_dim),
            # 激活函数,相比于relu,它对于负数输出一个-1到0之间的值,这样会有更平滑的梯度
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # 下采样和上采样层
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        # 用来在下面判断循环是否是最后一次
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            # 判断循环是否是最后一次
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        # 两个残差块,貌似就是Unet在下采样和上采样之前,这一层要做的两次卷积?
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        # 线性注意力
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        # 进行下采样,
                        Downsample(dim_in, dim_out)
                        if not is_last
                        else nn.Conv2d(dim_in, dim_out, 3, padding=1),
                    ]
                )
            )
        # 将dims最后一个元素给mid_dim
        mid_dim = dims[-1]

        # unet最下面过程
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        # 上采样
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):

            is_last = ind == (len(in_out) - 1)

            self.ups.append(
                nn.ModuleList(
                    [   # 上采样过程,把对应下采样层加进来
                        # 在上面已经执行了Unet最下面两层了,这里不应该先执行Upsample吗?
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Upsample(dim_out, dim_in)
                        if not is_last
                        else nn.Conv2d(dim_out, dim_in, 3, padding=1),
                    ]
                )
            )
        # 最后输出的通道数
        self.out_dim = default(out_dim, channels)
        # 这两步是要干什么?
        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)

    def forward(self, x, time, x_self_cond=None):
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim=1)
        # 先将通道数变为Unet需要的通道数
        x = self.init_conv(x)
        r = x.clone()
        # 获得时间向量
        t = self.time_mlp(time)

        h = []
        # Unet下采样
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)
        # Unet最底层
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            # hop移除并返回列表的最后一个元素,将下采样对应层叠加过来
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        # 最后又做了一遍残差
        x = torch.cat((x, r), dim=1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)




# 定义前向扩散过程中的β,下面是四种定义,选一种用
## 使用余弦衰减的时间调度策略
def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

# 线性时间调度策略
def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

# 平方根时间调度策略
def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start ** 0.5, beta_end ** 0.5, timesteps) ** 2

# Sigmoid 时间调度策略
def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

# 定义一些必要的β、α相关变量
timesteps = 300

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas
alphas = 1. - betas
# α累乘
alphas_cumprod = torch.cumprod(alphas, axis=0)
# alphas_cumprod[:-1]取alphas_cumprod前n-1个数,F.pad对张量进行扩充,(1, 0)在张量左侧添加一个元素,右侧不添加,添加的是1.0
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# α开方
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
# 累乘开方
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
# 1-α累乘
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
# 通过Xt求Xt-1公式部分
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

# a是传进来的α相关变量
def extract(a, t, x_shape):
    # 获取有多少t
    batch_size = t.shape[0]
    # 使用 torch.gather 从 a 中沿最后一个维度(-1)提取 t 中指定的元素
    out = a.gather(-1, t.cpu())
    # (1,):这是一个只包含一个元素 1 的元组,len(x_shape) 是 x_shape 的长度,即数据 x_shape 的维度数
    # 将上述单元素元组 (1,) 重复 len(x_shape) - 1 次,生成一个新的元组
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)



# 从网上获取一张图片
from PIL import Image
import requests
import matplotlib.pyplot as plt
# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# image = Image.open(requests.get(url, stream=True).raw)
# plt.imshow(image)
# plt.show()

# 将图片转为张量,并且让它的值在像素值在-1到1之间
from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize

image_size = 128
transform = Compose([
    Resize(image_size),
    CenterCrop(image_size),
    ToTensor(),  # turn into Numpy array of shape HWC, divide by 255
    # 将张量中的像素值从[0,1]转为[-1,1]
    Lambda(lambda t: (t * 2) - 1),

])
# x_start = transform(image).unsqueeze(0)
# print(x_start.shape)

# 逆向变换,它接收一个值在 [−1,1] 之间的 PyTorch 张量,并转为 PIL 张量
import numpy as np

reverse_transform = Compose([
    Lambda(lambda t: (t + 1) / 2),
    Lambda(lambda t: t.permute(1, 2, 0)),  # CHW to HWC
    Lambda(lambda t: t * 255.),
    Lambda(lambda t: t.numpy().astype(np.uint8)),
    ToPILImage(),
])
# plt.imshow(reverse_transform(x_start.squeeze()))
# plt.show()


# 定义前向扩散过程
# forward diffusion (using the nice property)
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    # 从sqrt_alphas_cumprod中提取下标为t的一些量,并整形成x_start的形状
    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    # 同上
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )
    # 加噪
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

# 测试加噪
def get_noisy_image(x_start, t):
    # add noise
    x_noisy = q_sample(x_start, t=t)

    # turn back into PIL image
    noisy_image = reverse_transform(x_noisy.squeeze())

    return noisy_image
t = torch.tensor([100])
# x = get_noisy_image(x_start, t)
# plt.imshow(x)
# plt.show()


# 模型损失
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)
    # 绝对误差损失
    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    # 均方误差损失
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    # Huber 损失(平滑 L1 损失)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss


from datasets import load_dataset

# load dataset from the hub
dataset = load_dataset("fashion_mnist")
print(dataset)
''' dataset打印
外层字典:DatasetDict 对象包含了多个键值对,其中键通常是数据集的不同部分,这里是"train" 和 "test"
train和test每个键对应的值是一个 Dataset 对象,Dataset 对象本身也像一个字典,它包含了数据集的特征(如 image 和 label)作为键
"image" 键对应的值是一个包含所有图像数据的列表,"label" 键对应的值是一个包含所有标签数据的列表
当访问例如 dataset["train"]["image"] 时,实际上是在访问 train 分割中的所有图像数据
访问 dataset["train"]["label"] 将给您 train 分割中的所有标签数据
DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 60000
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 10000
    })
})
'''
# 查看 train 数据集的前 10 个数据
first_ten_train_data = dataset["train"][:10]

# 打印前 10 个数据
for i in range(10):
    image, label = first_ten_train_data["image"][i], first_ten_train_data["label"][i]
    # print(f"Image {image}:")
    # print(f"Label: {label}")
    # 这里可以添加代码来显示图像,例如使用 matplotlib
    # 但需要先将图像从 PIL/Image 格式转换为 numpy 数组

image_size = 28
channels = 1  # 这个数据集是灰度图像,所以通道是1,数据集本身图像规模为[1,28,28]
batch_size = 128


from torchvision import transforms
from torch.utils.data import DataLoader

# define image transformations (e.g. using torchvision)
transform = Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t * 2) - 1)
])

# define function
def transforms(examples):
    # 将键为examples的图像转换为灰度图,然后再赋值给新键pixel_values
    examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
    # 删除键名为imgae的数据
    del examples["image"]

    return examples

# with_transform 是一个方法,用于将一个转换函数应用于整个数据集
# 移除列名为label的列
# with_transform这个函数会自动将dataset里的数据也就每一个键对应的值都传递给transforms
# 所以examples就是DatasetDict中的每一个Dataset
transformed_dataset = dataset.with_transform(transforms).remove_columns("label")
# print(transformed_dataset)
''' 经过处理后
DatasetDict({
    train: Dataset({
        features: ['image'],
        num_rows: 60000
    })
    test: Dataset({
        features: ['image'],
        num_rows: 10000
    })
})
'''
# create dataloader
# transformed_dataset["train"]是将Dataset传入给了DataLoader,它能自动规划分配batch_size
# print(transformed_dataset["train"].__len__())
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)
batch = next(iter(dataloader))
# print(batch.keys())


# 采样过程
# 单步反向传播
@torch.no_grad()
def p_sample(model, x, t, t_index):
    # 提取时间步t对应的数据
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)

    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    # Xt戴Xt-1公式
    model_mean = sqrt_recip_alphas_t * (
            x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    # 是否是最终图像
    if t_index == 0:
        return model_mean
    else:
        # 得到这一步的后验方差,这里也就是给预测得到的Xt-1图像加点噪声,也就是那个z
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device

    b = shape[0]
    # start from pure noise (for each example in the batch)
    # img是数据张量形状
    img = torch.randn(shape, device=device)
    imgs = []
    # tqdm 可以显示一个进度条,实时更新以反映当前循环的进度,包括已处理的元素数量、总元素数量、循环的估计剩余时间和当前迭代的速度
    # desc 参数用于设置进度条前的描述性文本。它为进度条提供了一个简短的说明或标题
    # total 参数用于指定迭代的总次数,即进度条的总长度。当 tqdm 知道总迭代次数时,它可以更准确地计算并显示进度百分比和剩余时间
    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        # torch.full((b,), i, device=device, dtype=torch.long) 生成一个与批次大小 b 相同长度的张量,其中每个元素都是当前时间步 i
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        # 如果一个张量存储在 GPU 上(即 CUDA 张量),不能直接将其转换为 NumPy 数组,因为 NumPy 不支持 CUDA 操作
        imgs.append(img.cpu().numpy())
    return imgs
@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))


# 训练模型
from pathlib import Path
# 函数的目的是将一个整数 num 分割成若干个组,每组的大小不超过 divisor
def num_to_groups(num, divisor):
    # num:需要分割的整数     divisor:每个组的最大大小
    # 计算 num 可以被 divisor 整除多少次,得到的 groups 是分组的数量
    groups = num // divisor
    # 这是最后一个组可能包含的元素数量
    remainder = num % divisor
    # 创建一个列表,其中包含 groups 个元素,每个元素的值为 divisor
    arr = [divisor] * groups
    #
    if remainder > 0:
        arr.append(remainder)
    return arr

results_folder = Path("./results/original_result")
results_folder.mkdir(exist_ok=True)
# 保存和采样间隔
save_and_sample_every = 1000

from torch.optim import Adam
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)
# 加载之前训练的模型
model_save_path = "./model_save/original_model_save/ddpm_model.pth"
model.load_state_dict(torch.load(model_save_path))
model.to(device)
optimizer = Adam(model.parameters(), lr=1e-3)

from torchvision.utils import save_image

epochs = 1
for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()

        batch_size = batch["pixel_values"].shape[0]
        batch = batch["pixel_values"].to(device)

        # Algorithm 1 line 3: sample t uniformally for every example in the batch
        # 创造一个和batch_size规模一样的t
        t = torch.randint(0, timesteps, (batch_size,), device=device).long()
        # 做损失
        loss = p_losses(model, batch, t, loss_type="huber")

        if step % 100 == 0:
            print("Loss:", loss.item())
        # 梯度传递
        loss.backward()
        # 更新参数
        optimizer.step()
        # if step == 100:
        #     samples = sample(model, image_size=image_size, batch_size=64, channels=channels)
        #
        #     # show a random one
        #     random_index = 5
        #     sampled_image = samples[-1][random_index].reshape(image_size, image_size, channels)
        #     plt.imshow(sampled_image, cmap="gray")
        #     plt.savefig("./results/original_result/1-5.png")
        #     plt.show()
        # save generated images
        # if step != 0 and step % save_and_sample_every == 0:
        #     milestone = step // save_and_sample_every
        #     batches = num_to_groups(4, batch_size)
        #     all_images_list = list(map(lambda n: sample(model,image_size=image_size, batch_size=n, channels=channels), batches))
        #     all_images = torch.cat(all_images_list, dim=0)
        #     all_images = (all_images + 1) * 0.5
        #     save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow=6)

# 保存模型
# 模型保存路径
model_save_path = "./model_save/original_model_save/ddpm_model.pth"
# 保存模型状态字典
torch.save(model.state_dict(), model_save_path)


# 采样
# sample 64 images
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)
# show a random one
random_index = 5
# samples[-1]代表最取300个阶段的最后一个阶段,也就是最好的图像,random_index是在一个batch_size=64
# 也就是一次采样64个图片的情况下选择展示第几个图片
sampled_image = samples[-1][random_index].reshape(image_size, image_size, channels)
plt.imshow(sampled_image, cmap="gray")
plt.savefig("./results/original_result/6-5.png")
plt.show()
for i in range(64):
    if i%5==0:
        sampled_image = samples[-1][i].reshape(image_size, image_size, channels)
        plt.imshow(sampled_image, cmap="gray")
        path = str(f"{i}")
        plt.savefig(f"./results/original_result/6-{path}.png")
        plt.show()
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值