昇思 25 天学习打卡营第 25 天 | MindSpore Diffusion 扩散模型

1. 背景:

使用 MindSpore 学习神经网络,打卡第 25 天;主要内容也依据 mindspore 的学习记录。

2. Diffusion 介绍:

Diffusion也是从纯噪声开始通过一个神经网络学习逐步去噪,最终得到一个实际图像。

  • 具体的 Diffusion 扩散模型的原理,可以参考:https://zhuanlan.zhihu.com/p/650394311?utm_id=0

  • MindSpore 的 docs 中有详细的说明;
    https://gitee.com/mindspore/docs/blob/r2.3/tutorials/application/source_zh_cn/generative/diffusion.ipynb

  • PyTorch 版本地址:
    https://github.com/lucidrains/denoising-diffusion-pytorch

  • 过程:
    Diffusion对于图像的处理包括以下两个过程:
    a. 我们选择的固定(或预定义)正向扩散过程 q,逐渐将高斯噪声添加到图像中,直到最终得到纯噪声
    b. 一个学习的反向去噪的扩散过程 pθ 通过训练神经网络从纯噪声开始逐渐对图像去噪,直到最终得到一个实际的图像
    在这里插入图片描述

3. 具体实现:

mindspore 实验中使用离散时间;

3.1 数据下载与处理:

import math
from functools import partial
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np
from multiprocessing import cpu_count
from download import download

import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype
from mindspore.dataset.vision import Resize, Inter, CenterCrop, ToTensor, RandomHorizontalFlip, ToPIL
from mindspore.common.initializer import initializer
from mindspore.amp import DynamicLossScaler

ms.set_seed(0)

# 下载猫猫图像
url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/image_cat.zip'
path = download(url, './', kind="zip", replace=True)

from mindspore.dataset import ImageFolderDataset

# 从PIL图像转换到mindspore张量, 线性缩放为 [−1,1]
image_size = 128
transforms = [
    Resize(image_size, Inter.BILINEAR),
    CenterCrop(image_size),
    ToTensor(),
    lambda t: (t * 2) - 1
]


path = './image_cat'
dataset = ImageFolderDataset(dataset_dir=path, num_parallel_workers=cpu_count(),
                             extensions=['.jpg', '.jpeg', '.png', '.tiff'],
                             num_shards=1, shard_id=0, shuffle=False,  decode=True)
dataset = dataset.project('image')
transforms.insert(1, RandomHorizontalFlip())
dataset_1 = dataset.map(transforms, 'image')
dataset_2 = dataset_1.batch(1, drop_remainder=True)
x_start = next(dataset_2.create_tuple_iterator())[0]
print(x_start.shape)

import numpy as np

# 反向变换,它接收一个包含 [−1,1] 中的张量,并将它们转回 PIL 图像
reverse_transform = [
    lambda t: (t + 1) / 2,
    lambda t: ops.permute(t, (1, 2, 0)), # CHW to HWC
    lambda t: t * 255.,
    lambda t: t.asnumpy().astype(np.uint8),
    ToPIL()
]

def compose(transform, x):
    for d in transform:
        x = d(x)
    return x

reverse_image = compose(reverse_transform, x_start[0])
reverse_image.show()

3.2 构造 Diffusion 网络

  • 定义模型的辅助方法:
def rearrange(head, inputs):
    b, hc, x, y = inputs.shape
    c = hc // head
    return inputs.reshape((b, head, c, x * y))

def rsqrt(x):
    res = ops.sqrt(x)
    return ops.inv(res)

def randn_like(x, dtype=None):
    if dtype is None:
        dtype = x.dtype
    res = ops.standard_normal(x.shape).astype(dtype)
    return res

def randn(shape, dtype=None):
    if dtype is None:
        dtype = ms.float32
    res = ops.standard_normal(shape).astype(dtype)
    return res

def randint(low, high, size, dtype=ms.int32):
    res = ops.uniform(size, Tensor(low, dtype), Tensor(high, dtype), dtype=dtype)
    return res

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def _check_dtype(d1, d2):
    if ms.float32 in (d1, d2):
        return ms.float32
    if d1 == d2:
        return d1
    raise ValueError('dtype is not supported.')

class Residual(nn.Cell):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def construct(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

def Upsample(dim):
    return nn.Conv2dTranspose(dim, dim, 4, 2, pad_mode="pad", padding=1)

def Downsample(dim):
    return nn.Conv2d(dim, dim, 4, 2, pad_mode="pad", padding=1)
  • 位置向量
  • 由于神经网络的参数在时间(噪声水平)上共享,作者使用正弦位置嵌入来编码t
    ,灵感来自Transformer(Vaswani et al., 2017)。对于批处理中的每一张图像,神经网络"知道"它在哪个特定时间步长(噪声水平)上运行
class SinusoidalPositionEmbeddings(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = np.exp(np.arange(half_dim) * - emb)
        self.emb = Tensor(emb, ms.float32)

    def construct(self, x):
        emb = x[:, None] * self.emb[None, :]
        emb = ops.concat((ops.sin(emb), ops.cos(emb)), axis=-1)
        return emb
  • ResNet/ConvNeXT块

原作者使用 Wide ResNet 块,Phil Wang决定添加 ConvNeXT 替换 ResNet block.

class Block(nn.Cell):
    def __init__(self, dim, dim_out, groups=1):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, pad_mode="pad", padding=1)
        self.proj = c(dim, dim_out, 3, padding=1, pad_mode='pad')
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def construct(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class ConvNextBlock(nn.Cell):
    def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
        super().__init__()
        self.mlp = (
            nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))
            if exists(time_emb_dim)
            else None
        )

        self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")
        self.net = nn.SequentialCell(
            nn.GroupNorm(1, dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),
            nn.GELU(),
            nn.GroupNorm(1, dim_out * mult),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),
        )

        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def construct(self, x, time_emb=None):
        h = self.ds_conv(x)
        if exists(self.mlp) and exists(time_emb):
            assert exists(time_emb), "time embedding must be passed in"
            condition = self.mlp(time_emb)
            condition = condition.expand_dims(-1).expand_dims(-1)
            h = h + condition

        h = self.net(h)
        return h + self.res_conv(x)

  • Attention 模块:
    DDPM作者将其添加到卷积块之间。Attention是著名的Transformer架构(Vaswani et al., 2017),在人工智能的各个领域都取得了巨大的成功,从NLP到蛋白质折叠。Phil Wang使用了两种注意力变体:一种是常规的multi-head self-attention(如Transformer中使用的),另一种是LinearAttention(Shen et al., 2018),其时间和内存要求在序列长度上线性缩放,而不是在常规注意力中缩放。 要想对Attention机制进行深入的了解。
class Attention(nn.Cell):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True)
        self.map = ops.Map()
        self.partial = ops.Partial()

    def construct(self, x):
        b, _, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, 1)
        q, k, v = self.map(self.partial(rearrange, self.heads), qkv)

        q = q * self.scale

        # 'b h d i, b h d j -> b h i j'
        sim = ops.bmm(q.swapaxes(2, 3), k)
        attn = ops.softmax(sim, axis=-1)
        # 'b h i j, b h d j -> b h i d'
        out = ops.bmm(attn, v.swapaxes(2, 3))
        out = out.swapaxes(-1, -2).reshape((b, -1, h, w))

        return self.to_out(out)


class LayerNorm(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        self.g = Parameter(initializer('ones', (1, dim, 1, 1)), name='g')

    def construct(self, x):
        eps = 1e-5
        var = x.var(1, keepdims=True)
        mean = x.mean(1, keep_dims=True)
        return (x - mean) * rsqrt((var + eps)) * self.g


class LinearAttention(nn.Cell):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)

        self.to_out = nn.SequentialCell(
            nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True),
            LayerNorm(dim)
        )

        self.map = ops.Map()
        self.partial = ops.Partial()

    def construct(self, x):
        b, _, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, 1)
        q, k, v = self.map(self.partial(rearrange, self.heads), qkv)

        q = ops.softmax(q, -2)
        k = ops.softmax(k, -1)

        q = q * self.scale
        v = v / (h * w)

        # 'b h d n, b h e n -> b h d e'
        context = ops.bmm(k, v.swapaxes(2, 3))
        # 'b h d e, b h d n -> b h e n'
        out = ops.bmm(context.swapaxes(2, 3), q)

        out = out.reshape((b, -1, h, w))
        return self.to_out(out)
  • 组归一化
    作者将U-Net的卷积/注意层与群归一化(Wu et al., 2018)。下面,我们定义一个PreNorm类,将用于在注意层之前应用groupnorm。
class PreNorm(nn.Cell):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def construct(self, x):
        x = self.norm(x)
        return self.fn(x)
  • 条件 U-Net:

网络构建过程如下:
a. 将卷积层应用于噪声图像批上,并计算噪声水平的位置
b. 接下来,应用一系列下采样级。每个下采样阶段由2个ResNet/ConvNeXT块 + groupnorm + attention + 残差连接 + 一个下采样操作组成
c. 在网络的中间,再次应用ResNet或ConvNeXT块,并与attention交织
d. 应用一系列上采样级。每个上采样级由2个ResNet/ConvNeXT块+ groupnorm + attention + 残差连接 + 一个上采样操作组成
e. 应用ResNet/ConvNeXT块,然后应用卷积层

class Unet(nn.Cell):
    def __init__(
            self,
            dim,
            init_dim=None,
            out_dim=None,
            dim_mults=(1, 2, 4, 8),
            channels=3,
            with_time_emb=True,
            convnext_mult=2,
    ):
        super().__init__()

        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ConvNextBlock, mult=convnext_mult)

        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.SequentialCell(
                SinusoidalPositionEmbeddings(dim),
                nn.Dense(dim, time_dim),
                nn.GELU(),
                nn.Dense(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None

        self.downs = nn.CellList([])
        self.ups = nn.CellList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.CellList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Downsample(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(
                nn.CellList(
                    [
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )

        out_dim = default(out_dim, channels)
        self.final_conv = nn.SequentialCell(
            block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
        )

    def construct(self, x, time):
        x = self.init_conv(x)

        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        len_h = len(h) - 1
        for block1, block2, attn, upsample in self.ups:
            x = ops.concat((x, h[len_h]), 1)
            len_h -= 1
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)
        return self.final_conv(x)

3.3 正向扩散

正向扩散

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return np.linspace(beta_start, beta_end, timesteps).astype(np.float32)

# 扩散200步
timesteps = 200

# 定义 beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# 定义 alphas
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.pad(alphas_cumprod[:-1], (1, 0), constant_values=1)

sqrt_recip_alphas = Tensor(np.sqrt(1. / alphas))
sqrt_alphas_cumprod = Tensor(np.sqrt(alphas_cumprod))
sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1. - alphas_cumprod))

# 计算 q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

p2_loss_weight = (1 + alphas_cumprod / (1 - alphas_cumprod)) ** -0.
p2_loss_weight = Tensor(p2_loss_weight)

def extract(a, t, x_shape):
    b = t.shape[0]
    out = Tensor(a).gather(t, -1)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))
  • 正向扩散的过程,并在特定时间步长上测试它
def get_noisy_image(x_start, t):
    # 添加噪音
    x_noisy = q_sample(x_start, t=t)

    # 转换为 PIL 图像
    noisy_image = compose(reverse_transform, x_noisy[0])

    return noisy_image


# 设置 time step
t = Tensor([40])
noisy_image = get_noisy_image(x_start, t)
print(noisy_image)
noisy_image.show()
  • 可视化
import matplotlib.pyplot as plt

def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    _, axs = plt.subplots(figsize=(200, 200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [image] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()
 
plot([get_noisy_image(x_start, Tensor([t])) for t in [0, 50, 100, 150, 199]])
  • 定义模型的损失函数
def p_losses(unet_model, x_start, t, noise=None):
    if noise is None:
        noise = randn_like(x_start)
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = unet_model(x_noisy, t)

    loss = nn.SmoothL1Loss()(noise, predicted_noise)# todo
    loss = loss.reshape(loss.shape[0], -1)
    loss = loss * extract(p2_loss_weight, t, loss.shape)
    return loss.mean()

3.4 训练

  • 下载数据并转换
# 下载MNIST数据集
url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset.zip'
path = download(url, './', kind="zip", replace=True)

from mindspore.dataset import FashionMnistDataset

image_size = 28
channels = 1
batch_size = 16

fashion_mnist_dataset_dir = "./dataset"
dataset = FashionMnistDataset(dataset_dir=fashion_mnist_dataset_dir, usage="train", num_parallel_workers=cpu_count(), shuffle=True, num_shards=1, shard_id=0)

transforms = [
    RandomHorizontalFlip(),
    ToTensor(),
    lambda t: (t * 2) - 1
]


dataset = dataset.project('image')
dataset = dataset.shuffle(64)
dataset = dataset.map(transforms, 'image')
dataset = dataset.batch(16, drop_remainder=True)

x = next(dataset.create_dict_iterator())
print(x.keys())
  • 采样:
    训练期间从模型中采样(以便跟踪进度)
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)

    if t_index == 0:
        return model_mean
    posterior_variance_t = extract(posterior_variance, t, x.shape)
    noise = randn_like(x)
    return model_mean + ops.sqrt(posterior_variance_t) * noise

def p_sample_loop(model, shape):
    b = shape[0]
    # 从纯噪声开始
    img = randn(shape, dtype=None)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, ms.numpy.full((b,), i, dtype=mstype.int32), i)
        imgs.append(img.asnumpy())
    return imgs

def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

  • 训练
# 定义动态学习率
lr = nn.cosine_decay_lr(min_lr=1e-7, max_lr=1e-4, total_step=10*3750, step_per_epoch=3750, decay_epoch=10)

# 定义 Unet模型
unet_model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)

name_list = []
for (name, par) in list(unet_model.parameters_and_names()):
    name_list.append(name)
i = 0
for item in list(unet_model.trainable_params()):
    item.name = name_list[i]
    i += 1

# 定义优化器
optimizer = nn.Adam(unet_model.trainable_params(), learning_rate=lr)
loss_scaler = DynamicLossScaler(65536, 2, 1000)

# 定义前向过程
def forward_fn(data, t, noise=None):
    loss = p_losses(unet_model, data, t, noise)
    return loss

# 计算梯度
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)

# 梯度更新
def train_step(data, t, noise):
    loss, grads = grad_fn(data, t, noise)
    optimizer(grads)
    return loss

import time

epochs = 10
for epoch in range(epochs):
    begin_time = time.time()
    for step, batch in enumerate(dataset.create_tuple_iterator()):
        unet_model.set_train()
        batch_size = batch[0].shape[0]
        t = randint(0, timesteps, (batch_size,), dtype=ms.int32)
        noise = randn_like(batch[0])
        loss = train_step(batch[0], t, noise)

        if step % 500 == 0:
            print(" epoch: ", epoch, " step: ", step, " Loss: ", loss)
    end_time = time.time()
    times = end_time - begin_time
    print("training time:", times, "s")
    # 展示随机采样效果
    unet_model.set_train(False)
    samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)
    plt.imshow(samples[-1][5].reshape(image_size, image_size, channels), cmap="gray")
print("Training Success!")

3.5 推理

要从模型中采样,我们可以只使用上面定义的采样函数

# 采样64个图片
unet_model.set_train(False)
samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)

# 展示一个随机效果
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

# 创建去噪过程
import matplotlib.animation as animation

random_index = 53

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=100)
animate.save('diffusion.gif')
plt.show()

4. 相关链接:

  • 18
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值