关于扩散模型(Diffusion Models)有很多种理解,本文的介绍是基于denoising diffusion probabilistic model (DDPM),DDPM已经在(无)条件图像/音频/视频生成领域取得了较多显著的成果,现有的比较受欢迎的的例子包括由OpenAI主导的GLIDE和DALL-E 2、由海德堡大学主导的潜在扩散和由Google Brain主导的图像生成。
实际上生成模型的扩散概念已经在(Sohl-Dickstein et al., 2015)中介绍过。然而,直到(Song et al., 2019)(斯坦福大学)和(Ho et al., 2020)(在Google Brain)才各自独立地改进了这种方法。
本文是在Phil Wang基于PyTorch框架的复现的基础上(而它本身又是基于TensorFlow实现),迁移到MindSpore AI框架上实现的。
导入环境
import math
from functools import partial
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np
from multiprocessing import cpu_count
from download import download
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype
from mindspore.dataset.vision import Resize, Inter, CenterCrop, ToTensor, RandomHorizontalFlip, ToPIL
from mindspore.common.initializer import initializer
from mindspore.amp import DynamicLossScaler
ms.set_seed(0)
模型简介
什么是Diffusion Model?
如果将Diffusion与其他生成模型(如Normalizing Flows、GAN或VAE)进行比较,它并没有那么复杂,它们都将噪声从一些简单分布转换为数据样本,Diffusion也是从纯噪声开始通过一个神经网络学习逐步去噪,最终得到一个实际图像。 Diffusion对于图像的处理包括以下两个过程:
-
我们选择的固定(或预定义)正向扩散过程 𝑞𝑞 :它逐渐将高斯噪声添加到图像中,直到最终得到纯噪声
-
一个学习的反向去噪的扩散过程 𝑝𝜃𝑝𝜃 :通过训练神经网络从纯噪声开始逐渐对图像去噪,直到最终得到一个实际的图像
由 𝑡索引的正向和反向过程都发生在某些有限时间步长 𝑇(DDPM作者使用 𝑇=1000)内。从𝑡=0开始,在数据分布中采样真实图像 𝐱0(本文使用一张来自ImageNet的猫图像形象的展示了diffusion正向添加噪声的过程),正向过程在每个时间步长 𝑡 都从高斯分布中采样一些噪声,再添加到上一个时刻的图像中。假定给定一个足够大的 𝑇 和一个在每个时间步长添加噪声的良好时间表,您最终会在 𝑡=𝑇 通过渐进的过程得到所谓的各向同性的高斯分布。、
构建Diffusion模型
下面,我们逐步构建Diffusion模型。
首先,我们定义了一些帮助函数和类,这些函数和类将在实现神经网络时使用。、
def rearrange(head, inputs):
b, hc, x, y = inputs.shape
c = hc // head
return inputs.reshape((b, head, c, x * y))
def rsqrt(x):
res = ops.sqrt(x)
return ops.inv(res)
def randn_like(x, dtype=None):
if dtype is None:
dtype = x.dtype
res = ops.standard_normal(x.shape).astype(dtype)
return res
def randn(shape, dtype=None):
if dtype is None:
dtype = ms.float32
res = ops.standard_normal(shape).astype(dtype)
return res
def randint(low, high, size, dtype=ms.int32):
res = ops.uniform(size, Tensor(low, dtype), Tensor(high, dtype), dtype=dtype)
return res
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def _check_dtype(d1, d2):
if ms.float32 in (d1, d2):
return ms.float32
if d1 == d2:
return d1
raise ValueError('dtype is not supported.')
class Residual(nn.Cell):
def __init__(self, fn):
super().__init__()
self.fn = fn
def construct(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
def Upsample(dim):
return nn.Conv2dTranspose(dim, dim, 4, 2, pad_mode="pad", padding=1)
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, pad_mode="pad", padding=1)
位置向量
由于神经网络的参数在时间(噪声水平)上共享,作者使用正弦位置嵌入来编码𝑡𝑡,灵感来自Transformer(Vaswani et al., 2017)。对于批处理中的每一张图像,神经网络"知道"它在哪个特定时间步长(噪声水平)上运行。
SinusoidalPositionEmbeddings
模块采用(batch_size, 1)
形状的张量作为输入(即批处理中几个有噪声图像的噪声水平),并将其转换为(batch_size, dim)
形状的张量,其中dim
是位置嵌入的尺寸。然后,我们将其添加到每个剩余块中。
class SinusoidalPositionEmbeddings(nn.Cell):
def __init__(self, dim):
super().__init__()
self.dim = dim
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = np.exp(np.arange(half_dim) * - emb)
self.emb = Tensor(emb, ms.float32)
def construct(self, x):
emb = x[:, None] * self.emb[None, :]
emb = ops.concat((ops.sin(emb), ops.cos(emb)), axis=-1)
return emb
ResNet/ConvNeXT块
接下来,我们定义U-Net模型的核心构建块。DDPM作者使用了一个Wide ResNet块(Zagoruyko et al., 2016),但Phil Wang决定添加ConvNeXT(Liu et al., 2022)替换ResNet,因为后者在图像领域取得了巨大成功。
在最终的U-Net架构中,可以选择其中一个或另一个,本文选择ConvNeXT块构建U-Net模型。
class Block(nn.Cell):
def __init__(self, dim, dim_out, groups=1):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, pad_mode="pad", padding=1)
self.proj = c(dim, dim_out, 3, padding=1, pad_mode='pad')
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def construct(self, x, scale_shift=None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ConvNextBlock(nn.Cell):
def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
super().__init__()
self.mlp = (
nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))
if exists(time_emb_dim)
else None
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")
self.net = nn.SequentialCell(
nn.GroupNorm(1, dim) if norm else nn.Identity(),
nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),
nn.GELU(),
nn.GroupNorm(1, dim_out * mult),
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def construct(self, x, time_emb=None):
h = self.ds_conv(x)
if exists(self.mlp) and exists(time_emb):
assert exists(time_emb), "time embedding must be passed in"
condition = self.mlp(time_emb)
condition = condition.expand_dims(-1).expand_dims(-1)
h = h + condition
h = self.net(h)
return h + self.res_conv(x)
Attention模块
接下来,我们定义Attention模块,DDPM作者将其添加到卷积块之间。Attention是著名的Transformer架构(Vaswani et al., 2017),在人工智能的各个领域都取得了巨大的成功,从NLP到蛋白质折叠。Phil Wang使用了两种注意力变体:一种是常规的multi-head self-attention(如Transformer中使用的),另一种是LinearAttention(Shen et al., 2018),其时间和内存要求在序列长度上线性缩放,而不是在常规注意力中缩放。 要想对Attention机制进行深入的了解,请参照Jay Allamar的精彩的博文。
class Attention(nn.Cell):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True)
self.map = ops.Map()
self.partial = ops.Partial()
def construct(self, x):
b, _, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, 1)
q, k, v = self.map(self.partial(rearrange, self.heads), qkv)
q = q * self.scale
# 'b h d i, b h d j -> b h i j'
sim = ops.bmm(q.swapaxes(2, 3), k)
attn = ops.softmax(sim, axis=-1)
# 'b h i j, b h d j -> b h i d'
out = ops.bmm(attn, v.swapaxes(2, 3))
out = out.swapaxes(-1, -2).reshape((b, -1, h, w))
return self.to_out(out)
class LayerNorm(nn.Cell):
def __init__(self, dim):
super().__init__()
self.g = Parameter(initializer('ones', (1, dim, 1, 1)), name='g')
def construct(self, x):
eps = 1e-5
var = x.var(1, keepdims=True)
mean = x.mean(1, keep_dims=True)
return (x - mean) * rsqrt((var + eps)) * self.g
class LinearAttention(nn.Cell):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
self.to_out = nn.SequentialCell(
nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True),
LayerNorm(dim)
)
self.map = ops.Map()
self.partial = ops.Partial()
def construct(self, x):
b, _, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, 1)
q, k, v = self.map(self.partial(rearrange, self.heads), qkv)
q = ops.softmax(q, -2)
k = ops.softmax(k, -1)
q = q * self.scale
v = v / (h * w)
# 'b h d n, b h e n -> b h d e'
context = ops.bmm(k, v.swapaxes(2, 3))
# 'b h d e, b h d n -> b h e n'
out = ops.bmm(context.swapaxes(2, 3), q)
out = out.reshape((b, -1, h, w))
return self.to_out(out)
组归一化
DDPM作者将U-Net的卷积/注意层与群归一化(Wu et al., 2018)。下面,我们定义一个PreNorm
类,将用于在注意层之前应用groupnorm。
class PreNorm(nn.Cell):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.GroupNorm(1, dim)
def construct(self, x):
x = self.norm(x)
return self.fn(x)
条件U-Net
我们已经定义了所有的构建块(位置嵌入、ResNet/ConvNeXT块、Attention和组归一化),现在需要定义整个神经网络了。请记住,网络 𝜖𝜃(𝐱𝑡,𝑡)的工作是接收一批噪声图像+噪声水平,并输出添加到输入中的噪声。
更具体的: 网络获取了一批(batch_size, num_channels, height, width)
形状的噪声图像和一批(batch_size, 1)
形状的噪音水平作为输入,并返回(batch_size, num_channels, height, width)
形状的张量。
网络构建过程如下:
-
首先,将卷积层应用于噪声图像批上,并计算噪声水平的位置
-
接下来,应用一系列下采样级。每个下采样阶段由2个ResNet/ConvNeXT块 + groupnorm + attention + 残差连接 + 一个下采样操作组成
-
在网络的中间,再次应用ResNet或ConvNeXT块,并与attention交织
-
接下来,应用一系列上采样级。每个上采样级由2个ResNet/ConvNeXT块+ groupnorm + attention + 残差连接 + 一个上采样操作组成
-
最后,应用ResNet/ConvNeXT块,然后应用卷积层
最终,神经网络将层堆叠起来,就像它们是乐高积木一样(但重要的是了解它们是如何工作的)。
class Unet(nn.Cell):
def __init__(
self,
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
with_time_emb=True,
convnext_mult=2,
):
super().__init__()
self.channels = channels
init_dim = default(init_dim, dim // 3 * 2)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
block_klass = partial(ConvNextBlock, mult=convnext_mult)
if with_time_emb:
time_dim = dim * 4
self.time_mlp = nn.SequentialCell(
SinusoidalPositionEmbeddings(dim),
nn.Dense(dim, time_dim),
nn.GELU(),
nn.Dense(time_dim, time_dim),
)
else:
time_dim = None
self.time_mlp = None
self.downs = nn.CellList([])
self.ups = nn.CellList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(
nn.CellList(
[
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity(),
]
)
)
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(
nn.CellList(
[
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity(),
]
)
)
out_dim = default(out_dim, channels)
self.final_conv = nn.SequentialCell(
block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
)
def construct(self, x, time):
x = self.init_conv(x)
t = self.time_mlp(time) if exists(self.time_mlp) else None
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
len_h = len(h) - 1
for block1, block2, attn, upsample in self.ups:
x = ops.concat((x, h[len_h]), 1)
len_h -= 1
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)
return self.final_conv(x)
正向扩散
我们已经知道正向扩散过程在多个时间步长𝑇中,从实际分布逐渐向图像添加噪声,根据差异计划进行正向扩散。最初的DDPM作者采用了线性时间表:
-
我们将正向过程方差设置为常数,从𝛽1=10e−4线性增加到𝛽𝑇=0.02。
-
但是,它在(Nichol et al., 2021)中表明,当使用余弦调度时,可以获得更好的结果。
下面,我们定义了𝑇时间步的时间表。
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return np.linspace(beta_start, beta_end, timesteps).astype(np.float32)
# 扩散200步
timesteps = 200
# 定义 beta schedule
betas = linear_beta_schedule(timesteps=timesteps)
# 定义 alphas
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.pad(alphas_cumprod[:-1], (1, 0), constant_values=1)
sqrt_recip_alphas = Tensor(np.sqrt(1. / alphas))
sqrt_alphas_cumprod = Tensor(np.sqrt(alphas_cumprod))
sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1. - alphas_cumprod))
# 计算 q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
p2_loss_weight = (1 + alphas_cumprod / (1 - alphas_cumprod)) ** -0.
p2_loss_weight = Tensor(p2_loss_weight)
def extract(a, t, x_shape):
b = t.shape[0]
out = Tensor(a).gather(t, -1)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
我们将用猫图像说明如何在扩散过程的每个时间步骤中添加噪音。
# 下载猫猫图像
url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/image_cat.zip'
path = download(url, './', kind="zip", replace=True)
from PIL import Image
image = Image.open('./image_cat/jpg/000000039769.jpg')
base_width = 160
image = image.resize((base_width, int(float(image.size[1]) * float(base_width / float(image.size[0])))))
image.show()
from mindspore.dataset import ImageFolderDataset
image_size = 128
transforms = [
Resize(image_size, Inter.BILINEAR),
CenterCrop(image_size),
ToTensor(),
lambda t: (t * 2) - 1
]
path = './image_cat'
dataset = ImageFolderDataset(dataset_dir=path, num_parallel_workers=cpu_count(),
extensions=['.jpg', '.jpeg', '.png', '.tiff'],
num_shards=1, shard_id=0, shuffle=False, decode=True)
dataset = dataset.project('image')
transforms.insert(1, RandomHorizontalFlip())
dataset_1 = dataset.map(transforms, 'image')
dataset_2 = dataset_1.batch(1, drop_remainder=True)
x_start = next(dataset_2.create_tuple_iterator())[0]
print(x_start.shape)
import numpy as np
reverse_transform = [
lambda t: (t + 1) / 2,
lambda t: ops.permute(t, (1, 2, 0)), # CHW to HWC
lambda t: t * 255.,
lambda t: t.asnumpy().astype(np.uint8),
ToPIL()
]
def compose(transform, x):
for d in transform:
x = d(x)
return x
reverse_image = compose(reverse_transform, x_start[0])
reverse_image.show()
def q_sample(x_start, t, noise=None):
if noise is None:
noise = randn_like(x_start)
return (extract(sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def get_noisy_image(x_start, t):
# 添加噪音
x_noisy = q_sample(x_start, t=t)
# 转换为 PIL 图像
noisy_image = compose(reverse_transform, x_noisy[0])
return noisy_image
# 设置 time step
t = Tensor([40])
noisy_image = get_noisy_image(x_start, t)
print(noisy_image)
noisy_image.show()
import matplotlib.pyplot as plt
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
if not isinstance(imgs[0], list):
imgs = [imgs]
num_rows = len(imgs)
num_cols = len(imgs[0]) + with_orig
_, axs = plt.subplots(figsize=(200, 200), nrows=num_rows, ncols=num_cols, squeeze=False)
for row_idx, row in enumerate(imgs):
row = [image] + row if with_orig else row
for col_idx, img in enumerate(row):
ax = axs[row_idx, col_idx]
ax.imshow(np.asarray(img), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if with_orig:
axs[0, 0].set(title='Original image')
axs[0, 0].title.set_size(8)
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout()
plot([get_noisy_image(x_start, Tensor([t])) for t in [0, 50, 100, 150, 199]])
def p_losses(unet_model, x_start, t, noise=None):
if noise is None:
noise = randn_like(x_start)
x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
predicted_noise = unet_model(x_noisy, t)
loss = nn.SmoothL1Loss()(noise, predicted_noise)# todo
loss = loss.reshape(loss.shape[0], -1)
loss = loss * extract(p2_loss_weight, t, loss.shape)
return loss.mean()
数据准备与处理
在这里我们定义一个正则数据集。数据集可以来自简单的真实数据集的图像组成,如Fashion-MNIST、CIFAR-10或ImageNet,其中线性缩放为 [−1,1] 。
每个图像的大小都会调整为相同的大小。有趣的是,图像也是随机水平翻转的。根据论文内容:我们在CIFAR10的训练中使用了随机水平翻转;我们尝试了有翻转和没有翻转的训练,并发现翻转可以稍微提高样本质量。
本实验我们选用Fashion_MNIST数据集,我们使用download下载并解压Fashion_MNIST数据集到指定路径。此数据集由已经具有相同分辨率的图像组成,即28x28。
# 下载MNIST数据集
url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset.zip'
path = download(url, './', kind="zip", replace=True)
from mindspore.dataset import FashionMnistDataset
image_size = 28
channels = 1
batch_size = 16
fashion_mnist_dataset_dir = "./dataset"
dataset = FashionMnistDataset(dataset_dir=fashion_mnist_dataset_dir, usage="train", num_parallel_workers=cpu_count(), shuffle=True, num_shards=1, shard_id=0)
transforms = [
RandomHorizontalFlip(),
ToTensor(),
lambda t: (t * 2) - 1
]
dataset = dataset.project('image')
dataset = dataset.shuffle(64)
dataset = dataset.map(transforms, 'image')
dataset = dataset.batch(16, drop_remainder=True)
x = next(dataset.create_dict_iterator())
print(x.keys())
def p_sample(model, x, t, t_index):
betas_t = extract(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
if t_index == 0:
return model_mean
posterior_variance_t = extract(posterior_variance, t, x.shape)
noise = randn_like(x)
return model_mean + ops.sqrt(posterior_variance_t) * noise
def p_sample_loop(model, shape):
b = shape[0]
# 从纯噪声开始
img = randn(shape, dtype=None)
imgs = []
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
img = p_sample(model, img, ms.numpy.full((b,), i, dtype=mstype.int32), i)
imgs.append(img.asnumpy())
return imgs
def sample(model, image_size, batch_size=16, channels=3):
return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
# 定义动态学习率
lr = nn.cosine_decay_lr(min_lr=1e-7, max_lr=1e-4, total_step=10*3750, step_per_epoch=3750, decay_epoch=10)
# 定义 Unet模型
unet_model = Unet(
dim=image_size,
channels=channels,
dim_mults=(1, 2, 4,)
)
name_list = []
for (name, par) in list(unet_model.parameters_and_names()):
name_list.append(name)
i = 0
for item in list(unet_model.trainable_params()):
item.name = name_list[i]
i += 1
# 定义优化器
optimizer = nn.Adam(unet_model.trainable_params(), learning_rate=lr)
loss_scaler = DynamicLossScaler(65536, 2, 1000)
# 定义前向过程
def forward_fn(data, t, noise=None):
loss = p_losses(unet_model, data, t, noise)
return loss
# 计算梯度
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)
# 梯度更新
def train_step(data, t, noise):
loss, grads = grad_fn(data, t, noise)
optimizer(grads)
return loss
模型训练
import time
# 由于时间原因,epochs设置为1,可根据需求进行调整
epochs = 1
for epoch in range(epochs):
begin_time = time.time()
for step, batch in enumerate(dataset.create_tuple_iterator()):
unet_model.set_train()
batch_size = batch[0].shape[0]
t = randint(0, timesteps, (batch_size,), dtype=ms.int32)
noise = randn_like(batch[0])
loss = train_step(batch[0], t, noise)
if step % 500 == 0:
print(" epoch: ", epoch, " step: ", step, " Loss: ", loss)
end_time = time.time()
times = end_time - begin_time
print("training time:", times, "s")
# 展示随机采样效果
unet_model.set_train(False)
samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)
plt.imshow(samples[-1][5].reshape(image_size, image_size, channels), cmap="gray")
print("Training Success!")
总结
以上材料来自昇思打卡训练营。请注意,DDPM论文表明扩散模型是(非)条件图像有希望生成的方向。自那以后,diffusion得到了(极大的)改进,最明显的是文本条件图像生成。下面,我们列出了一些重要的(但远非详尽无遗的)后续工作:
-
改进的去噪扩散概率模型(Nichol et al., 2021):发现学习条件分布的方差(除平均值外)有助于提高性能
-
用于高保真图像生成的级联扩散模型([Ho et al., 2021):引入级联扩散,它包括多个扩散模型的流水线,这些模型生成分辨率提高的图像,用于高保真图像合成
-
扩散模型在图像合成上击败了GANs(Dhariwal et al., 2021):表明扩散模型通过改进U-Net体系结构以及引入分类器指导,可以获得优于当前最先进的生成模型的图像样本质量
-
无分类器扩散指南([Ho et al., 2021):表明通过使用单个神经网络联合训练条件和无条件扩散模型,不需要分类器来指导扩散模型
-
具有CLIP Latents (DALL-E 2) 的分层文本条件图像生成 (Ramesh et al., 2022):在将文本标题转换为CLIP图像嵌入之前使用,然后扩散模型将其解码为图像
-
具有深度语言理解的真实文本到图像扩散模型(ImageGen)(Saharia et al., 2022):表明将大型预训练语言模型(例如T5)与级联扩散结合起来,对于文本到图像的合成很有效
请注意,此列表仅包括在撰写本文,即2022年6月7日之前的重要作品。
目前,扩散模型的主要(也许唯一)缺点是它们需要多次正向传递来生成图像(对于像GAN这样的生成模型来说,情况并非如此)。然而,有正在进行中的研究表明只需要10个去噪步骤就能实现高保真生成。