昇思25天学习打卡营第6天 | Diffusion扩散模型

模型简介

什么是Diffusion Model?

如果将Diffusion与其他生成模型(如Normalizing Flows、GAN或VAE)进行比较,它并没有那么复杂,它们都将噪声从一些简单分布转换为数据样本,Diffusion也是从纯噪声开始通过一个神经网络学习逐步去噪,最终得到一个实际图像。 Diffusion对于图像的处理包括以下两个过程:

  • 我们选择的固定(或预定义)正向扩散过程 𝑞

  • :它逐渐将高斯噪声添加到图像中,直到最终得到纯噪声

  • 一个学习的反向去噪的扩散过程 𝑝𝜃

  • :通过训练神经网络从纯噪声开始逐渐对图像去噪,直到最终得到一个实际的图像

Image-2

由 𝑡

索引的正向和反向过程都发生在某些有限时间步长 𝑇(DDPM作者使用 𝑇=1000)内。从𝑡=0开始,在数据分布中采样真实图像 𝐱0(本文使用一张来自ImageNet的猫图像形象的展示了diffusion正向添加噪声的过程),正向过程在每个时间步长 𝑡 都从高斯分布中采样一些噪声,再添加到上一个时刻的图像中。假定给定一个足够大的 𝑇 和一个在每个时间步长添加噪声的良好时间表,您最终会在 𝑡=𝑇

通过渐进的过程得到所谓的各向同性的高斯分布

扩散模型实现原理

Diffusion 前向过程

所谓前向过程,即向图片上加噪声的过程。虽然这个步骤无法做到图片生成,但这是理解diffusion model以及构建训练样本至关重要的一步。 首先我们需要一个可控的损失函数,并运用神经网络对其进行优化。

设 𝑞(𝑥0)

是真实数据分布,由于 𝑥0∼𝑞(𝑥0) ,所以我们可以从这个分布中采样以获得图像 𝑥0 。接下来我们定义前向扩散过程 𝑞(𝑥𝑡|𝑥𝑡−1) ,在前向过程中我们会根据已知的方差 0<𝛽1<𝛽2<...<𝛽𝑇<1

在每个时间步长 t 添加高斯噪声,由于前向过程的每个时刻 t 只与时刻 t-1 有关,所以也可以看做马尔科夫过程:

𝑞(𝐱𝑡|𝐱𝑡−1)=(𝐱𝑡;1−𝛽𝑡⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯√𝐱𝑡−1,𝛽𝑡𝐈)

回想一下,正态分布(也称为高斯分布)由两个参数定义:平均值 𝜇

和方差 𝜎2≥0 。基本上,在每个时间步长 𝑡

处的产生的每个新的(轻微噪声)图像都是从条件高斯分布中绘制的,其中

𝑞(𝜇𝑡)=1−𝛽𝑡⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯√𝐱𝑡−1

我们可以通过采样 𝜖∼(0,𝐈)

然后设置

𝑞(𝐱𝑡)=1−𝛽𝑡⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯√𝐱𝑡−1+𝛽𝑡⎯⎯⎯√𝜖

请注意, 𝛽𝑡

在每个时间步长 𝑡 (因此是下标)不是恒定的:事实上,我们定义了一个所谓的“动态方差”的方法,使得每个时间步长的 𝛽𝑡

可以是线性的、二次的、余弦的等(有点像动态学习率方法)。

因此,如果我们适当设置时间表,从 𝐱0

开始,我们最终得到 𝐱1,...,𝐱𝑡,...,𝐱𝑇,即随着 𝑡 的增大 𝐱𝑡 会越来越接近纯噪声,而 𝐱𝑇

就是纯高斯噪声。

那么,如果我们知道条件概率分布 𝑝(𝐱𝑡−1|𝐱𝑡)

,我们就可以反向运行这个过程:通过采样一些随机高斯噪声 𝐱𝑇,然后逐渐去噪它,最终得到真实分布 𝐱0 中的样本。但是,我们不知道条件概率分布 𝑝(𝐱𝑡−1|𝐱𝑡)

。这很棘手,因为需要知道所有可能图像的分布,才能计算这个条件概率。

Diffusion 逆向过程

为了解决上述问题,我们将利用神经网络来近似(学习)这个条件概率分布 𝑝𝜃(𝐱𝑡−1|𝐱𝑡)

, 其中 𝜃 是神经网络的参数。如果说前向过程(forward)是加噪的过程,那么逆向过程(reverse)就是diffusion的去噪推断过程,而通过神经网络学习并表示 𝑝𝜃(𝐱𝑡−1|𝐱𝑡)

的过程就是Diffusion 逆向去噪的核心。

现在,我们知道了需要一个神经网络来学习逆向过程的(条件)概率分布。我们假设这个反向过程也是高斯的,任何高斯分布都由2个参数定义:

  • 由 𝜇𝜃

  • 参数化的平均值

  • 由 𝜇𝜃

  • 参数化的方差

综上,我们可以将逆向过程公式化为

𝑝𝜃(𝐱𝑡−1|𝐱𝑡)=(𝐱𝑡−1;𝜇𝜃(𝐱𝑡,𝑡),Σ𝜃(𝐱𝑡,𝑡))

其中平均值和方差也取决于噪声水平 𝑡

,神经网络需要通过学习来表示这些均值和方差。

  • 注意,DDPM的作者决定保持方差固定,让神经网络只学习(表示)这个条件概率分布的平均值 𝜇𝜃。

  • 本文我们同样假设神经网络只需要学习(表示)这个条件概率分布的平均值 𝜇𝜃。

为了导出一个目标函数来学习反向过程的平均值,作者观察到 𝑞

和 𝑝𝜃 的组合可以被视为变分自动编码器(VAE)。因此,变分下界(也称为ELBO)可用于最小化真值数据样本 𝐱0 的似然负对数(有关ELBO的详细信息,请参阅VAE论文(Kingma等人,2013年)),该过程的ELBO是每个时间步长的损失之和 𝐿=𝐿0+𝐿1+...+𝐿𝑇 ,其中,每项的损失 𝐿𝑡 (除了 𝐿0

)实际上是2个高斯分布之间的KL发散,可以明确地写为相对于均值的L2-loss!

如Sohl-Dickstein等人所示,构建Diffusion正向过程的直接结果是我们可以在条件是 𝐱0

(因为高斯和也是高斯)的情况下,在任意噪声水平上采样 𝐱𝑡 ,而不需要重复应用 𝑞 去采样 𝐱𝑡

,这非常方便。使用𝛼𝑡:=1−𝛽𝑡𝛼¯𝑡:=Π𝑡𝑠=1𝛼𝑠我们就有𝑞(𝐱𝑡|𝐱0)=(𝐱t;𝛼¯t⎯⎯⎯√𝐱0,(1⎯𝛼¯t)𝐈)

这意味着我们可以采样高斯噪声并适当地缩放它,然后将其添加到 𝐱0中,直接获得 𝐱𝑡。

请注意,𝛼¯𝑡是已知 𝛽𝑡 方差计划的函数,因此也是已知的,可以预先计算。这允许我们在训练期间优化损失函数 𝐿 的随机项。或者换句话说,在训练期间随机采样 𝑡 并优化 𝐿𝑡。

正如Ho等人所展示的那样,这种性质的另一个优点是可以重新参数化平均值,使神经网络学习(预测)构成损失的KL项中噪声的附加噪声。这意味着我们的神经网络变成了噪声预测器,而不是(直接)均值预测器。其中,平均值可以按如下方式计算:𝜇𝜃(𝐱𝑡,𝑡)=1𝛼𝑡⎯⎯⎯√(𝐱𝑡−𝛽𝑡1−𝛼¯𝑡⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯√𝜖𝜃(𝐱𝑡,𝑡))最终的目标函数 𝐿𝑡

如下 (随机步长 t 由 (𝜖∼𝑁(0,𝐈))给定):‖𝜖−𝜖𝜃(𝐱𝑡,𝑡)‖2=‖𝜖−𝜖𝜃(𝛼¯𝑡⎯⎯⎯√𝐱0+(1−𝛼¯𝑡)⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯⎯√𝜖,𝑡)‖2

在这里, 𝐱0是初始(真实,未损坏)图像, 𝜖 是在时间步长 𝑡 采样的纯噪声,𝜖𝜃(𝐱𝑡,𝑡)

是我们的神经网络。神经网络是基于真实噪声和预测高斯噪声之间的简单均方误差(MSE)进行优化的。

训练算法现在如下所示:

Image-3

 目前,扩散模型的主要(也许唯一)缺点是它们需要多次正向传递来生成图像(对于像GAN这样的生成模型来说,情况并非如此)。然而,有正在进行中的研究表明只需要10个去噪步骤就能实现高保真生成。

相较GAN生成算法,扩散模型的基本思路是对图像中的随机噪音去除的思路来还原图像,可以还原图像也可以基于去噪的模式将原本的图像转换成另一种风格的图像,后续很多的生成算法图像生成算法都是基于此种思路来进行处理,其中包含u-net架构已经是图像生成算法的标配。

  • U-Net神经网络预测噪声

    神经网络需要在特定时间步长接收带噪声的图像,并返回预测的噪声。请注意,预测噪声是与输入图像具有相同大小/分辨率的张量。因此,从技术上讲,网络接受并输出相同形状的张量。那么我们可以用什么类型的神经网络来实现呢?

    这里通常使用的是非常相似的自动编码器,您可能还记得典型的"深度学习入门"教程。自动编码器在编码器和解码器之间有一个所谓的"bottleneck"层。编码器首先将图像编码为一个称为"bottleneck"的较小的隐藏表示,然后解码器将该隐藏表示解码回实际图像。这迫使网络只保留bottleneck层中最重要的信息。

    在模型结构方面,DDPM的作者选择了U-Net,出自(Ronneberger et al.,2015)(当时,它在医学图像分割方面取得了最先进的结果)。这个网络就像任何自动编码器一样,在中间由一个bottleneck组成,确保网络只学习最重要的信息。重要的是,它在编码器和解码器之间引入了残差连接,极大地改善了梯度流(灵感来自于(He et al., 2015))。

    Image-4

    可以看出,U-Net模型首先对输入进行下采样(即,在空间分辨率方面使输入更小),之后执行上采样。

    构建Diffusion模型

    下面,我们逐步构建Diffusion模型。

    首先,我们定义了一些帮助函数和类,这些函数和类将在实现神经网络时使用。

    def rearrange(head, inputs):
    
        b, hc, x, y = inputs.shape
    
        c = hc // head
    
        return inputs.reshape((b, head, c, x * y))
    
    ​
    
    def rsqrt(x):
    
        res = ops.sqrt(x)
    
        return ops.inv(res)
    
    ​
    
    def randn_like(x, dtype=None):
    
        if dtype is None:
    
            dtype = x.dtype
    
        res = ops.standard_normal(x.shape).astype(dtype)
    
        return res
    
    ​
    
    def randn(shape, dtype=None):
    
        if dtype is None:
    
            dtype = ms.float32
    
        res = ops.standard_normal(shape).astype(dtype)
    
        return res
    
    ​
    
    def randint(low, high, size, dtype=ms.int32):
    
        res = ops.uniform(size, Tensor(low, dtype), Tensor(high, dtype), dtype=dtype)
    
        return res
    
    ​
    
    def exists(x):
    
        return x is not None
    
    ​
    
    def default(val, d):
    
        if exists(val):
    
            return val
    
        return d() if callable(d) else d
    
    ​
    
    def _check_dtype(d1, d2):
    
        if ms.float32 in (d1, d2):
    
            return ms.float32
    
        if d1 == d2:
    
            return d1
    
        raise ValueError('dtype is not supported.')
    
    ​
    
    class Residual(nn.Cell):
    
        def __init__(self, fn):
    
            super().__init__()
    
            self.fn = fn
    
    ​
    
        def construct(self, x, *args, **kwargs):
    
            return self.fn(x, *args, **kwargs) + x

    我们还定义了上采样和下采样操作的别名。

    def Upsample(dim):
    
        return nn.Conv2dTranspose(dim, dim, 4, 2, pad_mode="pad", padding=1)
    
    ​
    
    def Downsample(dim):
    
        return nn.Conv2d(dim, dim, 4, 2, pad_mode="pad", padding=1)

    位置向量

    由于神经网络的参数在时间(噪声水平)上共享,作者使用正弦位置嵌入来编码𝑡

    ,灵感来自Transformer(Vaswani et al., 2017)。对于批处理中的每一张图像,神经网络"知道"它在哪个特定时间步长(噪声水平)上运行。

    SinusoidalPositionEmbeddings模块采用(batch_size, 1)形状的张量作为输入(即批处理中几个有噪声图像的噪声水平),并将其转换为(batch_size, dim)形状的张量,其中dim是位置嵌入的尺寸。然后,我们将其添加到每个剩余块中。

    class SinusoidalPositionEmbeddings(nn.Cell):
    
        def __init__(self, dim):
    
            super().__init__()
    
            self.dim = dim
    
            half_dim = self.dim // 2
    
            emb = math.log(10000) / (half_dim - 1)
    
            emb = np.exp(np.arange(half_dim) * - emb)
    
            self.emb = Tensor(emb, ms.float32)
    
    ​
    
        def construct(self, x):
    
            emb = x[:, None] * self.emb[None, :]
    
            emb = ops.concat((ops.sin(emb), ops.cos(emb)), axis=-1)
    
            return emb

    ResNet/ConvNeXT块

    接下来,我们定义U-Net模型的核心构建块。DDPM作者使用了一个Wide ResNet块(Zagoruyko et al., 2016),但Phil Wang决定添加ConvNeXT(Liu et al., 2022)替换ResNet,因为后者在图像领域取得了巨大成功。

    在最终的U-Net架构中,可以选择其中一个或另一个,本文选择ConvNeXT块构建U-Net模型。

    class Block(nn.Cell):
    
        def __init__(self, dim, dim_out, groups=1):
    
            super().__init__()
    
            self.proj = nn.Conv2d(dim, dim_out, 3, pad_mode="pad", padding=1)
    
            self.proj = c(dim, dim_out, 3, padding=1, pad_mode='pad')
    
            self.norm = nn.GroupNorm(groups, dim_out)
    
            self.act = nn.SiLU()
    
    ​
    
        def construct(self, x, scale_shift=None):
    
            x = self.proj(x)
    
            x = self.norm(x)
    
    ​
    
            if exists(scale_shift):
    
                scale, shift = scale_shift
    
                x = x * (scale + 1) + shift
    
    ​
    
            x = self.act(x)
    
            return x
    
    ​
    
    class ConvNextBlock(nn.Cell):
    
        def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
    
            super().__init__()
    
            self.mlp = (
    
                nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))
    
                if exists(time_emb_dim)
    
                else None
    
            )
    
    ​
    
            self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")
    
            self.net = nn.SequentialCell(
    
                nn.GroupNorm(1, dim) if norm else nn.Identity(),
    
                nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),
    
                nn.GELU(),
    
                nn.GroupNorm(1, dim_out * mult),
    
                nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),
    
            )
    
    ​
    
            self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
    
    ​
    
        def construct(self, x, time_emb=None):
    
            h = self.ds_conv(x)
    
            if exists(self.mlp) and exists(time_emb):
    
                assert exists(time_emb), "time embedding must be passed in"
    
                condition = self.mlp(time_emb)
    
                condition = condition.expand_dims(-1).expand_dims(-1)
    
                h = h + condition
    
    ​
    
            h = self.net(h)
    
            return h + self.res_conv(x)

    Attention模块

    接下来,我们定义Attention模块,DDPM作者将其添加到卷积块之间。Attention是著名的Transformer架构(Vaswani et al., 2017),在人工智能的各个领域都取得了巨大的成功,从NLP到蛋白质折叠。Phil Wang使用了两种注意力变体:一种是常规的multi-head self-attention(如Transformer中使用的),另一种是LinearAttention(Shen et al., 2018),其时间和内存要求在序列长度上线性缩放,而不是在常规注意力中缩放。 要想对Attention机制进行深入的了解,请参照Jay Allamar的精彩的博文

    class Attention(nn.Cell):
    
        def __init__(self, dim, heads=4, dim_head=32):
    
            super().__init__()
    
            self.scale = dim_head ** -0.5
    
            self.heads = heads
    
            hidden_dim = dim_head * heads
    
    ​
    
            self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
    
            self.to_out = nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True)
    
            self.map = ops.Map()
    
            self.partial = ops.Partial()
    
    ​
    
        def construct(self, x):
    
            b, _, h, w = x.shape
    
            qkv = self.to_qkv(x).chunk(3, 1)
    
            q, k, v = self.map(self.partial(rearrange, self.heads), qkv)
    
    ​
    
            q = q * self.scale
    
    ​
    
            # 'b h d i, b h d j -> b h i j'
    
            sim = ops.bmm(q.swapaxes(2, 3), k)
    
            attn = ops.softmax(sim, axis=-1)
    
            # 'b h i j, b h d j -> b h i d'
    
            out = ops.bmm(attn, v.swapaxes(2, 3))
    
            out = out.swapaxes(-1, -2).reshape((b, -1, h, w))
    
    ​
    
            return self.to_out(out)
    
    ​
    
    ​
    
    class LayerNorm(nn.Cell):
    
        def __init__(self, dim):
    
            super().__init__()
    
            self.g = Parameter(initializer('ones', (1, dim, 1, 1)), name='g')
    
    ​
    
        def construct(self, x):
    
            eps = 1e-5
    
            var = x.var(1, keepdims=True)
    
            mean = x.mean(1, keep_dims=True)
    
            return (x - mean) * rsqrt((var + eps)) * self.g
    
    ​
    
    ​
    
    class LinearAttention(nn.Cell):
    
        def __init__(self, dim, heads=4, dim_head=32):
    
            super().__init__()
    
            self.scale = dim_head ** -0.5
    
            self.heads = heads
    
            hidden_dim = dim_head * heads
    
            self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
    
    ​
    
            self.to_out = nn.SequentialCell(
    
                nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True),
    
                LayerNorm(dim)
    
            )
    
    ​
    
            self.map = ops.Map()
    
            self.partial = ops.Partial()
    
    ​
    
        def construct(self, x):
    
            b, _, h, w = x.shape
    
            qkv = self.to_qkv(x).chunk(3, 1)
    
            q, k, v = self.map(self.partial(rearrange, self.heads), qkv)
    
    ​
    
            q = ops.softmax(q, -2)
    
            k = ops.softmax(k, -1)
    
    ​
    
            q = q * self.scale
    
            v = v / (h * w)
    
    ​
    
            # 'b h d n, b h e n -> b h d e'
    
            context = ops.bmm(k, v.swapaxes(2, 3))
    
            # 'b h d e, b h d n -> b h e n'
    
            out = ops.bmm(context.swapaxes(2, 3), q)
    
    ​
    
            out = out.reshape((b, -1, h, w))
    
            return self.to_out(out)

    组归一化

    DDPM作者将U-Net的卷积/注意层与群归一化(Wu et al., 2018)。下面,我们定义一个PreNorm类,将用于在注意层之前应用groupnorm。

    class PreNorm(nn.Cell):
    
        def __init__(self, dim, fn):
    
            super().__init__()
    
            self.fn = fn
    
            self.norm = nn.GroupNorm(1, dim)
    
    ​
    
        def construct(self, x):
    
            x = self.norm(x)
    
            return self.fn(x)

    条件U-Net

    我们已经定义了所有的构建块(位置嵌入、ResNet/ConvNeXT块、Attention和组归一化),现在需要定义整个神经网络了。请记住,网络 𝜖𝜃(𝐱𝑡,𝑡)

    的工作是接收一批噪声图像+噪声水平,并输出添加到输入中的噪声。

    更具体的: 网络获取了一批(batch_size, num_channels, height, width)形状的噪声图像和一批(batch_size, 1)形状的噪音水平作为输入,并返回(batch_size, num_channels, height, width)形状的张量。

    网络构建过程如下:

    • 首先,将卷积层应用于噪声图像批上,并计算噪声水平的位置

    • 接下来,应用一系列下采样级。每个下采样阶段由2个ResNet/ConvNeXT块 + groupnorm + attention + 残差连接 + 一个下采样操作组成

    • 在网络的中间,再次应用ResNet或ConvNeXT块,并与attention交织

    • 接下来,应用一系列上采样级。每个上采样级由2个ResNet/ConvNeXT块+ groupnorm + attention + 残差连接 + 一个上采样操作组成

    • 最后,应用ResNet/ConvNeXT块,然后应用卷积层

    最终,神经网络将层堆叠起来,就像它们是乐高积木一样(但重要的是了解它们是如何工作的)。

    class Unet(nn.Cell):
    
        def __init__(
    
                self,
    
                dim,
    
                init_dim=None,
    
                out_dim=None,
    
                dim_mults=(1, 2, 4, 8),
    
                channels=3,
    
                with_time_emb=True,
    
                convnext_mult=2,
    
        ):
    
            super().__init__()
    
    ​
    
            self.channels = channels
    
    ​
    
            init_dim = default(init_dim, dim // 3 * 2)
    
            self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)
    
    ​
    
            dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
    
            in_out = list(zip(dims[:-1], dims[1:]))
    
    ​
    
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
    
    ​
    
            if with_time_emb:
    
                time_dim = dim * 4
    
                self.time_mlp = nn.SequentialCell(
    
                    SinusoidalPositionEmbeddings(dim),
    
                    nn.Dense(dim, time_dim),
    
                    nn.GELU(),
    
                    nn.Dense(time_dim, time_dim),
    
                )
    
            else:
    
                time_dim = None
    
                self.time_mlp = None
    
    ​
    
            self.downs = nn.CellList([])
    
            self.ups = nn.CellList([])
    
            num_resolutions = len(in_out)
    
    ​
    
            for ind, (dim_in, dim_out) in enumerate(in_out):
    
                is_last = ind >= (num_resolutions - 1)
    
    ​
    
                self.downs.append(
    
                    nn.CellList(
    
                        [
    
                            block_klass(dim_in, dim_out, time_emb_dim=time_dim),
    
                            block_klass(dim_out, dim_out, time_emb_dim=time_dim),
    
                            Residual(PreNorm(dim_out, LinearAttention(dim_out))),
    
                            Downsample(dim_out) if not is_last else nn.Identity(),
    
                        ]
    
                    )
    
                )
    
    ​
    
            mid_dim = dims[-1]
    
            self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
    
            self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
    
            self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
    
    ​
    
            for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
    
                is_last = ind >= (num_resolutions - 1)
    
    ​
    
                self.ups.append(
    
                    nn.CellList(
    
                        [
    
                            block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
    
                            block_klass(dim_in, dim_in, time_emb_dim=time_dim),
    
                            Residual(PreNorm(dim_in, LinearAttention(dim_in))),
    
                            Upsample(dim_in) if not is_last else nn.Identity(),
    
                        ]
    
                    )
    
                )
    
    ​
    
            out_dim = default(out_dim, channels)
    
            self.final_conv = nn.SequentialCell(
    
                block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
    
            )
    
    ​
    
        def construct(self, x, time):
    
            x = self.init_conv(x)
    
    ​
    
            t = self.time_mlp(time) if exists(self.time_mlp) else None
    
    ​
    
            h = []
    
    ​
    
            for block1, block2, attn, downsample in self.downs:
    
                x = block1(x, t)
    
                x = block2(x, t)
    
                x = attn(x)
    
                h.append(x)
    
    ​
    
                x = downsample(x)
    
    ​
    
            x = self.mid_block1(x, t)
    
            x = self.mid_attn(x)
    
            x = self.mid_block2(x, t)
    
    ​
    
            len_h = len(h) - 1
    
            for block1, block2, attn, upsample in self.ups:
    
                x = ops.concat((x, h[len_h]), 1)
    
                len_h -= 1
    
                x = block1(x, t)
    
                x = block2(x, t)
    
                x = attn(x)
    
    ​
    
                x = upsample(x)
    
            return self.final_conv(x)

    正向扩散

    我们已经知道正向扩散过程在多个时间步长𝑇

    中,从实际分布逐渐向图像添加噪声,根据差异计划进行正向扩散。最初的DDPM作者采用了线性时间表:

    线性增加到𝛽𝑇=0.02

    下面,我们定义了𝑇

    时间步的时间表。

    def linear_beta_schedule(timesteps):
    
        beta_start = 0.0001
    
        beta_end = 0.02
    
        return np.linspace(beta_start, beta_end, timesteps).astype(np.float32)

    首先,让我们使用 𝑇=200

    时间步长的线性计划,并定义我们需要的 β𝑡 中的各种变量,例如方差 𝛼¯𝑡 的累积乘积。下面的每个变量都只是一维张量,存储从 𝑡 到 𝑇 的值。重要的是,我们还定义了extract函数,它将允许我们提取一批适当的 𝑡

    索引。

    timesteps = 400
    
    ​
    
    # 定义 beta schedule
    
    betas = linear_beta_schedule(timesteps=timesteps)
    
    ​
    
    # 定义 alphas
    
    alphas = 1. - betas
    
    alphas_cumprod = np.cumprod(alphas, axis=0)
    
    alphas_cumprod_prev = np.pad(alphas_cumprod[:-1], (1, 0), constant_values=1)
    
    ​
    
    sqrt_recip_alphas = Tensor(np.sqrt(1. / alphas))
    
    sqrt_alphas_cumprod = Tensor(np.sqrt(alphas_cumprod))
    
    sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1. - alphas_cumprod))
    
    ​
    
    # 计算 q(x_{t-1} | x_t, x_0)
    
    posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
    
    ​
    
    p2_loss_weight = (1 + alphas_cumprod / (1 - alphas_cumprod)) ** -0.
    
    p2_loss_weight = Tensor(p2_loss_weight)
    
    ​
    
    def extract(a, t, x_shape):
    
        b = t.shape[0]
    
        out = Tensor(a).gather(t, -1)
    
        return out.reshape(b, *((1,) * (len(x_shape) - 1)))

    我们将用猫图像说明如何在扩散过程的每个时间步骤中添加噪音。

    # 下载猫猫图像
    
    url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/image_cat.zip'
    
    path = download(url, './', kind="zip", replace=True)
    
    
    from PIL import Image
    
    ​
    
    image = Image.open('./image_cat/jpg/000000039769.jpg')
    
    base_width = 160
    
    image = image.resize((base_width, int(float(image.size[1]) * float(base_width / float(image.size[0])))))
    
    image.show()

    噪声被添加到mindspore张量中,而不是Pillow图像。我们将首先定义图像转换,允许我们从PIL图像转换到mindspore张量(我们可以在其上添加噪声),反之亦然。

    这些转换相当简单:我们首先通过除以255

    来标准化图像(使它们在 [0,1] 范围内),然后确保它们在 [−1,1]

    范围内。DPPM论文中有介绍到:

    假设图像数据由 {0,1,...,255}

    中的整数组成,线性缩放为 [−1,1] , 这确保了神经网络反向过程在从标准正常先验 𝑝(𝐱𝑇)

    开始的一致缩放输入上运行。

    from mindspore.dataset import ImageFolderDataset
    
    ​
    
    image_size = 128
    
    transforms = [
    
        Resize(image_size, Inter.BILINEAR),
    
        CenterCrop(image_size),
    
        ToTensor(),
    
        lambda t: (t * 2) - 1
    
    ]
    
    ​
    
    ​
    
    path = './image_cat'
    
    dataset = ImageFolderDataset(dataset_dir=path, num_parallel_workers=cpu_count(),
    
                                 extensions=['.jpg', '.jpeg', '.png', '.tiff'],
    
                                 num_shards=1, shard_id=0, shuffle=False, decode=True)
    
    dataset = dataset.project('image')
    
    transforms.insert(1, RandomHorizontalFlip())
    
    dataset_1 = dataset.map(transforms, 'image')
    
    dataset_2 = dataset_1.batch(1, drop_remainder=True)
    
    x_start = next(dataset_2.create_tuple_iterator())[0]
    
    print(x_start.shape)
     

    我们还定义了反向变换,它接收一个包含 [−1,1]

    中的张量,并将它们转回 PIL 图像:

    import numpy as np
    
    ​
    
    reverse_transform = [
    
        lambda t: (t + 1) / 2,
    
        lambda t: ops.permute(t, (1, 2, 0)), # CHW to HWC
    
        lambda t: t * 255.,
    
        lambda t: t.asnumpy().astype(np.uint8),
    
        ToPIL()
    
    ]
    
    ​
    
    def compose(transform, x):
    
        for d in transform:
    
            x = d(x)
    
        return x
     

    这意味着我们现在可以定义给定模型的损失函数,如下所示:

    def p_losses(unet_model, x_start, t, noise=None):
    
        if noise is None:
    
            noise = randn_like(x_start)
    
        x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    
        predicted_noise = unet_model(x_noisy, t)
    
    ​
    
        loss = nn.SmoothL1Loss()(noise, predicted_noise)# todo
    
        loss = loss.reshape(loss.shape[0], -1)
    
        loss = loss * extract(p2_loss_weight, t, loss.shape)
    
        return loss.mean()

    denoise_model将是我们上面定义的U-Net。我们将在真实噪声和预测噪声之间使用Huber损失。

    数据准备与处理

    在这里我们定义一个正则数据集。数据集可以来自简单的真实数据集的图像组成,如Fashion-MNIST、CIFAR-10或ImageNet,其中线性缩放为 [−1,1]

    每个图像的大小都会调整为相同的大小。有趣的是,图像也是随机水平翻转的。根据论文内容:我们在CIFAR10的训练中使用了随机水平翻转;我们尝试了有翻转和没有翻转的训练,并发现翻转可以稍微提高样本质量。

    本实验我们选用Fashion_MNIST数据集,我们使用download下载并解压Fashion_MNIST数据集到指定路径。此数据集由已经具有相同分辨率的图像组成,即28x28。

    # 下载MNIST数据集
    url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset.zip'
    
    path = download(url, './', kind="zip", replace=True)
    from mindspore.dataset import FashionMnistDataset
    
    ​
    
    image_size = 28
    
    channels = 1
    
    batch_size = 16
    
    ​
    
    fashion_mnist_dataset_dir = "./dataset"
    
    dataset = FashionMnistDataset(dataset_dir=fashion_mnist_dataset_dir, usage="train", num_parallel_workers=cpu_count(), shuffle=True, num_shards=1, shard_id=0)
     

    采样

    由于我们将在训练期间从模型中采样(以便跟踪进度),我们定义了下面的代码。采样在本文中总结为算法2:

    Image-5

    从扩散模型生成新图像是通过反转扩散过程来实现的:我们从𝑇

    开始,我们从高斯分布中采样纯噪声,然后使用我们的神经网络逐渐去噪(使用它所学习的条件概率),直到我们最终在时间步𝑡=0结束。如上图所示,我们可以通过使用我们的噪声预测器插入平均值的重新参数化,导出一个降噪程度较低的图像 𝐱𝑡−1

    。请注意,方差是提前知道的。

    理想情况下,我们最终会得到一个看起来像是来自真实数据分布的图像。

    下面的代码实现了这一点。

    def p_sample(model, x, t, t_index):
    
        betas_t = extract(betas, t, x.shape)
    
        sqrt_one_minus_alphas_cumprod_t = extract(
    
            sqrt_one_minus_alphas_cumprod, t, x.shape
    
        )
    
        sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
        model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
    
    ​
    
        if t_index == 0:
    
            return model_mean
    
        posterior_variance_t = extract(posterior_variance, t, x.shape)
    
        noise = randn_like(x)
    
        return model_mean + ops.sqrt(posterior_variance_t) * noise
    
    ​
    
    def p_sample_loop(model, shape):
    
        b = shape[0]
    
        # 从纯噪声开始
    
        img = randn(shape, dtype=None)
    
        imgs = []
    
    ​
    
        for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
    
            img = p_sample(model, img, ms.numpy.full((b,), i, dtype=mstype.int32), i)
    
            imgs.append(img.asnumpy())
    
        return imgs
    
    ​
    
    def sample(model, image_size, batch_size=16, channels=3):
    
        return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
     

    请注意,上面的代码是原始实现的简化版本。

    训练过程

    下面,我们开始训练吧!

    # 定义动态学习率
    
    lr = nn.cosine_decay_lr(min_lr=1e-7, max_lr=1e-4, total_step=10*3750, step_per_epoch=3750, decay_epoch=10)
    
    ​
    
    # 定义 Unet模型
    
    unet_model = Unet(
    
        dim=image_size,
    
        channels=channels,
    
        dim_mults=(1, 2, 4,)
    
    )
    
    ​
    
    name_list = []
    
    for (name, par) in list(unet_model.parameters_and_names()):
    
        name_list.append(name)
    
    i = 0
    
    for item in list(unet_model.trainable_params()):
    
        item.name = name_list[i]
    
        i += 1
    
    ​
    
    # 定义优化器
    
    optimizer = nn.Adam(unet_model.trainable_params(), learning_rate=lr)
    
    loss_scaler = DynamicLossScaler(65536, 2, 1000)
    
    ​
    
    # 定义前向过程
    
    def forward_fn(data, t, noise=None):
    
        loss = p_losses(unet_model, data, t, noise)
    
        return loss
    
    ​
    
    # 计算梯度
    
    grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)
    
    ​
    
    # 梯度更新
    
    def train_step(data, t, noise):
    
        loss, grads = grad_fn(data, t, noise)
    
        optimizer(grads)
    
        return loss
    
    
    import time
    
    ​
    
    # 由于时间原因,epochs设置为1,可根据需求进行调整
    
    epochs = 1
    
    ​
    
    for epoch in range(epochs):
    
        begin_time = time.time()
    
        for step, batch in enumerate(dataset.create_tuple_iterator()):
    
            unet_model.set_train()
    
            batch_size = batch[0].shape[0]
    
            t = randint(0, timesteps, (batch_size,), dtype=ms.int32)
    
            noise = randn_like(batch[0])
    
            loss = train_step(batch[0], t, noise)
    
    ​
    
            if step % 500 == 0:
    
                print(" epoch: ", epoch, " step: ", step, " Loss: ", loss)
    
        end_time = time.time()
    
        times = end_time - begin_time
    
        print("training time:", times, "s")
    
        # 展示随机采样效果
    
        unet_model.set_train(False)
    
        samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)
    
        plt.imshow(samples[-1][5].reshape(image_size, image_size, channels), cmap="gray")
    
    print("Training Success!")

  • 总结

  • 无分类器扩散指南([Ho et al., 2021):表明通过使用单个神经网络联合训练条件和无条件扩散模型,不需要分类器来指导扩散模型

  • 具有CLIP Latents (DALL-E 2) 的分层文本条件图像生成 (Ramesh et al., 2022):在将文本标题转换为CLIP图像嵌入之前使用,然后扩散模型将其解码为图像

  • 具有深度语言理解的真实文本到图像扩散模型(ImageGen)(Saharia et al., 2022):表明将大型预训练语言模型(例如T5)与级联扩散结合起来,对于文本到图像的合成很有效

  • 15
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值