代码
:https://github.com/facebookresearch/DiT
论文:可扩展的基于transformer的扩散模型23.03.Scalable Diffusion Models with Transformers
威廉·皮布尔斯(William Peebles)加州大学伯克利分校(UC Berkeley)
谢赛宁(Saining Xie)纽约大学(New York University
简介 (就是将原来扩散生成模型Unet网络用ViT进行替换 + 验证Transformer的扩展性)
在潜在扩散模型(Latent Diffusion Models, LDMs)框架下构建和基准测试DiT设计空间,我们可以成功地用Transformer 替代U-Net骨干。我们进一步表明,DiTs是可扩展的扩散模型架构:网络复杂度(用Gflops衡量)与样本质量(用FID衡量)之间存在强相关性。通过简单地扩大DiT并训练具有高容量骨干(118.6 Gflops)的LDM,我们能够在类别条件的256×256 ImageNet
生成基准测试上实现2.27 FID的最先进结果。
研究了Transformer相对于网络复杂度与样本质量的扩展行为
图1 实现了最先进的图像质量。
我们展示了从我们在ImageNet上训练的两个类条件DiT-XL/2模型
中选择的样本,分别为512×512和256×256分辨率。
模型结构
不同规格的模型
我们遵循 Small (S)、Base (B) 和 Large (L) 变体的 ViT [10] 模型配置;我们还引入了 XLarge (XL) 配置作为我们的最大模型。
核心代码
https://github.com/facebookresearch/DiT/blob/main/models.py
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import torch
import torch.nn as nn
import numpy as np
import math
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
#################################################################################
# Embedding Layers for Timesteps and Class Labels #
#################################################################################
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
#################################################################################
# Core DiT Model #
#################################################################################
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
num_patches = self.x_embedder.num_patches
# Will use fixed sin-cos embedding:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
self.blocks = nn.ModuleList([
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
])
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def unpatchify(self, x):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return imgs
def forward(self, x, t, y):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(t) # (N, D)
y = self.y_embedder(y, self.training) # (N, D)
c = t + y # (N, D)
for block in self.blocks:
x = block(x, c) # (N, T, D)
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_cfg(self, x, t, y, cfg_scale):
"""
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.forward(combined, t, y)
# For exact reproducibility reasons, we apply classifier-free guidance on only
# three channels by default. The standard approach to cfg applies it to all channels.
# This can be done by uncommenting the following line and commenting-out the line following that.
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
#################################################################################
# DiT Configs #
#################################################################################
def DiT_XL_2(**kwargs):
return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
def DiT_XL_4(**kwargs):
return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
def DiT_XL_8(**kwargs):
return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
def DiT_L_2(**kwargs):
return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
def DiT_L_4(**kwargs):
return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
def DiT_L_8(**kwargs):
return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
def DiT_B_2(**kwargs):
return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
def DiT_B_4(**kwargs):
return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
def DiT_B_8(**kwargs):
return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
def DiT_S_2(**kwargs):
return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
def DiT_S_4(**kwargs):
return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
def DiT_S_8(**kwargs):
return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
DiT_models = {
'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
}
部分原文翻译
摘要
我们探索了一类基于Transformer架构(architecture)的新扩散模型(diffusion models)。我们训练图像的潜在扩散模型(latent diffusion models),用在潜在块(latent patches)上操作的Transformer替代
常用的U-Net
骨干。我们通过前向传递复杂度(forward pass complexity)来分析我们的扩散Transformer(DiTs)的可扩展性
,并用Gflops来衡量。我们发现,通过增加Transformer深度/宽度或增加输入标记数量,具有更高Gflops的DiTs在FID上表现出更低的数值。除了具备良好的可扩展性,我们最大的DiT-XL/2模型在类别条件ImageNet 512×512和256×256基准测试中,均超越了之前的所有扩散模型,在后者上实现了2.27的FID。
1.引言
机器学习正经历一场由Transformer推动的复兴。在过去五年中,自然语言处理(natural language processing)、视觉(vision)及其他多个领域的神经架构(neural architectures)大多被Transformer取代[60]。然而,许多图像级生成模型(image-level generative models)仍未完全跟上这一趋势——虽然Transformer在自回归模型(autoregressive models中得到广泛应用,但在其他生成建模框架中应用较少。例如,扩散模型(diffusion models)在图像级生成模型的最新进展中处于前沿;然而,它们都采用卷积U-Net架构作为默认的骨干选择。
霍等人(Ho et al.)的开创性工作(DDPM
)首次引入了U-Net骨干用于扩散模型。最初在像素级自回归模型(pixel-level autoregressive models)和条件GANs中取得成功后,U-Net从PixelCNN++中继承而来,进行了少量更改。该模型是卷积的,主要由ResNet[15]块组成。与标准U-Net[49]不同,额外的空间自注意块
(spatial self-attention blocks)——Transformer中的基本组件——在较低分辨率下插入。Dhariwal和Nichol[9]对U-Net的几个架构选择进行了消融实验,如使用自适应归一化层(adaptive normalization layers)[40]来注入条件信息和卷积层的通道数。然而,从Ho等人的U-Net高层设计基本保持不变。
通过这项工作,我们旨在揭示扩散模型中的架构选择的重要性,并为未来的生成建模研究提供经验基准。我们展示了U-Net的归纳偏置
(inductive bias)对扩散模型的性能并不关键,它们可以很容易地被标准设计(如Transformer)替代。因此,扩散模型有望从近期的架构统一趋势中受益——例如,通过继承其他领域的最佳实践和训练方案,以及保留可扩展性、鲁棒性和效率等有利属性。标准化的架构还可以为跨领域研究开辟新可能。
在本文中,我们重点介绍了一类基于Transformer的扩散模型。我们称之为 Difffusion Transformer,简称DiTs
。DiTs遵循Vision Transformers(ViTs)
的最佳实践,后者已被证明在视觉识别上比传统卷积网络(如ResNet[15])更有效地扩展。
具体而言,我们研究了Transformer相对于网络复杂度与样本质量的扩展行为。通过在潜在扩散模型(Latent Diffusion Models, LDMs)[48]框架下构建和基准测试DiT设计空间,我们可以成功地用Transformer替代U-Net骨干。我们进一步表明,DiTs是可扩展的扩散模型架构:网络复杂度(用Gflops衡量)与样本质量(用FID衡量)之间存在强相关性。通过简单地扩大DiT并训练具有高容量骨干(118.6 Gflops)的LDM,我们能够在类别条件的256×256 ImageNet生成基准测试上实现2.27 FID的最先进结果。
2 . 相关工作
2.1 Transformer
Transformers 已取代语言、视觉(Vision-Transformer)、强化学习(reinforcement learning) 和元学习(meta-learning)领域的特定领域架构。它们在语言领域展示了在增加模型规模(model size)、训练计算(training compute)和数据方面的显著扩展特性(scaling properties),作为通用自回归模型
(generic autoregressive models) 和 ViTs。
在语言之外,Transformers 已被训练用于自回归预测像素(autoregressively predict pixels)[6, 7, 38]。它们还被训练在离散代码表上(discrete codebooks,VQ-VAE
),作为自回归模型(autoregressive models, 例如VQ-GAN
,DALLE) 和掩码生成模型(masked generative models);前者在参数量达到20B时显示出优秀的扩展特性。
最后,Transformers 已在 DDPMs 中探索,用于生成非空间数据(non-spatial data);例如,在 DALL·E 2
中生成 CLIP 图像嵌入(image embeddings)。在本文中,我们研究 Transformer 作为图像扩散模型(diffusion models of images)骨干(backbone)的扩展特性(scaling properties)
2.2 Denoising diffusion probabilistic models (DDPMs)
扩散模型和基于分数的生成模型(Score-Based Generative Models)在生成图像方面特别成功,在许多情况下优于生成对抗网络(GANs),。过去两年中DDPMs的改进主要来自于改进的采样技术
(DDPM,DDIM),尤其是无分类器指导(Classifier-Free Guidance,CFG),将扩散模型重新表述为预测噪声
而不是像素,并使用级联DDPM管道,其中低分辨率基础扩散模型与上采样器并行训练。
对于上述所有扩散模型,卷积U-Nets是事实上的骨干架构选择。与此同时的工作引入了一种基于注意力的新型高效架构用于DDPMs;
2.3 Architecture complexity.
我们探索纯Transformers架构复杂度(Architecture Complexity)。在图像生成文献中评估架构复杂度时,通常使用参数数量。一般来说,参数数量对于评估图像模型的复杂度可能是一个不好的代理,因为它们不考虑图像分辨率
,而分辨率显著影响性能。相反,本文的大部分模型复杂度分析通过理论Gflops的视角进行。这使我们与架构设计文献保持一致,在该文献中Gflops被广泛用于评估复杂度。在实践中,黄金复杂度指标仍在争论中,因为它经常取决于特定的应用场景。Nichol和Dhariwal改进扩散模型的开创性工作与我们最相关——他们分析了U-Net架构类的可扩展性和Gflop特性。在本文中,我们关注Transformer类。
3. Diffusion Transformer
3.1 预备知识
数学字符的大小写读音
3.1.1 扩散公式(Diffusion formulation)。
在介绍我们的架构之前,我们简要回顾一下理解扩散模型(DDPMs
)的一些基本概念。高斯扩散模型(Gaussian diffusion models)假设一个前向加噪过程(forward noising process),该过程逐步对真实图片
x
0
x_0
x0 应用噪声:
q ( x t ∣ x 0 ) = N ( x t ; α t ˉ x 0 , ( 1 − α t ˉ ) I ) q(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha_t}} x_0, (1-\bar{\alpha_t}) I) q(xt∣x0)=N(xt;αtˉx0,(1−αtˉ)I)
如果
X
∼
N
(
μ
,
σ
2
)
X \sim \mathcal{N}(\mu, \sigma^2)
X∼N(μ,σ2),则对任意常数
a
a
a 和
b
b
b,线性变换
Y
=
a
X
+
b
Y = aX + b
Y=aX+b 也服从高斯分布(正态分布线性变换性质
),且:
Y
∼
N
(
a
μ
+
b
,
a
2
σ
2
)
Y \sim \mathcal{N}(a\mu + b, a^2\sigma^2)
Y∼N(aμ+b,a2σ2)
若
ϵ
t
∼
N
(
0
,
1
)
\epsilon_t \sim \mathcal{N}(0, 1)
ϵt∼N(0,1),那么令线性变换
a
ϵ
t
+
b
a\epsilon_t + b
aϵt+b 等于
x
t
x_t
xt,
X
t
∼
N
(
b
,
a
2
)
∼
N
(
α
t
ˉ
x
0
,
(
1
−
α
t
ˉ
)
I
)
X_t \sim \mathcal{N}( b, a^2) \sim \mathcal{N}( \sqrt{\bar{\alpha_t}} x_0, (1-\bar{\alpha_t}) I)
Xt∼N(b,a2)∼N(αtˉx0,(1−αtˉ)I)
其中常数
α
t
ˉ
\bar{\alpha_t}
αtˉ 是超参数(hyperparameters)。通过应用重参数化技巧(reparameterization trick)(等效与用标准正太分布的线性变换去替代其他正太分布
,通过上面的等式反求a,b,
b
=
α
t
ˉ
x
0
,
a
2
=
1
−
α
t
ˉ
)
b= \sqrt{\bar{\alpha_t}}x_0 ,a^2=1-\bar{\alpha_t})
b=αtˉx0,a2=1−αtˉ):
x
t
=
a
ϵ
t
+
b
x_t=a\epsilon_t + b
xt=aϵt+b 等于:
x t = α t ˉ x 0 + 1 − α t ˉ ϵ t x_t = \sqrt{\bar{\alpha_t}} x_0 + \sqrt{1 - \bar{\alpha_t}} \epsilon_t xt=αtˉx0+1−αtˉϵt
其中 ϵ t ∼ N ( 0 , I ) \epsilon_t \sim \mathcal{N}(0, I) ϵt∼N(0,I)。
扩散模型被训练用来学习逆过程
(reverse process),即逆转前向过程的腐蚀:
p θ ( x t − 1 ∣ x t ) = N ( μ θ ( x t ) , Σ θ ( x t ) ) p_\theta(x_{t-1}|x_t) = \mathcal{N} (\mu_\theta(x_t), \Sigma_\theta(x_t)) pθ(xt−1∣xt)=N(μθ(xt),Σθ(xt))
其中神经网络用于预测
p
θ
p_\theta
pθ 的统计量。逆过程模型通过
x
0
x_0
x0 对数似然的变分下界
(variational lower bound)进行训练,该下界减少为
L ( θ ) = − p ( x 0 ∣ x 1 ) + ∑ t D K L ( q ∗ ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) L(\theta) = -p(x_0|x_1) + \sum_t D_{KL}(q^*(x_{t-1}|x_t, x_0) || p_\theta(x_{t-1}|x_t)) L(θ)=−p(x0∣x1)+t∑DKL(q∗(xt−1∣xt,x0)∣∣pθ(xt−1∣xt))
排除一个与训练无关的附加项。由于 q ∗ q^* q∗ 和 p θ p_\theta pθ 都是高斯分布(Gaussian),因此可以使用两个分布的均值和协方差计算 D K L D_{KL} DKL。通过将 μ θ \mu_\theta μθ 重新参数化为噪声预测网络(noise prediction network) ϵ θ \epsilon_\theta ϵθ,模型可以通过预测噪声 ϵ θ ( x t ) \epsilon_\theta(x_t) ϵθ(xt) 与采样的高斯噪声 ϵ t \epsilon_t ϵt 之间的简单均方误差进行训练:
L simple ( θ ) = ∥ ϵ θ ( x t ) − ϵ t ∥ 2 2 L_{\text{simple}}(\theta) = \|\epsilon_\theta(x_t) - \epsilon_t\|_2^2 Lsimple(θ)=∥ϵθ(xt)−ϵt∥22
但是,为了用学习到的逆过程协方差 Σ θ \Sigma_\theta Σθ 训练扩散模型,需要优化完整的 D K L D_{KL} DKL 项。我们遵循Nichol和Dhariwal的方法:
用 L simple L_{\text{simple}} Lsimple 训练 ϵ θ \epsilon_\theta ϵθ,并用完整的 L L L 训练 Σ θ \Sigma_\theta Σθ。一旦 p θ p_\theta pθ 训练完毕,可以通过初始化 x t max ∼ N ( 0 , I ) x_{t_{\text{max}}} \sim \mathcal{N}(0, I) xtmax∼N(0,I) 并通过重参数化技巧采样 x t − 1 ∼ p θ ( x t − 1 ∣ x t ) x_{t-1} \sim p_\theta(x_{t-1}|x_t) xt−1∼pθ(xt−1∣xt) 生成新图像。
3.1.2 无分类器指导
(Classifier-free guidance)
条件扩散模型(conditional diffusion models)将额外信息作为输入,如类标签 c c c。在这种情况下,逆过程变为
p θ ( x t − 1 ∣ x t , c ) = N ( μ θ ( x t , c ) , Σ θ ( x t , c ) ) p_\theta(x_{t-1}|x_t, c) = \mathcal{N} (\mu_\theta(x_t, c), \Sigma_\theta(x_t, c)) pθ(xt−1∣xt,c)=N(μθ(xt,c),Σθ(xt,c))
其中 ϵ θ \epsilon_\theta ϵθ 和 Σ θ \Sigma_\theta Σθ 都以 c c c 为条件。在这种设置下,可以使用无分类器指导(classifier-free guidance)来鼓励采样程序找到 x x x,使得 log p ( c ∣ x ) \log p(c|x) logp(c∣x) 高。通过贝叶斯规则(Bayes Rule),
log p ( c ∣ x ) ∝ log p ( x ∣ c ) − log p ( x ) \log p(c|x) \propto \log p(x|c) - \log p(x) logp(c∣x)∝logp(x∣c)−logp(x)
因此
∇ x log p ( c ∣ x ) ∝ ∇ x log p ( x ∣ c ) − ∇ x log p ( x ) \nabla_x \log p(c|x) \propto \nabla_x \log p(x|c) - \nabla_x \log p(x) ∇xlogp(c∣x)∝∇xlogp(x∣c)−∇xlogp(x)
通过将扩散模型的输出解释为得分函数(score function),可以通过指导
ϵ ^ θ ( x t , c ) = ϵ θ ( x t , ∅ ) + s ⋅ ( ϵ θ ( x t , c ) − ϵ θ ( x t , ∅ ) ) \hat{\epsilon}_\theta(x_t, c) = \epsilon_\theta(x_t, \emptyset) + s \cdot (\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \emptyset)) ϵ^θ(xt,c)=ϵθ(xt,∅)+s⋅(ϵθ(xt,c)−ϵθ(xt,∅))
来引导采样程序,以采样具有高 p ( x ∣ c ) p(x|c) p(x∣c) 的 x x x,其中 s > 1 s > 1 s>1 表示指导的规模(注意 s = 1 s = 1 s=1 恢复标准采样)。通过在训练期间随机删除 c c c 并用学习到的“空”嵌入 ∅ \emptyset ∅ 替换来评估扩散模型。无分类器指导众所周知能显著改善样本质量,相对于一般的采样技术,这一趋势在我们的DiT模型中同样适用。
3.1.3 潜在扩散模型(Latent diffusion models)。
直接在高分辨率像素空间(pixel space)训练扩散模型可能在计算上非常昂贵。潜在扩散模型(LDMs)通过两阶段方法解决这个问题:(1)学习一个自编码器(autoencoder),将图像压缩到更小的空间表示(spatial representations),使用学习到的编码器 E E E;(2)在表示 z = E ( x ) z = E(x) z=E(x) 上训练扩散模型,而不是在图像 x x x 上训练扩散模型( E E E 被冻结)。然后,可以通过从扩散模型中采样表示 z z z 并随后用学习到的解码器 D D D 解码为图像 x = D ( z ) x = D(z) x=D(z) 生成新图像。
如图2所示
LDMs 在使用较少Gflops的情况下实现了良好的性能,相对于像ADM这样的像素空间扩散模型。由于我们关注计算效率,这使得它们成为架构探索的一个有吸引力的起点。在本文中,我们将DiTs应用于潜在空间,尽管它们也可以无修改地应用于像素空间。这使得我们的图像生成管道成为一种混合方法;我们使用现成的卷积VAEs和基于Transformer的DDPMs。
3.2(重点
) DiT: diffusion Transformer设计空间
我们引入了扩散Transformers(DiTs),一种新型扩散模型架构。我们尽可能忠实于(faithful)标准Transformer架构,以保留其扩展特性 (scaling properties)。由于我们的重点是训练图像的DDPMs(特别是图像的空间表示),DiT基于操作在块序列上的Vision Transformer(ViT)架构。DiT保留了许多ViTs的最佳实践。图3显示了完整DiT架构的概述。在本节中,我们描述DiT的前向传递(forward pass)以及DiT类设计空间的组件。
图3: 扩散变压器 (DiT) 架构。
左:我们训练条件潜在 DiT 模型。输入潜在特征(latend)被分解为块(patches)并由几个 DiT 块处理。右图:我们的 DiT 块的详细信息。我们尝试了标准变压器块的变体,这些变体通过自适应层范数、交叉注意和额外的输入标记合并条件。自适应层正则化(Adaptive layer norm)效果最好。
3.2.1 分块(Patchify)。
** DiT的输入是空间表示 z z z(对于256×256×3的图像, z z z 的形状为32×32×4)。DiT的第一层是“分块”,将空间输入转换为 T T T个标记的序列,每个标记维度为 d d d,通过线性嵌入输入中的每个块。分块后,我们对所有输入标记应用标准ViT的频率基位置嵌入(sine-cosine版本)。分块创建的标记数 T T T由块大小超参数 p p p决定。如图4所示,将 p p p减半会使 T T T翻四倍,从而至少使总Transformer的Gflops翻四倍。虽然它对Gflops影响显著,但请注意更改 p p p对下游参数数量没有实际影响。
我们将 p = 2 , 4 , 8 p = 2, 4, 8 p=2,4,8添加到DiT设计空间中。
图4 DiT 的输入规范(specifications)。
给定补丁大小 p × p,形状 I × I × C 的空间表示(来自 VAE 的噪声潜在)被“补丁化”成一个长度为
T
=
(
I
/
p
)
2
T = (I/p)^2
T=(I/p)2 的序列,隐藏维度 d。较小的补丁大小 p 会导致更长的序列长度,从而导致更多的 Gflop。
3.2.2 DiT块设计(DiT block design)。
分块后,输入标记由一系列Transformer块处理。除了加噪图像输入外,扩散模型有时还处理额外的条件信息,如噪声时间步长 t t t、类标签 c c c、自然语言等。我们探索了四种不同处理条件输入的Transformer块变体。这些设计对标准ViT块设计进行了小的但重要的修改。所有块的设计如图3所示。
-
上下文条件(In-context conditioning)。 我们简单地将 t t t和 c c c的向量嵌入附加为输入序列中的两个额外标记,对它们的处理与图像标记没有区别。这类似于ViTs中的cls标记,并且允许我们使用未修改的标准ViT块。在最后一个块之后,我们从序列中移除条件标记。这种方法对模型引入的Gflops可忽略不计。
-
交叉注意力块(Cross-attention block)。 我们将 t t t和 c c c的嵌入连接成一个长度为二的序列,独立于图像标记序列。Transformer块被修改为在多头自注意力块之后包含一个额外的多头交叉注意力层,类似于Vaswani等人的原始设计,也类似于LDM用于对类标签进行条件处理的设计。交叉注意力为模型增加了最多的Gflops,大约有15%的开销。
-
自适应层归一化(adaLN)块(Adaptive layer norm (adaLN) block)。 继GANs和带有U-Net骨干的扩散模型中广泛使用自适应归一化层(adaptive normalization layers)之后,我们探索用自适应层归一化(adaLN)替换Transformer块中的标准层归一化层(layer norm layers)。而不是直接学习维度上的缩放和移位参数 γ \gamma γ 和 β \beta β,我们从 t t t和 c c c的嵌入向量的和中回归它们。在我们探索的三个块设计中,adaLN增加的Gflops最少,因此计算效率最高。它也是唯一受限于对所有标记应用相同函数的条件机制。
-
adaLN-Zero块(adaLN-Zero block)。 先前在ResNets上的工作发现将每个残差块初始化为恒等函数是有益的。例如,Goyal等人发现,在监督学习设置中,零初始化每个块中的最后批量归一化缩放因子 γ \gamma γ加速了大规模训练。扩散U-Net模型使用类似的初始化策略,在任何残差连接之前,零初始化每个块中的最后卷积层。我们探索了adaLN DiT块的一个修改版本,它也做同样的事情。除了回归 γ \gamma γ和 β \beta β,我们还回归维度上的缩放参数 α \alpha α,这些参数在DiT块中的任何残差连接之前立即应用。
我们将上下文、交叉注意力、自适应层归一化和adaLN-Zero块包含在DiT设计空间中(DiT design space)。
3.2.3 模型大小(Model size)。
我们应用了一系列 N N N个DiT块,每个块在隐藏维度大小 d d d上操作。遵循ViT,我们使用共同扩展 N N N、 d d d和注意力头(attention heads)的标准Transformer配置。具体来说,我们使用四个配置:DiT-S、DiT-B、DiT-L和DiT-XL。它们涵盖了从0.3到118.6Gflops的广泛模型大小和flop分配,使我们能够评估扩展性能。表1给出了配置的详细信息。
我们将B、S、L和XL配置添加到DiT设计空间中。
表 1. DiT 模型的详细信息。
我们遵循 Small (S)、Base (B) 和 Large (L) 变体的 ViT
模型配置;我们还引入了 XLarge (XL) 配置作为我们的最大模型。
3.2.4 Transformer解码器(Transformer decoder)。
在最后一个DiT块之后,我们需要将图像标记序列解码为输出噪声预测和输出对角协方差预测。两个输出的形状都与原始空间输入相等。我们使用标准线性解码器来实现这一点;我们应用最终的层归一化(如果使用adaLN则是自适应的),并线性解码每个标记为一个 p × p × 2 C p \times p \times 2C p×p×2C张量,其中 C C C是DiT的空间输入通道数。最后,我们将解码的标记重新排列为它们的原始空间布局,以获得预测的噪声和协方差。
我们探索的完整DiT设计空间包括块大小、Transformer块架构和模型大小。
4. 实验设置 (Experimental Setup)
我们探索了DiT设计空间并研究了我们模型类的扩展属性。我们的模型根据其配置 (config) 和潜在 (patch) 大小 p p p 命名;例如,DiT-XL/2表示XLarge配置且 p = 2 p = 2 p=2。
训练 (Training)
我们在ImageNet数据集上训练了类条件潜在的DiT模型,图像分辨率为256×256和512×512,这是一项高度竞争的生成建模基准。我们使用ViT的标准权重初始化技术来初始化最后的线性层为零。我们使用AdamW
训练所有模型【DiT-Small, DiT-Base, DiT-Large, DiT-XLarge】。
我们使用恒定学习率
1
×
1
0
−
4
1 \times 10^{-4}
1×10−4,无权重衰减,批次大小为256。唯一的数据增强 (data augmentation) 方法是水平翻转。与之前使用ViTs的许多工作不同【Vision-Transformer】,我们发现学习率预热 (warmup) 和正则化 (regularization)
对训练高性能DiT来说不是必需的。即使没有这些技术,训练在所有模型配置下都非常稳定,并且我们没有观察到常见于训练transformers时的损失峰值。
在生成建模文献中常见的实践中,我们在训练过程中保持DiT权重的指数移动平均 (EMA),衰减为0.9999。所有结果报告都使用EMA模型。我们在所有DiT模型大小和patch大小之间使用相同的训练超参数
。我们的训练超参数几乎完全保留自ADM【openai论文
:21.05.Diffusion Models Beat GANs on Image Synthesis】。我们没有调整学习率、衰减/预热调度、Adam的
β
1
/
β
2
\beta_1/\beta_2
β1/β2 或权重衰减(weight decays)。
扩散 (Diffusion)
我们使用来自Stable Diffusion的预训练变分自编码器
(VAE) 模型【Variational Autoencoder】。VAE编码器的下采样因子为8——给定一个形状为256×256×3的RGB图像
x
x
x,编码后的
z
=
E
(
x
)
z = E(x)
z=E(x) 形状为32×32×4。在本节中的所有实验中,我们的扩散模型在这个Z空间中操作。采样新的潜在变量后,我们使用VAE解码器解码为像素
x
=
D
(
z
)
x = D(z)
x=D(z)。我们保留ADM的扩散超参数【21.05.Diffusion Models Beat GANs on Image Synthesis】;具体来说,我们使用最大时间步数
t
m
a
x
=
1000
t_{max} = 1000
tmax=1000,线性方差调度从
1
×
1
0
−
4
1 \times 10^{-4}
1×10−4 到
2
×
1
0
−
2
2 \times 10^{-2}
2×10−2,ADM的协方差参数化方法
Σ
θ
\Sigma_\theta
Σθ 以及其嵌入输入时间步和标签的方法。
评估指标 (Evaluation Metrics)
我们使用Frechet Inception Distance (FID)
【Frechet Inception Distance】作为评估图像生成模型的标准指标来测量扩展性能。按照惯例,当与之前的工作比较时,我们使用250个DDPM采样步骤报告FID-50K。已知FID对小的实现细节很敏感【Frechet Inception Distance】;为了确保准确的比较,本文中报告的所有值都是通过导出样本并使用ADM的TensorFlow评估套件获得的。除非另有说明,本节中报告的FID数值不使用分类器自由指导。此外,我们还报告了Inception Score【Inception Score】、sFID【sFID】和Precision/Recall【Precision/Recall】作为次要指标。
计算 (Computation)
我们在JAX中实现所有模型并使用TPU-v3 pods训练它们。DiT-XL/2是我们计算最密集的模型,在TPU v3-256 pod上以大约每秒5.7次迭代的速度训练,批次大小为256。
5. 实验 (Experiments)
DiT块设计
我们训练了四个我们Gflops最高的DiT-XL/2模型,每个使用不同的块设计——in-context (119.4 Gflops),cross-attention (137.6 Gflops),adaptive layer norm (adaLN, 118.6 Gflops)或adaLN-zero (118.6 Gflops)。我们在整个训练过程中测量FID。结果如图5
所示。adaLN-Zero块比cross-attention和in-context条件在所有训练阶段都获得了更低的FID,同时也是计算效率最高的。在40万次训练迭代时,使用adaLN-Zero模型获得的FID几乎是in-context模型的一半,表明条件机制对模型质量至关重要。初始化也很重要——将每个DiT块初始化为恒等函数的adaLN-Zero显著优于vanilla adaLN。在本文的其余部分,所有模型将使用adaLN-Zero DiT块。
图5 比较不同的条件反射策略。
adaLNZero 在训练的所有阶段都优于交叉注意力和上下文条件
模型大小和块大小的扩展 (Scaling model size and patch size)
我们训练了 12 个 DiT 模型,覆盖模型配置(S、B、L、XL)和块大小(8、4、2)。需要注意的是,DiT-L 和 DiT-XL 在相对 Gflops 方面显著接近其他配置。图2(左)概述了每个模型在 400K 训练迭代中的 Gflops 和 FID。在所有情况下,我们发现增加模型大小和减少块大小显著改善了扩散模型。
图6(顶部)展示了在保持块大小不变的情况下,FID 随模型大小增加而变化。
在所有四个配置中,通过使 Transformer 更深和更宽,可以在训练的所有阶段获得显著的 FID 改进。同样,图6(底部)显示了在保持模型大小不变的情况下,FID 随块大小减少而变化。通过简单地扩展 DiT 处理的标记数量,在参数近似不变的情况下,我们再次观察到显著的 FID 改进。
第一行:我们比较保持补丁大小恒定的 FID。
底行:我们比较了 FID 保持模型大小不变。Transfoemer主干在所有模型大小和补丁大小中产生更好的生成模型
表2 在ImageNet 256×256上对类条件图像生成进行基准测试。
DiT-XL/2实现了最先进的FID。
表3 在 ImageNet 512×512 上对类条件图像生成进行基准测试
请注意,先前的工作 使用 1000 个真实样本测量 Precision 和 Recall,分辨率为 512 × 512;为了一致性,我们做同样的事情。
6. 结论
我们引入了扩散 Transformers(DiTs),一种简单的基于 Transformer 的扩散模型骨干,优于先前的 U-Net 模型,并继承了 Transformer 模型类的优秀扩展特性。鉴于本文中的有前途的扩展结果,未来的工作应继续将 DiTs 扩展到更大的模型和标记数量。DiT 还可以作为 DALL·E 2 和 Stable Diffusion 等文本到图像模型的直接骨干进行探索。
附录
图11 其他生成结果。
我们对 512 × 512 模型使用 6.0 的无分类器引导尺度,对 256 × 256 模型使用 4.0。两种模型都使用 ft-EMA VAE 解码器
图14 15
高斯分布性质
标准正态分布:
-
当 μ = 0 \mu = 0 μ=0 且 σ 2 = 1 \sigma^2 = 1 σ2=1 时,高斯分布称为标准正态分布,记作 Z ∼ N ( 0 , 1 ) Z \sim \mathcal{N}(0, 1) Z∼N(0,1)。
-
标准正态分布的概率密度函数为:
f ( z ) = 1 2 π exp ( − z 2 2 ) f(z) = \frac{1}{\sqrt{2\pi}} \exp\left( -\frac{z^2}{2} \right) f(z)=2π1exp(−2z2)
线性变换:
-
如果 X ∼ N ( μ , σ 2 ) X \sim \mathcal{N}(\mu, \sigma^2) X∼N(μ,σ2),则对任意常数 a a a 和 b b b,线性变换 Y = a X + b Y = aX + b Y=aX+b 也服从高斯分布,且:
Y ∼ N ( a μ + b , a 2 σ 2 ) Y \sim \mathcal{N}(a\mu + b, a^2\sigma^2) Y∼N(aμ+b,a2σ2)
线性组合:
-
如果 X ∼ N ( μ X , σ X 2 ) X \sim \mathcal{N}(\mu_X, \sigma_X^2) X∼N(μX,σX2) 和 Y ∼ N ( μ Y , σ Y 2 ) Y \sim \mathcal{N}(\mu_Y, \sigma_Y^2) Y∼N(μY,σY2) 且 X X X 和 Y Y Y 独立,则它们的线性组合 Z = a X + b Y Z = aX + bY Z=aX+bY 仍然是高斯分布,且:
Z ∼ N ( a μ X + b μ Y , a 2 σ X 2 + b 2 σ Y 2 ) Z \sim \mathcal{N}(a\mu_X + b\mu_Y, a^2\sigma_X^2 + b^2\sigma_Y^2) Z∼N(aμX+bμY,a2σX2+b2σY2)
重参数技巧(标准正太分布表示常规正太分布)
扩散模型(Gaussian diffusion models)假设一个前向加噪过程(forward noising process),该过程逐步对真实图片
x
0
x_0
x0 应用噪声:
q ( x t ∣ x 0 ) = N ( x t ; α t ˉ x 0 , ( 1 − α t ˉ ) I ) q(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha_t}} x_0, (1-\bar{\alpha_t}) I) q(xt∣x0)=N(xt;αtˉx0,(1−αtˉ)I)
如果
X
∼
N
(
μ
,
σ
2
)
X \sim \mathcal{N}(\mu, \sigma^2)
X∼N(μ,σ2),则对任意常数
a
a
a 和
b
b
b,线性变换
Y
=
a
X
+
b
Y = aX + b
Y=aX+b 也服从高斯分布(正态分布线性变换性质
),且:
Y
∼
N
(
a
μ
+
b
,
a
2
σ
2
)
Y \sim \mathcal{N}(a\mu + b, a^2\sigma^2)
Y∼N(aμ+b,a2σ2)
若
ϵ
t
∼
N
(
0
,
1
)
\epsilon_t \sim \mathcal{N}(0, 1)
ϵt∼N(0,1),那么令线性变换
a
ϵ
t
+
b
a\epsilon_t + b
aϵt+b 等于
x
t
x_t
xt,
X
t
∼
N
(
b
,
a
2
)
∼
N
(
α
t
ˉ
x
0
,
(
1
−
α
t
ˉ
)
I
)
X_t \sim \mathcal{N}( b, a^2) \sim \mathcal{N}( \sqrt{\bar{\alpha_t}} x_0, (1-\bar{\alpha_t}) I)
Xt∼N(b,a2)∼N(αtˉx0,(1−αtˉ)I)
其中常数
α
t
ˉ
\bar{\alpha_t}
αtˉ 是超参数(hyperparameters)。通过应用重参数化技巧(reparameterization trick)(等效与用标准正太分布的线性变换去替代其他正太分布
,通过上面的等式反求a,b,
b
=
α
t
ˉ
x
0
,
a
2
=
1
−
α
t
ˉ
)
b= \sqrt{\bar{\alpha_t}}x_0 ,a^2=1-\bar{\alpha_t})
b=αtˉx0,a2=1−αtˉ):
x
t
=
a
ϵ
t
+
b
x_t=a\epsilon_t + b
xt=aϵt+b 等于:
x t = α t ˉ x 0 + 1 − α t ˉ ϵ t x_t = \sqrt{\bar{\alpha_t}} x_0 + \sqrt{1 - \bar{\alpha_t}} \epsilon_t xt=αtˉx0+1−αtˉϵt