AIGC入门(二)从零开始搭建Diffusion!(上)

前言

图像生成一直是CV界经久不衰的话题。自从2020年的DDPM(Denoising Diffusion Probabilistic Models)横空出世之后,几乎所有的图像生成领域的细分方向都被横扫一空。此之后,Latent-Diffusion、DiT等模型横空出世,也支撑起了包括Stable Diffusion系列、DALL-E系列、Sora系列的工程开发,并展示出了不少震撼人心的效果。

万丈高楼平地起,正因为如此,狠狠地了解最原始的DDPM是很有必要的!笔者在空闲时间,决定狠狠拆解相关源码,并从零开始重新写一个DDPM!

DDPM看起来数学推导十分复杂,其实并没有那么可怕。在编程方面所使用到的数学知识没有超过大学公共课的概率论的内容。详情可以参考这篇文章(这位作者写得非常好,让笔者这种数学稀碎的人都能看懂并理解其中的思想)。

笔者将这篇文章的叙述重点放在了代码上,故很多的介绍重点都在代码的注释里。为了更好的阅读体验,请最好用电脑阅读!

最后完成的代码组成结构如下。我们将逐个文件开始依次入手编写。

dataset是一个小型数据集,后续会有介绍。results包含的是相关采样结果。

话不多说,现在立马开始!

(注意:我们的这个只是一个学习级别的Demo,因此数据集、模型参数、训练批次都不会很大。主要是从代码入手,跑通整个DDPM的架构。因此最终的效果请不要介意。)

  • 这篇文章旨在从代码级别简略介绍DDPM的结构特征。中间会穿插一些知识点方便记忆。受文章字数限制,本文分上下两篇。

  • 本篇文章代码参考了这篇代码,同时也参考了他人解析。笔者在此基础上进行了较大的修改,使得代码阅读更加适合新手。更多详细的原理细节,我希望读者能够去阅读详细的阅读链接中的文档,并且一定要去阅读原论文

  • 这期文章面向的是想要入门DDPM的新手,希望这篇文章对你有所帮助。大佬可以指出其中的错误,感谢不尽!另外,用电脑阅读的效果比用手机阅读的效果要好上不少。

一、基本过程理解

让笔者来做个简介,简单介绍一下DDPM的流程,不然的话写起来代码都不知道是干啥的。

DDPM不像Transformer一样是一个纯Architecture的工作。我们都知道它包含了这样的几个过程:

  1. 一个前向加噪过程。我们需要不断的往一张图片里添加噪声,直到它最终最终变成高斯噪声。

  2. 一个需要学习的反向去噪过程。我们训练好的模型将不断的给一张纯噪声去噪,最终变成我们需要的图案。

让我们稍微看一下!

1、前向加噪过程

前向加噪过程是一个马尔可夫过程,我们的每一步x_t都只依赖于上一步x_{t-1}

假设x_0是我们的真实图像,我执行 𝑡 步加噪后的带噪声的图像为x_t,那么我的前向扩散过程可以用这样的公式来描述:

x_t=\sqrt{1-\beta_t}x_{t-1}+\sqrt{\beta_t}\epsilon_{t-1}

用条件概率的方向去看,我们其实就是正在计算q(x_t|x_{t-1})

注意,我们的\beta_t是一个负责进行调节的参数。他满足0<\beta_{t}<1,负责为我们的加噪过程添加一点的扰动。而\epsilon_{t-1}是一个从标准高斯分布中采样的噪声。注意:我们每一次加噪的时候,都会进行一次采样,获得不同的 𝜖 !

我们在实操中的\beta_t比较小,还是让x_t的信息主要落在x_{t-1}

在前向加噪的过程中,我们倾向于让t变得很大很大,这样x_t就倾向于高斯分布啦!这整个其实就是前向加噪过程。

2、前向加噪过程的简化

但是,在训练时,一步步扩散的方式真的太过于麻烦了!这样子将大大增加训练的复杂度。想象一下,假设我么你的前向加噪过程有1000步( 𝑇=1000 ),那我们要训练反向去噪的第2步时,必须先前向加噪999步!多麻烦啊!

因此,原论文作者就想,如果q(x_t|x_{t-1})过于麻烦,那可不可以直接求得q(x_t|x_{0})?这样,当我们有了一张真实图片时,就可以直接获得某一个t对应的加噪图像了。

受于篇幅限制,笔者将直接给出q(x_t|x_{0})的推导结果。推导后的表达式如下:

x_t=\sqrt{\bar{\alpha}_t}x_{0}+\sqrt{1-\bar{\alpha}_t}\bar{\epsilon}_{t}

其中,\alpha_t = 1-\beta_t, \bar{\alpha}_t=\prod_{i=1}^{t}\alpha_i 。而\bar{\epsilon}_{t}就是我每t步都采样高斯分布中的噪声的均值,也可以视作单独采样得到的一个高斯噪声。这也是我们要在代码中实现的真正公式。

3、反向去噪过程

反向过程,实际上就是从一张高斯分布的图像出发,反向的得到我们的真实图像。也就是说,现在,我们已经有了一张类似于高斯噪声的x_t,我们要将他一步一步去噪,并生成我们的最终真实图像x_0。一步一步推导的话,我们其实就是要求q(x_{t-1}|x_{t})

但实际上,这个求不出来。假设我们直接在最原始的前向加噪的式子中倒来倒去,我们其实可以得到这样一个式子:

x_{t-1}=\frac{x_t}{\sqrt{1-\beta_t}}-\frac{\sqrt{\beta_t}}{\sqrt{1-\beta_t}}\epsilon_{t-1}

但我们发现,这个\epsilon_{t-1}是得不到的。我们所有采样的\epsilon_t在反向过程中都是得不到的,因为在实际操作的过程中,这个东西总是会受到x_{t-1}的影响的。

因此,作者想到的最好的一个方法,就是用一个网络 \epsilon_{\theta}(x_t,t),去拟合我们的这个噪声。那么上面这个式子就可以写成这样:

x_{t-1}=\frac{x_t}{\sqrt{1-\beta_t}}-\frac{\sqrt{\beta_t}}{\sqrt{1-\beta_t}}\epsilon_{\theta}(x_t,t)

这其实就是我们的反向去噪过程。我们的网络本质上就是要学习一个去噪器,它将在这个过程中按照不同的时间步去除噪声,从而将我们的图像从噪声中“还原”出来。

4、反向去噪过程的简化

注意:如果按照我们上述的正常的反向去噪过程,那么,我们将面临简化前的问题:即训练太麻烦了!如果我们要执行去噪的第1000步时,我们必须要让我们的去噪器先去噪999步!那我们的训练要什么时候才是个头?

参考我们之前正向加噪过程中得到的那个一步到位的式子,即q(x_t|x_{0}) 。那么,在反向去噪过程中,我们能否将这个q(x_t|x_{0})结合起来,得到一个“一步加噪,然后降噪”的训练过程?

当然可以!故论文作者就找到了一个平替:q(x_{t-1}|x_{t},x_0)。即训练这个去噪器的时候,我的目标图像x_0是已知的,并且加噪t步得到的x_t也是已知的,我们便可以直接由这两者出发,去直接找到怎么生成第t-1步的图像啦!

同样,由于篇幅关系,笔者也不推导了。通过一系列计算,有公式如下:

x_{t-1}=\frac{x_t}{\sqrt{\alpha_t}}-\frac{\beta_t}{\sqrt{\alpha_t(1-\bar{\alpha}_t)}}\bar{\epsilon}_t+\frac{(1-\bar{\alpha}_{t-1})\beta_t}{1-\bar{\alpha}_{t}}z_t

其中z_t是一个在高斯噪声中采样得到的噪声,其含义同加噪过程中的\epsilon_t是一样的,一个已知量。

注意,这个表达式中的x_0已经使用q(x_t|x_{0})直接替换过啦,因此表达式中只含有x_t

这里的\bar{\epsilon}_t仍然是需要我们的去噪器去拟合的。因此,这个反向去噪过程我们可以改写成下面这样:

x_{t-1}=\frac{x_t}{\sqrt{\alpha_t}}-\frac{\beta_t}{\sqrt{\alpha_t(1-\bar{\alpha}_t)}}\bar{\epsilon}_{\theta}(x_t,t)+\frac{(1-\bar{\alpha}_{t-1})\beta_t}{1-\bar{\alpha}_{t}}z_t

其中,\bar{\epsilon}_{\theta}(x_t,t)就是我们去噪器要干的活了。

接下来,就该让我们真正的去撰写相关代码了!出发!

二、去噪器的组件设计(Model.py)

既然我们知道反向去噪过程就是让我们的去噪器去拟合噪声并完成降噪,那么,先来设计我们的去噪器就显得很有必要。

先导入我们的包:

import math
import torch
from torch import nn, einsum
import torch.nn.functional as F
from inspect import isfunction
from functools import partial
from einops.layers.torch import Rearrange

在论文中的去噪器的设计使用的是U-Net结构。这个医学图像分割领域的小鼻祖,相信大家都很熟悉了。我将其相关的原论文中的图放在下面:

有上下采样两个部分。每一采样层之间都有直连通道。

然而,作者在这里进行了很多方面的修改。其中包括添加了残差网络结构,注意力机制等。笔者将按照组件的形式分开来讲。可能按照组件的方式很难看懂,但是到了最后的Unet的组装这一步,就会茅塞顿开了!

1、上下采样部分

首先,我们需要一个辅助函数,来保证我的输入一定有值:

# 辅助函数;
def default(val, d):
    if val is not None:
        return val # 如果val存在,就返回val;
    
    # 如果val不存在,就检查d是否是一个函数;
    # 如果是函数就调用返回,不是函数就直接返回值;
    return d() if isfunction(d) else d

然后,我们介绍上采样。这里的上采样其实就是一个上采样层+一个卷积层。长这样:

# 上采样部分;
def Upsample(dim, dim_out=None):
    return nn.ModuleList([
        # 上采样层,缩放因子扩大两倍;
        nn.Upsample(scale_factor=2, mode="nearest"),
        # 卷积;
        nn.Conv2d(dim, default(dim_out, dim), 3, padding=1)])

很简单嘛!再来介绍一下我们的下采样函数:

# 下采样部分;
def Downsample(dim, dim_out=None):
    return nn.ModuleList([
        # 表示将批次大小为b,通道数为c,高度为h,宽度为w的特征图,
        # 按p1 和p2的值(这里都是2)来重排。
        Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
        # 卷积;
        nn.Conv2d(dim * 4, default(dim_out, dim), 1)])

这里有一个很老牌的函数Rearrange(),其实我们可以把它看作是一个类似于view()的函数,起到一个形状整理的作用。 dim为输入维度,dim_out为输出维度。

2、时间编码嵌入

这里其实就是借鉴了Transformer中的位置编码的思想,将一个带数值的信息通过正弦编码的方式,来捕获更高维的信息特征。相应的代码如下:

# 对于时间的位置编码;
class Time_Positional_Encoding(nn.Module):
    def __init__(self,dim):
        super(Time_Positional_Encoding,self).__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device # 同步设备是GPU还是CPU;
        half_dim = self.dim // 2 # 将时间位置编码的维度除上2;

        # 同transformer一样的位置编码的计算公式和方法;
        TPE = math.log(10000) / (half_dim - 1)
        TPE = torch.exp(torch.arange(half_dim, device=device) * -TPE)
        TPE = time[:, None] * TPE[None, :]
        TPE = torch.cat((TPE.sin(), TPE.cos()), dim=-1)
        return TPE

然后,我们就要去组装以下我们在单层通道中的ResNet块了。

3、ResNet块设计

在设计ResNet块前,这里先使用了一种叫做权重标准化的初始化方式,去初始化我们的二维卷积层。实际上,也就是将其权重初始化为均值为0,方差为1的一个分布状态。相应的代码如下:

# 获得标准化的权重初始化的卷积层;
# 相当于对每个输出通道的初始权重做归一化处理。
class WeightStandardizedConv2d(nn.Conv2d):
    def forward(self, x):
        # eps为防止方差为0的“保险”;
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        mean = weight.mean(dim=[1, 2, 3], keepdim=True)
        var = weight.var(dim=[1, 2, 3], keepdim=True, unbiased=False)
        normalized_weight = (weight - mean) / torch.sqrt(var + eps)
        
        # 返回权重初始化后的二维卷积层;
        return F.conv2d(
            x,
            normalized_weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )

当我们拥有了一个初始化的权重之后,我们可以先来构造一个block块。这个block块将作为我们的基础单元,参与到整个的ResNet块的构成中:

# 一个函数块;
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super(Block,self).__init__()
        # 标准化卷积层;
        self.StdConv2d = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
        # 归一化层;
        self.norm = nn.GroupNorm(groups, dim_out)
        # 激活层;
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        x = self.StdConv2d(x)
        x = self.norm(x)
        # 将时间作为调整信息嵌入到模块中来;
        if scale_shift is not None:
            scale, shift = scale_shift
            x = x * (scale + 1) + shift
        x = self.act(x)
        return x

我们很容易就注意到,这里存在一个scale_shift。这是干啥的?实际上,这个缩放和偏移量是由编码后的时间步t所带来的。它的缩放和偏移代表了时间步t所携带的信息,并对我们整体的去噪器产生影响。

接下来,构建我们整体的ResNet块:

# 一个残差网络块;
class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super(ResnetBlock,self).__init__()
        # 初始化 self.mlp;
        if time_emb_dim is not None:
            # 如果 time_emb_dim 存在,创建一个包含 SiLU 激活和线性变换的序列;
            self.mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_emb_dim, dim_out * 2))
        else:
            # 如果 time_emb_dim 不存在,self.mlp 为 None;
            self.mlp = None

        # 两个Block块;
        self.block1 = Block(dim, dim_out, groups=groups)

        self.block2 = Block(dim_out, dim_out, groups=groups)
        # 一个卷积层;
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    # 前向通道;
    def forward(self, x, time_emb=None):
        scale_shift = None
        if (self.mlp is not None) and (time_emb is not None):
            time_emb = self.mlp(time_emb)
            # 重塑成4维,方便进行卷积;
            time_emb = time_emb.unsqueeze(-1).unsqueeze(-1)
            # print("time_emb shape:",np.shape(time_emb))
            # 使用chunk方法,将其在channels维度上将其分割为两个维度;
            scale_shift = time_emb.chunk(2, dim=1) 

        h = self.block1(x, scale_shift=scale_shift)
        # print("h in ResBlock1:",np.shape(h))
        h = self.block2(h)
        # print("h in ResBlock2:",np.shape(h))
        return h + self.res_conv(x)

我们可以很容易的就看到,在这里面的时间将再经过一个激活层和线性层以增加可学习参数,并在之后通过前述的缩放偏移作为时间信息添加到我们的整个去噪器中。

同样,在单层,作者也广泛吸收了Transformer的美好思想,在构建单层信息传递时,添加了注意力模块。下面将进行介绍。

4、注意力模块

下面是自注意力模块的相关代码:

# 添加自注意力模块;
class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super(Attention,self).__init__()
        # 缩放因子,用于查询张量的缩放;
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads# 隐藏层的维度;
        # 卷积层,用于生成查询、键和值张量;
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        # 输出用的卷积层;
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape# 获取输入的维度信息;
        # 通过卷积层生成查询、键和值张量,并调整形状;
        qkv = self.to_qkv(x).view(b, self.heads, -1, 3, h * w)
        q, k, v = qkv.unbind(dim=3)
        q = q * self.scale

        sim = torch.matmul(q.transpose(-2, -1), k)# 计算查询和键张量之间的相似度;
        sim = sim - torch.max(sim, dim=-1, keepdim=True)[0]
        attn = torch.softmax(sim, dim=-1)

        # 根据注意力权重和值张量计算输出;
        out = torch.matmul(attn, v.transpose(-2, -1))
        out = out.transpose(-2, -1).contiguous().view(b, -1, h, w)
        return self.to_out(out)

我们可以看见,这里的自注意力的Q、K、V是通过对一个卷积层输出的结果来做的。

作者也添加了一个线性注意力机制在其中。相应代码如下:

# 添加线性注意力层;
class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super(LinearAttention,self).__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads # 计算隐藏层维度;
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        # 输出层卷积和归一化;
        self.to_out = nn.Sequential(
            nn.Conv2d(hidden_dim, dim, 1), 
            nn.GroupNorm(1, dim)
        )

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)

之后,我们需要一个归一化层,来将我们的信息进行归一化处理。

5、GroupNorm层和Residual结构

这个层的构造比较简单,因此不需要进行介绍。稍微看下就好:

# Group normalization;
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

Residual结构也是老生常谈了:

# 残差结构;
class Residual(nn.Module):
    def __init__(self,fn,dropout=0.1):
        super(Residual,self).__init__()
        self.fn = fn
        self.dropout = nn.Dropout(dropout)

    # *args用来传递任意数量的值,**kwargs用来传递任意数量的键值;
    def forward(self,x,*args, **kwargs):
        return x + self.dropout(self.fn(x,*args, **kwargs))

到这里,所有的组件就全部构造完成!接下来,就到我们的整一个去噪器的网络组装啦!

三、去噪器U-Net结构的组装(Model.py)

构建我们的网络结构,笔者打算将他的全部拆开来讲,也会让读者理解的更清晰一些。我们首先看看U-Net的一个类的大致架构:

class Unet(nn.Module):
    def __init__(
        self,
        dim, # 特征的维度;
        init_dim=None, # 初始化的特征维度;
        out_dim=None, # 输出结果的特征维度;
        dim_mults=(1, 2, 4, 8), # 每一个下采样步骤中的特征维度的倍数;
        channels=3, # 输入图像的通道数,默认为3(RGB);
        self_condition=False,# 是否自我条件化,用于控制输入通道数;
        resnet_block_groups=4, # ResnetBlock的组数;
    ):
        ...
    def forward(self, x, time, x_self_cond=None):
        ...

可以看到,我们能够很简单的将其按照init和forward来介绍。首先,先来从init吧开始剖析吧!

1、模型初始化(__init__)

首先,让我们看看模型的一些参数是怎么赋值的:

super(Unet,self).__init__()

self.channels = channels
self.self_condition = self_condition
time_dim = dim * 4 # 时间嵌入的维度;
input_channels = channels * (2 if self_condition else 1)# 根据条件化标志计算输入通道数;
init_dim = default(init_dim, dim)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]# 计算每个下采样步骤的特征维度;
in_out = list(zip(dims[:-1], dims[1:]))# 创建输入输出维度对(每一个采样层的dim_in 与 dim_out);

self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0)# 初始卷积层;
block_klass = partial(ResnetBlock, groups=resnet_block_groups)# 创建ResNet块,即有这么多的块组装的网络层;
self.time_mlp = nn.Sequential( # 时间嵌入层;
    Time_Positional_Encoding(dim),
    nn.Linear(dim, time_dim),
    nn.GELU(),
    nn.Linear(time_dim, time_dim),
)

我们注意到,时间嵌入维度实际上是我们规定的正常输入图片维度的4倍。

我们还设计了一个时间步的整体编码。一个t进入后,将通过TPE先编码到正常维度,再通过线性层扩充我们的可学习参数,并且中间激活了一下。

之后,就要开始构建我们的上下采样层啦!

self.downs = nn.ModuleList([]) # 下采样部分;
self.ups = nn.ModuleList([]) # 上采样部分;
num_resolutions = len(in_out) # 上下采样层数;

首先,构建我们的下采样层:

########################### 开始构建下采样层:###########################
for ind, (dim_in, dim_out) in enumerate(in_out):
    is_last = ind >= (num_resolutions - 1) # 判断是否是最后一个下采样层;

    self.downs.append(
    nn.ModuleList(
        [   # 如果不是最后一层的话;
        block_klass(dim_in, dim_in, time_emb_dim=time_dim), # ResNet块;
        block_klass(dim_in, dim_in, time_emb_dim=time_dim), # ResNet块;
        Residual(PreNorm(dim_in, LinearAttention(dim_in))), # 带有线性注意力机制的残差块;
        Downsample(dim_in, dim_out) # 下采样一层;
        if not is_last
        # 如果是最后一层的话,一个简单的卷积层就可以了;
        else nn.Conv2d(dim_in, dim_out, 3, padding=1),
        ]
    )
    )

下采样层增加了若干个ResNet块组成的整体网络,还通过一个残差网络连接一个线性注意力机制层的输入输出。

底部是一个中间层:

############################# 开始构建中间层:###########################
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) # ResNet块;
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))  # 带有注意力机制的残差块;
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) # ResNet块;

然后,构建我们的上采样层:

########################### 开始构建上采样层:###########################
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
    is_last = ind == (len(in_out) - 1)

    self.ups.append(
    nn.ModuleList(
        [
        # 如果不是最后一层的话;
        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),# ResNet块;
        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),# ResNet块;
        Residual(PreNorm(dim_out, LinearAttention(dim_out))),# 带有线性注意力机制的残差块;
        Upsample(dim_out, dim_in) # 上采样一层;
        if not is_last
        # 如果是最后一层的话,一个简单的卷积层就可以了;
        else nn.Conv2d(dim_out, dim_in, 3, padding=1),
        ]
    )
    )

结构和下采样层是完全对称的!

最后,我们的输出需要再通过一个ResNet块和一个卷积层:

self.out_dim = default(out_dim, channels)# 获得最终的输出维度;
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)# 通过一个ResNet块;
self.final_conv = nn.Conv2d(dim, self.out_dim, 1) # 使用1x1的卷积核获得最终输出;

到此,我们的整个U-Net网络结构就大功告成啦!接下来就是看看前向是怎么做的了:

2、前向传播(forward)

首先,我们需要做一些准备。

因为我们输入的是一整个的层,以及我们的时间步,首先,我们需要对他们处理一下。

if self.self_condition: # 自我条件化的话,将对应张量合并。
    x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
    x = torch.cat((x_self_cond, x), dim=1)

    x = self.init_conv(x) # 通过初始卷积层处理输入x;
    r = x.clone()
    # print("time shape in init:",np.shape(time))
    # print("x shape in init:",np.shape(x))
    t = self.time_mlp(time)  # 通过时间MLP处理时间嵌入;
    # print("t shape in init:",np.shape(t))
    h = [] # 初始化一个列表来存储中间特征;

这里的h就记录了我们下采样过程中的特征图信息。他们将和后续的上采样过程中的特征图进行合并,也就是那个前向通道。

接下来是下采样过程:

###################### 开始下采样过程;######################
for block1, block2, attn, downsample in self.downs:
    x = block1(x, t) # 应用第一个ResNet块;
    # print("x shape in block1:",np.shape(x))
    h.append(x) # 将特征添加到h列表;
    x = block2(x, t)  # 应用第二个ResNet块;
    # print("x shape in block2:",np.shape(x))
    x = attn(x)  # 应用注意力机制;
    # print("x shape in down attn:",np.shape(x))
    h.append(x)  # 再次将特征添加到h列表;
    x = downsample(x) # 下采样;
    # print("x shape in down layer:",np.shape(x))

我们在下采样过程中加入了Attention模块来添加信息。

然后是中间层的处理:

###################### 开始中间层处理;######################
# print("x shape in mid layer:",np.shape(x))
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)

然后是我们的上采样过程:

###################### 开始上采样过程;######################
for block1, block2, attn, upsample in self.ups:
    x = torch.cat((x, h.pop()), dim=1) # 将特征与h列表中最后一个特征合并;
    x = block1(x, t)  # 应用第一个ResNet块;

    x = torch.cat((x, h.pop()), dim=1)  # 再次将特征与h列表中最后一个特征合并;
    x = block2(x, t)  # 应用第二个ResNet块;
    x = attn(x) # 应用注意力机制;

    x = upsample(x) # 上采样;

可以看到,在上采样过程中,我们将下采样中的记录的,存放在h中的特征图进行了拼接,最终得到了结果。

之后,简单的收尾即可:

x = torch.cat((x, r), dim=1) # 将特征与初始复制的特征r合并;
x = self.final_res_block(x, t)  # 应用最终的ResNet块;
x = self.final_conv(x) # 通过最终的卷积层处理并返回结果。
return x

到此,我们的整个去噪器就组装完成!可喜可贺,可喜可贺!

下篇。icon-default.png?t=N7T8https://blog.csdn.net/alxws/article/details/140059294?spm=1001.2014.3001.5502

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值