昇思25天学习打卡营第24天|生成式-Diffusion扩散模型

打卡

目录

打卡

理解扩散模型

环境设置

Diffusion Model 简介

扩散模型实现原理

Diffusion 前向过程

Diffusion 逆向过程

训练算法总结

U-Net神经网络预测噪声

构建Diffusion模型

准备函数和类

位置向量

ResNet/ConvNeXT块

Attention模块

组归一化

条件U-Net

正向扩散 (core)

数据准备与处理

前向扩散定义

时间步下的前向扩散可视化 

模型损失函数定义

案例实操

数据准备与处理

从模型中采样

训练过程

推理过程(从模型中采样)

参考


本次内容是基于Hugging Face:The Annotated Diffusion Model一文翻译迁移到mindspore中的讲解实现,具体的diffusion理论理解和改善需要额外通过其他参考链接进行知识补充。

即使如此,本次内容的代码部分的核心部分也几乎都涉及到了。

理解扩散模型

本文的介绍是基于denoising diffusion probabilistic model (DDPM),DDPM的应用案例,例如有,OpenAI主导的GLIDEDALL-E 2、海德堡大学主导的潜在扩散、Google Brain主导的图像生成

本文是在Phil Wang 基于PyTorch框架的复现 的基础上,迁移到MindSpore AI框架上实现的。

环境设置

实验开始之前请确保安装并导入所需的库(假设您已经安装了MindSpore、download、dataset、matplotlib以及tqdm)。 

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
pip install download
pip install numpy tqdm matplotlib
import math
from functools import partial
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np
from multiprocessing import cpu_count
from download import download

import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype
from mindspore.dataset.vision import Resize, Inter, CenterCrop, ToTensor, RandomHorizontalFlip, ToPIL
from mindspore.common.initializer import initializer
from mindspore.amp import DynamicLossScaler


### 设置全局种子。全局种子可用于numpy.random, mindspore.common.Initializer以及mindspore.nn.probability.distribution。
## 如果没有设置全局种子,这些包将会各自使用自己的种子    
### 种子是随机数生成器的起点,决定了随机数序列。给定相同的种子,随机数生成器会产生相同的序列,这就使得实验结果可以复现。
### 如果每次运行都使用不同的种子,那么即使是相同的代码也会产生不同的结果,这在调试和模型训练中可能会导致问题。
### 全局种子的作用是确保在多次运行程序或训练模型时,随机数生成的结果是可以复现的。
ms.set_seed(0) 

Diffusion Model 简介

Diffusion 模型从纯噪声开始通过一个神经网络学习逐步去噪,最终得到一个实际图像。

Diffusion 模型通过下面两个过程处理图像:

  • 一个选取的固定(或预定义)正向扩散过程 𝑞 :它逐渐将高斯噪声添加到图像中,直到最终得到纯噪声。

  • 一个学习的反向去噪的扩散过程 $ p_\theta $  :通过训练神经网络从纯噪声开始逐渐对图像去噪,直到最终得到一个实际的图像。

 如下图,从左到右(0-->T 的时间步长)是反向扩散的生成过程,从右到左(T-->0 的时间步长)是噪声化的正向扩散过程。

假定给定一个足够大的 T 和一个在每个时间步长添加噪声的良好时间表,最终会在 𝑡=𝑇  通过渐进的过程得到所谓的 各向同性的高斯分布

  

扩散模型实现原理

Diffusion 前向过程

前向过程即向图片$x_0$加噪声的过程,是理解diffusion model 以及 构建训练样本至关重要的一步。

>> 前向过程定义:

  • 前向过程 $ q(\mathbf{x}_t | \mathbf{x}_{t-1}) $ 的每个时刻 t 只与时刻 t-1 有关,所以也可以看做马尔科夫过程$ q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}) $ 。
  • 上面公式是一个参数化的高斯分布,在这个前向过程中,在每个时间步长 t ,我们根据已知的逐渐增大的方差 ${0}<\beta_{1}<\beta_{2}< ... <\beta_{T}<{1}$ 添加高斯噪声。
  • 注意,上面的参数化高斯分布中,分号用于分隔随机变量和参数,逗号用于分隔不同的参数(均值和方差)。
  • 其中,假设 $q(x_0)$ 是真实数据分布,从这个分布中采样可以获得图像 $x_0$
  • 其中,$ q(\mathbf{x}_t | \mathbf{x}_{t-1}) $ 是条件高斯分布,想要找到 $x_t$​ 在  $x_{t-1}$ 已知的情况下的条件分布(也是一个正态分布),我们只需要计算出它的条件均值和条件方差即可。按照此公式,条件均值是 $ \sqrt{1 - \beta_t} \mathbf{x}_{t-1} $,条件方差是 $ \beta_t \mathbf{I} $。 

>> 添加高斯噪声:

  1. 首先,采样一个标准高斯噪声 ${\epsilon} \sim N(0, I)$
  2. 一般地,正态分布(或高斯分布)由平均值 𝜇 和方差 $ \sigma ^2 \geqslant 0$  两个参数定义,基本上,在每个时间步长 t 处的产生的每个新的(轻微噪声)图像都是从条件高斯分布中绘制的。// 数学上,条件高斯分布是多元高斯分布中的一个特例,它关心的是在给定其他变量值的情况下,一个特定变量的分布。
  3. 然后,根据公式 $ q(\mathbf{x}_t) = q(\mathbf{\mu}_t) + \sqrt{\beta_t} \mathbf{\epsilon} = \sqrt{1 - \beta_t} \mathbf{x}_{t-1} + \sqrt{\beta_t} \mathbf{\epsilon} $  计算 $ {x}_t $ ,这里的 $ \sqrt{1 - \beta_t} $ 和 $ \sqrt{ \beta_t} $ 分别是图像和噪声的权重。其中 $ \beta _t $ 随着时间步长增长而增大,即意味着对标准高斯噪声的加权权重比是逐步上升的。

>> 动态方差定义:

  • 动态方差方法:每个时间步长的 $ \beta _t $ 可以是线性的、二次的、余弦的等形式。

>> 前向过程动态演进:

  • 如果我们适当设置时间表,从 $x_0$ 开始,我们最终得到 $x_1, ..., x_t, ..., x_T$ ,即随着 t 的增大  $x_t$ 会越来越接近纯噪声,而 $x_T$ 就是纯高斯噪声
  • 如果知道条件概率分布 $ p(x_{t-1} | x_t) $  ,就可以反向运行这个过程:通过采样一些随机高斯噪声 $x_T$ ,然后逐渐去噪它,最终得到真实分布 $x_0$   中的样本。但是,我们不知道条件概率分布$ p(x_{t-1} | x_t) $ 。这很棘手,因为需要知道所有可能图像的分布,才能计算这个条件概率。

Diffusion 逆向过程

为了解决上面的问题,可以利用神经网络来近似(学习)这个条件概率分布 $ p_\theta (x_{t-1} | x_t) $  , 其中 \theta 是神经网络的参数。

逆向过程(reverse)就是diffusion的去噪推断过程,而通过神经网络学习并表示 $ p_\theta (x_{t-1} | x_t) $ 的过程就是Diffusion 逆向去噪的核心。

假设这个反向过程也是高斯的,任何高斯分布都由2个参数定义:

  • 由  $ \mu_\theta $ 参数化的平均值
  • 由  $ \Sigma_\theta $ 参数化的方差

综上,我们可以将逆向过程公式化为 $ p_\theta (\mathbf{x}_{t-1} | \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1};\mu_\theta(\mathbf{x}_{t},t), \Sigma_\theta (\mathbf{x}_{t},t)) $ 

其中平均值和方差也取决于噪声水平 t ,神经网络需要通过学习来表示这些均值和方差。

  • 注意,DDPM的作者决定保持方差固定,让神经网络只学习(表示)这个条件概率分布的平均值 $ \mu_\theta $ 

作者观察到前向过程 $ q(\mathbf{x}_t) $ 和 $ p_\theta (\mathbf{x}_{t-1}) $  的组合可以被视为变分自动编码器(VAE)。因此,变分下界(也称为ELBO)可用于最小化真值数据样本 𝐱0 的似然负对数。该过程的 ELBO 是每个时间步长的损失之和  $L=L_0+L_1+...+L_T$ ,其中,每项的损失 𝐿𝑡 (除了 𝐿0 )实际上是2个高斯分布之间的KL发散,可以明确地写为相对于均值的L2-loss!

  • 经过系列推导认为,不需要重复应用 $ q(\mathbf{x}_t) $ 去采样  $ x_t $ , 可以采样高斯噪声并适当地缩放它,然后将其添加到 $ x_0 $中,直接获得 $ x_t $
  • 从而有公式:​​​​​​​$ q(\mathbf{x}_t | \mathbf{x}_0) = N (\mathbf{x_t}; \sqrt{ \bar{\alpha}_t } \mathbf{x_0}, (1 - \bar{\alpha}_t) \mathbf{I}) $  。其中,$ \alpha_t := 1 - \beta_t$ ,$ \bar{\alpha}t := \Pi_{s=1}^{t} \alpha_s $
  • 这种性质的另一个优点是可以重新参数化平均值,使神经网络学习(预测)构成损失的KL项中噪声的附加噪声。意味着我们的神经网络变成了噪声预测器,而不是(直接)均值预测器。其中,平均值可以按如下方式计算:$ \mathbf{\mu}_\theta(\mathbf{x}_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t - \frac{\beta_t}{\sqrt{1- \bar{\alpha}_t}} \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) \right) $

神经网络是基于真实噪声和预测高斯噪声之间的简单均方误差(MSE)进行优化的。最终的目标函数 ${L}_{t}$如下 (随机步长 t 由 $({\epsilon} \sim N(\mathbf{0}, \mathbf{I}))$ 给定):$ \| \mathbf{\epsilon} - \mathbf{\epsilon}_\theta(\mathbf{x}_t, t) \|^2 = \| \mathbf{\epsilon} - \mathbf{\epsilon}_\theta( \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{(1- \bar{\alpha}_t) } \mathbf{\epsilon}, t) \|^2$ 

其中,$x_0$初始(真实,未损坏)图像, $ \epsilon $ 是在时间步长 t 采样的纯噪声$ \epsilon_\theta (x_t, t) $ 是我们的神经网络。 

训练算法总结

根据前所述的前向和逆向过程,训练算法如下图简述。

U-Net神经网络预测噪声

神经网络需要在特定时间步长接收带噪声的图像,并返回预测的噪声。请注意,预测噪声是与输入图像具有相同大小/分辨率的张量。

在模型结构方面,DDPM的作者选择了U-Net。这个网络就像任何自动编码器(VAE)一样,在中间由一个bottleneck组成,确保网络只学习最重要的信息,重要的是,它在编码器和解码器之间引入了残差连接,极大地改善了梯度流。

如下图,U-Net模型首先对输入进行下采样(即,在空间分辨率方面使输入更小),之后执行上采样。

构建Diffusion模型

准备函数和类

定义了一些简单的计算操作、随机数生成操作、检查操作、上采样和下采样操作的别名等。

def rearrange(head, inputs):
    b, hc, x, y = inputs.shape
    c = hc // head
    ## 对输入Tensor进行重新排列。
    return inputs.reshape((b, head, c, x * y))

def rsqrt(x):
    ### 逐元素计算:y_i = 1 / sqrt(x_i)
    res = ops.sqrt(x)  ## 逐元素返回当前Tensor的平方根。out_i = sqrt(x_i)
    return ops.inv(res)  ## 逐元素计算输入Tensor的倒数。out_i = 1 / (x_i)

def randn_like(x, dtype=None):
    if dtype is None:
        dtype = x.dtype
    ### 根据标准正态(高斯)随机数分布生成随机数。
    ## 返回具有给定shape的Tensor,其中的随机数从平均值为0、标准差为1的标准正态分布中取样。
    res = ops.standard_normal(x.shape).astype(dtype)
    return res

def randn(shape, dtype=None):
    if dtype is None:
        dtype = ms.float32
    res = ops.standard_normal(shape).astype(dtype)
    return res

def randint(low, high, size, dtype=ms.int32):
    ## 生成服从均匀分布的随机数。
    ## params=[变量shape, 最小值,最大值,类型]
    res = ops.uniform(size, Tensor(low, dtype), Tensor(high, dtype), dtype=dtype)
    return res

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def _check_dtype(d1, d2):
    if ms.float32 in (d1, d2):
        return ms.float32
    if d1 == d2:
        return d1
    raise ValueError('dtype is not supported.')

class Residual(nn.Cell):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def construct(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


def Upsample(dim):
    # 计算二维转置卷积,可以视为Conv2d对输入求梯度,也称为反卷积(实际不是真正的反卷积)。
    ## in_ch = 输入的shape通常为 (N, C, H, W), N 是 batch size, C 是空间维度,H/W 是特征层的高度和宽度。
    ## out_ch
    ## 二维卷积核的高度和宽度 = 4 = kernel_size, 
    ## 二维卷积核的移动步长 = 2 = stride,
    return nn.Conv2dTranspose(dim, dim, 4, 2, pad_mode="pad", padding=1)

def Downsample(dim):
    return nn.Conv2d(dim, dim, 4, 2, pad_mode="pad", padding=1)

位置向量

作者参考Transformer,使用正弦位置嵌入来编码𝑡 (Vaswani et al., 2017

class SinusoidalPositionEmbeddings(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = np.exp(np.arange(half_dim) * - emb)
        self.emb = Tensor(emb, ms.float32)

    def construct(self, x):
        emb = x[:, None] * self.emb[None, :]
        # 在指定轴上拼接输入Tensor。axis 取值范围是[-R, R)。默认 0。
        emb = ops.concat((ops.sin(emb), ops.cos(emb)), axis=-1)
        return emb

ResNet/ConvNeXT块

定义 U-Net模型的核心构建块时,DDPM作者使用了一个Wide ResNet块(Zagoruyko et al., 2016),但Phil Wang决定添加ConvNeXT(Liu et al., 2022)替换ResNet,因为后者在图像领域取得了巨大成功。本文选择ConvNeXT块构建U-Net模型。

class Block(nn.Cell):
    def __init__(self, dim, dim_out, groups=1):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, pad_mode="pad", padding=1) 
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def construct(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class ConvNextBlock(nn.Cell):
    def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
        super().__init__()
        self.mlp = (
            nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))
            if exists(time_emb_dim)
            else None
        )

        self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")
        self.net = nn.SequentialCell(
            nn.GroupNorm(1, dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),
            nn.GELU(),
            nn.GroupNorm(1, dim_out * mult),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),
        )

        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def construct(self, x, time_emb=None):
        h = self.ds_conv(x)
        if exists(self.mlp) and exists(time_emb):
            assert exists(time_emb), "time embedding must be passed in"
            condition = self.mlp(time_emb)
            condition = condition.expand_dims(-1).expand_dims(-1)
            h = h + condition

        h = self.net(h)
        return h + self.res_conv(x)

Attention模块

定义Attention模块(来自Transformer架构 - Vaswani et al., 2017),DDPM作者将其添加到卷积块之间。Phil Wang使用了两种注意力变体:一种是常规的multi-head self-attention(如Transformer中使用的),另一种是LinearAttention(Shen et al., 2018),其时间和内存要求在序列长度上线性缩放,而不是在常规注意力中缩放。

要想对Attention机制进行深入的了解,请参照Jay Allamar的精彩的博文

class Attention(nn.Cell):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True)
        self.map = ops.Map()
        self.partial = ops.Partial()

    def construct(self, x):
        b, _, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, 1)
        q, k, v = self.map(self.partial(rearrange, self.heads), qkv)

        q = q * self.scale

        # 'b h d i, b h d j -> b h i j'
        sim = ops.bmm(q.swapaxes(2, 3), k)
        attn = ops.softmax(sim, axis=-1)
        # 'b h i j, b h d j -> b h i d'
        out = ops.bmm(attn, v.swapaxes(2, 3))
        out = out.swapaxes(-1, -2).reshape((b, -1, h, w))

        return self.to_out(out)


class LayerNorm(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        self.g = Parameter(initializer('ones', (1, dim, 1, 1)), name='g')

    def construct(self, x):
        eps = 1e-5
        var = x.var(1, keepdims=True)
        mean = x.mean(1, keep_dims=True)
        return (x - mean) * rsqrt((var + eps)) * self.g


class LinearAttention(nn.Cell):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)

        self.to_out = nn.SequentialCell(
            nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True),
            LayerNorm(dim)
        )

        self.map = ops.Map()
        self.partial = ops.Partial()

    def construct(self, x):
        b, _, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, 1)
        q, k, v = self.map(self.partial(rearrange, self.heads), qkv)

        q = ops.softmax(q, -2)
        k = ops.softmax(k, -1)

        q = q * self.scale
        v = v / (h * w)

        # 'b h d n, b h e n -> b h d e'
        context = ops.bmm(k, v.swapaxes(2, 3))
        # 'b h d e, b h d n -> b h e n'
        out = ops.bmm(context.swapaxes(2, 3), q)

        out = out.reshape((b, -1, h, w))
        return self.to_out(out)

组归一化

将U-Net的卷积/注意层与群归一化(Wu et al., 2018)。如下代码,PreNorm类将用于在注意层之前应用groupnorm。

class PreNorm(nn.Cell):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def construct(self, x):
        x = self.norm(x)
        return self.fn(x)

条件U-Net

请记住,网络  $ \epsilon_\theta (x_t, t) $  的工作是接收一批噪声图像+噪声水平,并输出添加到输入中的噪声。

更具体的: 网络获取了一批(batch_size, num_channels, height, width)形状的噪声图像和一批(batch_size, 1)形状的噪音水平作为输入,并返回(batch_size, num_channels, height, width)形状的张量。

网络构建过程如下:

  • 首先,将卷积层应用于噪声图像批上,并计算噪声水平的位置

  • 接下来,应用一系列下采样级。每个下采样阶段由2个ResNet/ConvNeXT块 + groupnorm + attention + 残差连接 + 一个下采样操作组成

  • 在网络的中间,再次应用ResNet或ConvNeXT块,并与attention交织

  • 接下来,应用一系列上采样级。每个上采样级由2个ResNet/ConvNeXT块+ groupnorm + attention + 残差连接 + 一个上采样操作组成

  • 最后,应用ResNet/ConvNeXT块,然后应用卷积层

最终,神经网络将层堆叠起来。

class Unet(nn.Cell):
    def __init__(
            self,
            dim,
            init_dim=None,
            out_dim=None,
            dim_mults=(1, 2, 4, 8),
            channels=3,
            with_time_emb=True,
            convnext_mult=2,
    ):
        super().__init__()

        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ConvNextBlock, mult=convnext_mult)

        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.SequentialCell(
                SinusoidalPositionEmbeddings(dim),
                nn.Dense(dim, time_dim),
                nn.GELU(),
                nn.Dense(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None

        self.downs = nn.CellList([])
        self.ups = nn.CellList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.CellList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Downsample(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(
                nn.CellList(
                    [
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )

        out_dim = default(out_dim, channels)
        self.final_conv = nn.SequentialCell(
            block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
        )

    def construct(self, x, time):
        x = self.init_conv(x)

        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        len_h = len(h) - 1
        for block1, block2, attn, upsample in self.ups:
            x = ops.concat((x, h[len_h]), 1)
            len_h -= 1
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)
        return self.final_conv(x)

正向扩散 (core)

正向扩散过程是在多个时间步长 T 中,从实际分布逐渐向图像添加噪声,根据差异计划进行正向扩散。最初的DDPM作者采用了线性时间表:

  • 我们将正向过程方差设置为常数,从 $ \beta_1 = 10^{-4}$ 线性增加$ \beta_T = 0.02 $

  • 在(Nichol et al., 2021)中表明,当使用余弦调度时,可以获得更好的结果。

下面,我们定义了T时间步的线性时间表。

首先,我们使用 𝑇=200  时间步长的线性计划,并定义我们需要的 $ \beta _t$ 中的各种变量,例如方差 $\bar{\alpha}_t$ 的累积乘积。下面的每个变量都只是一维张量,存储从 𝑡 到 𝑇 的值。重要的是,我们还定义了extract函数,它将允许我们提取一批适当的 𝑡 索引。

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return np.linspace(beta_start, beta_end, timesteps).astype(np.float32)


# 扩散200步
timesteps = 200

# 定义 beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# 定义 alphas
alphas = 1. - betas
### 用于计算数组或矩阵的累积乘积。累积乘积是一个沿着指定轴对数组或矩阵的元素进行累积乘积的序列。
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.pad(alphas_cumprod[:-1], (1, 0), constant_values=1)

sqrt_recip_alphas = Tensor(np.sqrt(1. / alphas))
sqrt_alphas_cumprod = Tensor(np.sqrt(alphas_cumprod))
sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1. - alphas_cumprod))

# 计算 q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

p2_loss_weight = (1 + alphas_cumprod / (1 - alphas_cumprod)) ** -0.
p2_loss_weight = Tensor(p2_loss_weight)

def extract(a, t, x_shape):
    b = t.shape[0]
    out = Tensor(a).gather(t, -1)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

数据准备与处理

用猫图像说明如何在扩散过程的每个时间步骤中添加噪音。

噪声被添加到mindspore张量中,而不是Pillow图像。

我们将首先定义图像转换,允许我们从PIL图像转换到mindspore张量(我们可以在其上添加噪声),反之亦然。

from PIL import Image
from mindspore.dataset import ImageFolderDataset
import numpy as np

# 下载猫猫图像
url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/image_cat.zip'
path = download(url, './', kind="zip", replace=True)

# 显示猫猫图像
image = Image.open('./image_cat/jpg/000000039769.jpg')
base_width = 160
image = image.resize((base_width, int(float(image.size[1]) * float(base_width / float(image.size[0])))))
image.show()

## 变换猫猫图像
image_size = 128
transforms = [
    ## 对输入图像使用给定的 mindspore.dataset.vision.Inter 插值方式去调整为给定的尺寸大小。
    Resize(image_size, Inter.BILINEAR),
    ## 对输入图像应用中心区域裁剪。如果输入图像尺寸小于输出尺寸,则在裁剪前对输入图像边界填充0像素。
    CenterCrop(image_size),  
    ToTensor(),
    lambda t: (t * 2) - 1
]

path = './image_cat'
dataset = ImageFolderDataset(dataset_dir=path, num_parallel_workers=cpu_count(),
                             extensions=['.jpg', '.jpeg', '.png', '.tiff'],
                             num_shards=1, shard_id=0, shuffle=False, decode=True)
## 从数据集对象中选择需要的列,并按给定的列名的顺序进行排序。 未指定的数据列将被丢弃。
dataset = dataset.project('image')
transforms.insert(1, RandomHorizontalFlip())
dataset_1 = dataset.map(transforms, 'image')
dataset_2 = dataset_1.batch(1, drop_remainder=True)
x_start = next(dataset_2.create_tuple_iterator())[0]
print(x_start.shape)  ## (1, 3, 128, 128)


## 定义了反向变换,它接收一个包含 [−1,1] 中的张量,并将它们转回 PIL 图像
reverse_transform = [
    lambda t: (t + 1) / 2,
    lambda t: ops.permute(t, (1, 2, 0)), # CHW to HWC
    lambda t: t * 255.,
    lambda t: t.asnumpy().astype(np.uint8),
    ToPIL()
]

def compose(transform, x):
    for d in transform:
        x = d(x)
    return x

### x_start[0] >> [3, 128, 128]
reverse_image = compose(reverse_transform, x_start[0])
reverse_image.show()

前向扩散定义

def q_sample(x_start, t, noise=None):
    if noise is None:
        ## 生成标准正态分布噪声,尺寸同 x_start
        noise = randn_like(x_start)
    return (extract(sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)


def get_noisy_image(x_start, t):
    # 添加噪音
    x_noisy = q_sample(x_start, t=t)

    # 转换为 PIL 图像
    noisy_image = compose(reverse_transform, x_noisy[0])

    return noisy_image


# 设置 time step
t = Tensor([40])
noisy_image = get_noisy_image(x_start, t)
print(noisy_image)
noisy_image.show()

时间步下的前向扩散可视化 

如正向扩散(core) 过程,一共200个时间步,我们可视化其中的5个结果如下。

import matplotlib.pyplot as plt

def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    _, axs = plt.subplots(figsize=(200, 200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [image] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()


plot([get_noisy_image(x_start, Tensor([t])) for t in [0, 50, 100, 150, 199]])

模型损失函数定义

1)通过原始图像在t时间步添加一个噪声分布noise,输出叠加了噪声的原始图像 x_noisy;

2)使用叠加了噪声的原始图像 x_noisy 作为网络模型的输入,预测噪声分布 predicted_noise ;

3)用 nn.SmoothL1Loss 计算实际噪声分布和预测的噪声分布之间的损失。

def p_losses(unet_model, x_start, t, noise=None):
    if noise is None:
        noise = randn_like(x_start)

    ## 前向扩散t时间步的噪声图像输出
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    ## 噪声图像作为训练样本进入unet网络模型,进行噪声分布函数的参数预测
    predicted_noise = unet_model(x_noisy, t)
    
    ## noise 是真实的分布,predicted_noise 是模型预测的噪声分布,要让这个loss最小化。
    loss = nn.SmoothL1Loss()(noise, predicted_noise)# todo
    loss = loss.reshape(loss.shape[0], -1)
    loss = loss * extract(p2_loss_weight, t, loss.shape)
    return loss.mean()

案例实操

数据准备与处理

from mindspore.dataset import FashionMnistDataset

# 下载MNIST数据集
url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset.zip'
path = download(url, './', kind="zip", replace=True)


## 数据加载
image_size = 28
channels = 1
batch_size = 16

fashion_mnist_dataset_dir = "./dataset"
dataset = FashionMnistDataset(
                        dataset_dir=fashion_mnist_dataset_dir, 
                        usage="train", num_parallel_workers=cpu_count(), 
                        shuffle=True, num_shards=1, shard_id=0)


transforms = [
    RandomHorizontalFlip(),
    ToTensor(),
    lambda t: (t * 2) - 1
]


dataset = dataset.project('image')
dataset = dataset.shuffle(64)
dataset = dataset.map(transforms, 'image')
dataset = dataset.batch(16, drop_remainder=True)

x = next(dataset.create_dict_iterator())
print(x.keys())  ### dict_keys(['image'])

从模型中采样

在训练期间从模型中采样,以便跟踪进度。

从扩散模型生成新图像是通过反转扩散过程来实现的:我们从 T 开始,从高斯分布中采样纯噪声,然后使用神经网络逐渐去噪(使用它所学习的条件概率),直到我们最终在时间步 𝑡=0 结束。如上图算法所示,我们可以通过使用噪声预测器插入平均值的重新参数化,导出一个降噪程度较低的图像 $ x_{t-1} $。请注意,方差是提前知道的。

理想情况下,最终会得到一个看起来像是来自真实数据分布的图像。

代码如下,如下代码是原始实现的简化版本。

def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)

    if t_index == 0:
        return model_mean
    posterior_variance_t = extract(posterior_variance, t, x.shape)
    noise = randn_like(x)
    return model_mean + ops.sqrt(posterior_variance_t) * noise

def p_sample_loop(model, shape):
    b = shape[0]
    # 从纯噪声开始
    img = randn(shape, dtype=None)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, ms.numpy.full((b,), i, dtype=mstype.int32), i)
        imgs.append(img.asnumpy())
    return imgs

def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

训练过程

如下,是1个 epoch 的训练过程。

# 定义动态学习率
lr = nn.cosine_decay_lr(min_lr=1e-7, max_lr=1e-4, total_step=10*3750, step_per_epoch=3750, decay_epoch=10)

# 定义 Unet模型
unet_model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)

name_list = []
for (name, par) in list(unet_model.parameters_and_names()):
    name_list.append(name)
i = 0
for item in list(unet_model.trainable_params()):
    item.name = name_list[i]
    i += 1

# 定义优化器
optimizer = nn.Adam(unet_model.trainable_params(), learning_rate=lr)
loss_scaler = DynamicLossScaler(65536, 2, 1000)

# 定义前向过程
def forward_fn(data, t, noise=None):
    loss = p_losses(unet_model, data, t, noise)
    return loss

# 计算梯度
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)

# 梯度更新
def train_step(data, t, noise):
    loss, grads = grad_fn(data, t, noise)
    optimizer(grads)
    return loss


import time

# 由于时间原因,epochs设置为1,可根据需求进行调整
epochs = 1

for epoch in range(epochs):
    begin_time = time.time()
    for step, batch in enumerate(dataset.create_tuple_iterator()):
        unet_model.set_train()
        batch_size = batch[0].shape[0]
        t = randint(0, timesteps, (batch_size,), dtype=ms.int32)
        noise = randn_like(batch[0])
        loss = train_step(batch[0], t, noise)

        if step % 500 == 0:
            print(" epoch: ", epoch, " step: ", step, " Loss: ", loss)
    end_time = time.time()
    times = end_time - begin_time
    print("training time:", times, "s")
    # 展示随机采样效果
    unet_model.set_train(False)
    samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)
    plt.imshow(samples[-1][5].reshape(image_size, image_size, channels), cmap="gray")
print("Training Success!")

训练过程展示,效果一般,建议多加几个epoch:

print(unet_model)

Unet<
  (init_conv): Conv2d<input_channels=1, output_channels=18, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd21290f40>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffd21290a90>, format=NCHW>
  (time_mlp): SequentialCell<
    (0): SinusoidalPositionEmbeddings<>
    (1): Dense<input_channels=28, output_channels=112, has_bias=True>
    (2): GELU<>
    (3): Dense<input_channels=112, output_channels=112, has_bias=True>
    >
  (downs): CellList<
    (0): CellList<
      (0): ConvNextBlock<
        (mlp): SequentialCell<
          (0): GELU<>
          (1): Dense<input_channels=112, output_channels=18, has_bias=True>
          >
        (ds_conv): Conv2d<input_channels=18, output_channels=18, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=18, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd21247100>, bias_init=None, format=NCHW>
        (net): SequentialCell<
          (0): GroupNorm<num_groups=1, num_channels=18>
          (1): Conv2d<input_channels=18, output_channels=56, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd21247f40>, bias_init=None, format=NCHW>
          (2): GELU<>
          (3): GroupNorm<num_groups=1, num_channels=56>
          (4): Conv2d<input_channels=56, output_channels=28, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd21240cd0>, bias_init=None, format=NCHW>
          >
        (res_conv): Conv2d<input_channels=18, output_channels=28, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd21247fd0>, bias_init=None, format=NCHW>
        >
      (1): ConvNextBlock<
        (mlp): SequentialCell<
          (0): GELU<>
          (1): Dense<input_channels=112, output_channels=28, has_bias=True>
          >
        (ds_conv): Conv2d<input_channels=28, output_channels=28, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=28, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd21247d90>, bias_init=None, format=NCHW>
        (net): SequentialCell<
          (0): GroupNorm<num_groups=1, num_channels=28>
          (1): Conv2d<input_channels=28, output_channels=56, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd21247af0>, bias_init=None, format=NCHW>
          (2): GELU<>
          (3): GroupNorm<num_groups=1, num_channels=56>
          (4): Conv2d<input_channels=56, output_channels=28, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd212477c0>, bias_init=None, format=NCHW>
          >
        (res_conv): Identity<>
        >
      (2): Residual<
        (fn): PreNorm<
          (fn): LinearAttention<
            (to_qkv): Conv2d<input_channels=28, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffde0903f40>, bias_init=None, format=NCHW>
            (to_out): SequentialCell<
              (0): Conv2d<input_channels=128, output_channels=28, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd21247850>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffd21247b80>, format=NCHW>
              (1): LayerNorm<>
              >
            >
          (norm): GroupNorm<num_groups=1, num_channels=28>
          >
        >
      (3): Conv2d<input_channels=28, output_channels=28, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd20114220>, bias_init=None, format=NCHW>
      >
    (1): CellList<
      (0): ConvNextBlock<
        (mlp): SequentialCell<
          (0): GELU<>
          (1): Dense<input_channels=112, output_channels=28, has_bias=True>
          >
        (ds_conv): Conv2d<input_channels=28, output_channels=28, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=28, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd201144f0>, bias_init=None, format=NCHW>
        (net): SequentialCell<
          (0): GroupNorm<num_groups=1, num_channels=28>
          (1): Conv2d<input_channels=28, output_channels=112, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd21247ee0>, bias_init=None, format=NCHW>
          (2): GELU<>
          (3): GroupNorm<num_groups=1, num_channels=112>
          (4): Conv2d<input_channels=112, output_channels=56, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd201144c0>, bias_init=None, format=NCHW>
          >
        (res_conv): Conv2d<input_channels=28, output_channels=56, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd201145b0>, bias_init=None, format=NCHW>
        >
      (1): ConvNextBlock<
        (mlp): SequentialCell<
          (0): GELU<>
          (1): Dense<input_channels=112, output_channels=56, has_bias=True>
          >
        (ds_conv): Conv2d<input_channels=56, output_channels=56, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=56, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd201148e0>, bias_init=None, format=NCHW>
        (net): SequentialCell<
          (0): GroupNorm<num_groups=1, num_channels=56>
          (1): Conv2d<input_channels=56, output_channels=112, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd201149d0>, bias_init=None, format=NCHW>
          (2): GELU<>
          (3): GroupNorm<num_groups=1, num_channels=112>
          (4): Conv2d<input_channels=112, output_channels=56, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd20114940>, bias_init=None, format=NCHW>
          >
        (res_conv): Identity<>
        >
      (2): Residual<
        (fn): PreNorm<
          (fn): LinearAttention<
            (to_qkv): Conv2d<input_channels=56, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd20114970>, bias_init=None, format=NCHW>
            (to_out): SequentialCell<
              (0): Conv2d<input_channels=128, output_channels=56, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd20114580>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffd20114ac0>, format=NCHW>
              (1): LayerNorm<>
              >
            >
          (norm): GroupNorm<num_groups=1, num_channels=56>
          >
        >
      (3): Conv2d<input_channels=56, output_channels=56, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd20114d90>, bias_init=None, format=NCHW>
      >
    (2): CellList<
      (0): ConvNextBlock<
        (mlp): SequentialCell<
          (0): GELU<>
          (1): Dense<input_channels=112, output_channels=56, has_bias=True>
          >
        (ds_conv): Conv2d<input_channels=56, output_channels=56, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=56, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd20114ca0>, bias_init=None, format=NCHW>
        (net): SequentialCell<
          (0): GroupNorm<num_groups=1, num_channels=56>
          (1): Conv2d<input_channels=56, output_channels=224, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd20114be0>, bias_init=None, format=NCHW>
          (2): GELU<>
          (3): GroupNorm<num_groups=1, num_channels=224>
          (4): Conv2d<input_channels=224, output_channels=112, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200da100>, bias_init=None, format=NCHW>
          >
        (res_conv): Conv2d<input_channels=56, output_channels=112, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200da250>, bias_init=None, format=NCHW>
        >
      (1): ConvNextBlock<
        (mlp): SequentialCell<
          (0): GELU<>
          (1): Dense<input_channels=112, output_channels=112, has_bias=True>
          >
        (ds_conv): Conv2d<input_channels=112, output_channels=112, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=112, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200da520>, bias_init=None, format=NCHW>
        (net): SequentialCell<
          (0): GroupNorm<num_groups=1, num_channels=112>
          (1): Conv2d<input_channels=112, output_channels=224, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200da550>, bias_init=None, format=NCHW>
          (2): GELU<>
          (3): GroupNorm<num_groups=1, num_channels=224>
          (4): Conv2d<input_channels=224, output_channels=112, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200da4f0>, bias_init=None, format=NCHW>
          >
        (res_conv): Identity<>
        >
      (2): Residual<
        (fn): PreNorm<
          (fn): LinearAttention<
            (to_qkv): Conv2d<input_channels=112, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200da040>, bias_init=None, format=NCHW>
            (to_out): SequentialCell<
              (0): Conv2d<input_channels=128, output_channels=112, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200da730>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffd200da850>, format=NCHW>
              (1): LayerNorm<>
              >
            >
          (norm): GroupNorm<num_groups=1, num_channels=112>
          >
        >
      (3): Identity<>
      >
    >
  (ups): CellList<
    (0): CellList<
      (0): ConvNextBlock<
        (mlp): SequentialCell<
          (0): GELU<>
          (1): Dense<input_channels=112, output_channels=224, has_bias=True>
          >
        (ds_conv): Conv2d<input_channels=224, output_channels=224, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=224, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200b60a0>, bias_init=None, format=NCHW>
        (net): SequentialCell<
          (0): GroupNorm<num_groups=1, num_channels=224>
          (1): Conv2d<input_channels=224, output_channels=112, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200b65e0>, bias_init=None, format=NCHW>
          (2): GELU<>
          (3): GroupNorm<num_groups=1, num_channels=112>
          (4): Conv2d<input_channels=112, output_channels=56, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200b6640>, bias_init=None, format=NCHW>
          >
        (res_conv): Conv2d<input_channels=224, output_channels=56, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200b67f0>, bias_init=None, format=NCHW>
        >
      (1): ConvNextBlock<
        (mlp): SequentialCell<
          (0): GELU<>
          (1): Dense<input_channels=112, output_channels=56, has_bias=True>
          >
        (ds_conv): Conv2d<input_channels=56, output_channels=56, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=56, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200b6940>, bias_init=None, format=NCHW>
        (net): SequentialCell<
          (0): GroupNorm<num_groups=1, num_channels=56>
          (1): Conv2d<input_channels=56, output_channels=112, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200b6190>, bias_init=None, format=NCHW>
          (2): GELU<>
          (3): GroupNorm<num_groups=1, num_channels=112>
          (4): Conv2d<input_channels=112, output_channels=56, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200b6490>, bias_init=None, format=NCHW>
          >
        (res_conv): Identity<>
        >
      (2): Residual<
        (fn): PreNorm<
          (fn): LinearAttention<
            (to_qkv): Conv2d<input_channels=56, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200b6be0>, bias_init=None, format=NCHW>
            (to_out): SequentialCell<
              (0): Conv2d<input_channels=128, output_channels=56, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200b6d00>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffd200b6e80>, format=NCHW>
              (1): LayerNorm<>
              >
            >
          (norm): GroupNorm<num_groups=1, num_channels=56>
          >
        >
      (3): Conv2dTranspose<input_channels=56, output_channels=56, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200b6bb0>, bias_init=None, format=NCHW>
      >
    (1): CellList<
      (0): ConvNextBlock<
        (mlp): SequentialCell<
          (0): GELU<>
          (1): Dense<input_channels=112, output_channels=112, has_bias=True>
          >
        (ds_conv): Conv2d<input_channels=112, output_channels=112, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=112, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd007d6160>, bias_init=None, format=NCHW>
        (net): SequentialCell<
          (0): GroupNorm<num_groups=1, num_channels=112>
          (1): Conv2d<input_channels=112, output_channels=56, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd007d6280>, bias_init=None, format=NCHW>
          (2): GELU<>
          (3): GroupNorm<num_groups=1, num_channels=56>
          (4): Conv2d<input_channels=56, output_channels=28, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd007d62e0>, bias_init=None, format=NCHW>
          >
        (res_conv): Conv2d<input_channels=112, output_channels=28, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200b6d60>, bias_init=None, format=NCHW>
        >
      (1): ConvNextBlock<
        (mlp): SequentialCell<
          (0): GELU<>
          (1): Dense<input_channels=112, output_channels=28, has_bias=True>
          >
        (ds_conv): Conv2d<input_channels=28, output_channels=28, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=28, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd007d6580>, bias_init=None, format=NCHW>
        (net): SequentialCell<
          (0): GroupNorm<num_groups=1, num_channels=28>
          (1): Conv2d<input_channels=28, output_channels=56, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd007d6490>, bias_init=None, format=NCHW>
          (2): GELU<>
          (3): GroupNorm<num_groups=1, num_channels=56>
          (4): Conv2d<input_channels=56, output_channels=28, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd007d6730>, bias_init=None, format=NCHW>
          >
        (res_conv): Identity<>
        >
      (2): Residual<
        (fn): PreNorm<
          (fn): LinearAttention<
            (to_qkv): Conv2d<input_channels=28, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd007d68b0>, bias_init=None, format=NCHW>
            (to_out): SequentialCell<
              (0): Conv2d<input_channels=128, output_channels=28, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200b6fa0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffd007d6a30>, format=NCHW>
              (1): LayerNorm<>
              >
            >
          (norm): GroupNorm<num_groups=1, num_channels=28>
          >
        >
      (3): Conv2dTranspose<input_channels=28, output_channels=28, kernel_size=(4, 4), stride=(2, 2), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd007d66a0>, bias_init=None, format=NCHW>
      >
    >
  (mid_block1): ConvNextBlock<
    (mlp): SequentialCell<
      (0): GELU<>
      (1): Dense<input_channels=112, output_channels=112, has_bias=True>
      >
    (ds_conv): Conv2d<input_channels=112, output_channels=112, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=112, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200da880>, bias_init=None, format=NCHW>
    (net): SequentialCell<
      (0): GroupNorm<num_groups=1, num_channels=112>
      (1): Conv2d<input_channels=112, output_channels=224, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200dab80>, bias_init=None, format=NCHW>
      (2): GELU<>
      (3): GroupNorm<num_groups=1, num_channels=224>
      (4): Conv2d<input_channels=224, output_channels=112, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200dabe0>, bias_init=None, format=NCHW>
      >
    (res_conv): Identity<>
    >
  (mid_attn): Residual<
    (fn): PreNorm<
      (fn): Attention<
        (to_qkv): Conv2d<input_channels=112, output_channels=384, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200dad90>, bias_init=None, format=NCHW>
        (to_out): Conv2d<input_channels=128, output_channels=112, kernel_size=(1, 1), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200dafd0>, bias_init=<mindspore.common.initializer.Uniform object at 0xfffd20114ee0>, format=NCHW>
        >
      (norm): GroupNorm<num_groups=1, num_channels=112>
      >
    >
  (mid_block2): ConvNextBlock<
    (mlp): SequentialCell<
      (0): GELU<>
      (1): Dense<input_channels=112, output_channels=112, has_bias=True>
      >
    (ds_conv): Conv2d<input_channels=112, output_channels=112, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=112, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200b6160>, bias_init=None, format=NCHW>
    (net): SequentialCell<
      (0): GroupNorm<num_groups=1, num_channels=112>
      (1): Conv2d<input_channels=112, output_channels=224, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200b6280>, bias_init=None, format=NCHW>
      (2): GELU<>
      (3): GroupNorm<num_groups=1, num_channels=224>
      (4): Conv2d<input_channels=224, output_channels=112, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd200b62e0>, bias_init=None, format=NCHW>
      >
    (res_conv): Identity<>
    >
  (final_conv): SequentialCell<
    (0): ConvNextBlock<
      (ds_conv): Conv2d<input_channels=28, output_channels=28, kernel_size=(7, 7), stride=(1, 1), pad_mode=pad, padding=3, dilation=(1, 1), group=28, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd007d6520>, bias_init=None, format=NCHW>
      (net): SequentialCell<
        (0): GroupNorm<num_groups=1, num_channels=28>
        (1): Conv2d<input_channels=28, output_channels=56, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd007d6d90>, bias_init=None, format=NCHW>
        (2): GELU<>
        (3): GroupNorm<num_groups=1, num_channels=56>
        (4): Conv2d<input_channels=56, output_channels=28, kernel_size=(3, 3), stride=(1, 1), pad_mode=pad, padding=1, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd007d6df0>, bias_init=None, format=NCHW>
        >
      (res_conv): Identity<>
      >
    (1): Conv2d<input_channels=28, output_channels=1, kernel_size=(1, 1), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=<mindspore.common.initializer.HeUniform object at 0xfffd007d6d30>, bias_init=None, format=NCHW>
    >
  >

推理过程(从模型中采样)

import matplotlib.animation as animation


# 采样64个图片
unet_model.set_train(False)
samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)


# 展示一个随机效果
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")


random_index = 53

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=100)
animate.save('diffusion.gif')
plt.show()

 

参考

[1] https://huggingface.co/blog/annotated-diffusion

[2] https://zhuanlan.zhihu.com/p/525106459

[3] https://twitter.com/sedielem/status/1530894256168222722?s=20&t=mfv4afx1GcNQU5fZklpACw ----->> Generative Modeling by Estimating Gradients of the Data Distribution | Yang Song

https://arxiv.org/abs/1312.6114

[4] U-Net:Ronneberger et al.,2015

[5] 梯度流:He et al., 2015

[6] VAE:https://arxiv.org/abs/1312.6114

[7] A Recipe for Training Neural Networks

[8] The Illustrated Transformer – Jay Alammar – Visualizing machine learning one concept at a time.

  • 9
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值