DiT代码详细解析——Scalable Diffusion Models with Transformers

Paperhttps://arxiv.org/abs/2212.09748 Codehttps://github.com/facebookresearch/DiT

目录

方法:

实验设置:

结论:

DiT代码:

Sampling

Training DiT

DiT

DiTBlock

FinalLayer

TimestepEmbedder

LabelEmbedder


用transformer替换处理latent patches的U-Net backbone。本文作者发现,通过增加Transformer深度/宽度或增加输入令牌数量,具有较高Gflops的DiT始终具有较低的FID。除了具有良好的可扩展性外,最大的DiT-XL/2模型在条件类ImageNet 512×512和256×256基准测试中的表现优于所有先前的扩散模型,在后者上实现了2.27的最先进FID。

方法:

本文介绍了一种新的扩散模型结构--扩散变换器(DiT)。我们的目标是尽可能忠实于标准Transformer架构,以保持其扩展特性。由于我们的重点是训练图像的DDPM(具体而言,图像的空间表示),因此DiT基于VIT。

Patchify. DIT的输入是空间表示z(对于256 × 256 × 3图像,z的形状为32 × 32 × 4)。它通过在输入中线性地嵌入每个patch,将空间输入转换为T个tokens的序列,每个token的维数为d。与patchify一起,我们把标准VIT位置嵌入应用于所有tokens。

如图4,将p减半将使T增加四倍,因此至少使总Transformer Gflops增加四倍。尽管它对Gflops有显著影响,但请注意,更改p对下游参数计数没有任何有意义的影响。我们将p = 2、4、8添加到DiT设计空间。

DiT block design. 在patchify之后,输入tokens由一系列Transformer块处理。除了噪声图像输入,扩散模型有时处理额外的条件信息,如噪声时间步长t,类标签c,自然语言等。我们探讨了Transformer块的四种变体处理不同条件输入。这些设计对标准ViT块设计进行了微小但重要的修改。所有试验块的设计如图3所示。

(1)In-context conditioning. 我们简单地将t和c的向量嵌入作为两个额外的标记添加到输入序列中,将它们与图像标记没有区别地对待。这类似于ViT中的cls令牌,它允许我们使用标准的ViT块而无需修改。在最后一个块之后,我们从序列中移除条件标记。该方法向模型引入了可忽略不计的新Gflops。

(2)Cross-attention block. 我们将t和c的嵌入连接成一个长度为2的序列,与图像标记序列分开。修改了Transformer块,以在多头自注意块之后包括额外的多头交叉注意层,类似于Vaswani等人的原始设计,也类似于LDM用于调节类标签的层。交叉注意为模型增加了最多的Gflops,大约为15%的开销。

(3)Adaptive layer norm (adaLN) block. 在GAN和具有UNet骨干的扩散模型中广泛使用自适应归一化层之后,我们探索用自适应层范数(adaLN)替换Transformer块中的标准层范数层。我们不是直接学习维度尺度和移位参数γ和β,而是从t和c的嵌入向量之和回归它们。在我们研究的三种块设计中,adaLN添加的Gflops最少,因此计算效率最高。它也是唯一一种限制为对所有令牌应用相同功能的条件机制。

(4)adaLN-Zero block. 先前关于ResNets的工作已经发现,将每个残差块初始化为单位函数是有益的。扩散U-Net模型使用类似的初始化策略,在任何残差连接之前,对每个块中的最终卷积层进行零初始化。我们探索了一种改良的adaLN DiT阻断剂,其具有相同的作用。除了回归γ和β之外,我们还回归了在DiT块内的任何残差连接之前立即应用的维度标度参数α。

Model size. 我们应用N个DiT块的序列,每个DiT块以隐藏维度大小d操作。在ViT之后,我们使用标准的Transformer转换器来联合缩放N、d和注意力头。具体而言,我们使用四种配置:DiT-S、DiT-B、DiT-L和DiT-XL。

Transformer decoder. 在最后的DiT块之后,我们需要将我们的图像tokens序列解码为输出噪声预测和输出对角协方差预测。这两个输出都具有与原始空间输入相等的形状。我们使用标准的线性解码器来完成此操作。我们应用最终层范数(如果使用adaLN,则为自适应),并将每个令牌线性解码为p×p×2C张量,其中C是DiT的空间输入中的通道数。最后,我们将解码后的记号重新排列成它们原来的空间布局,以得到预测的噪声和协方差。

实验设置:

我们在ImageNet数据集上以256 × 256和512 × 512图像分辨率训练类条件潜在DiT模型。我们用零初始化最后的线性层,其他使用来自ViT的标准权重初始化。我们使用AdamW训练所有模型。我们使用1e−4的恒定学习率,没有权重衰减,批量大小为256。我们唯一使用的数据增强是水平翻转。

我们使用了来自SD的现成的预训练变分自动编码器(VAE)模型。VAE编码器的下采样系数为8——给定RGB图像x的形状为256 × 256 × 3,z = E(x)的形状为32 × 32 × 4。在本节的所有实验中,我们的扩散模型都是在这个Z空间中进行的。在从我们的扩散模型中采样新的潜像之后,我们使用VAE解码器x = D(z)将其解码为像素。

评估指标。我们使用FID来衡量缩放性能,FID是评估图像生成模型的标准度量。使用250步DDPM采样,计算FID-50K的结果,没用特殊说明时未采用classifier-free guiance。此外还增加了Inception Score、sFID、Precision/Recall等指标。

结论:

本文引入了扩散变换器(DiTs),这是一个简单的基于变换器的扩散模型主干,它优于先前的U-Net模型,并继承了Transformer模型类的出色缩放特性。鉴于本文中的缩放结果,未来的工作应继续将DiT缩放到更大的模型和令牌计数。DiT也可以作为文本到图像模型(如DALL·E 2和Stable Diffusion)的嵌入式主干进行探索。

DiT代码:

Sampling
python sample.py --image-size 512 --seed 1
Training DiT

To launch DiT-XL/2 (256x256) training with N GPUs on one node:

python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt

训练和采样均使用以下代码设置DIT_models,

model = DiT_models[args.model](
    input_size=latent_size,
    num_classes=args.num_classes
)
DiT

下面进入DiT_models,看看怎么运行的。先来看Class DiT的初始化。可以看出,对inputs(图片的latent特征表示)、timesteps、class_lables的编码分别使用PatchEmbed、TimestepEmbedder、LabelEmbedder。

class DiT(nn.Module):
    """
    Diffusion model with a Transformer backbone.
    """
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4.0,
        class_dropout_prob=0.1,
        num_classes=1000,
        learn_sigma=True,
    ):
        super().__init__()
        self.learn_sigma = learn_sigma
        self.in_channels = in_channels
        self.out_channels = in_channels * 2 if learn_sigma else in_channels
        self.patch_size = patch_size
        self.num_heads = num_heads

        self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
        self.t_embedder = TimestepEmbedder(hidden_size)
        self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
        num_patches = self.x_embedder.num_patches
        # Will use fixed sin-cos embedding:
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)

        self.blocks = nn.ModuleList([
            DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
        ])
        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
        self.initialize_weights()

再来看class DiT的forward函数,输入的条件c为timesteps和class_label编码后相加,x为编码后的图片和位置编码相加,将x和c送入到DiTBlock块中处理,再通过FinalLayer和unpatchify得到输出结果。

def forward(self, x, t, y):
    """
    Forward pass of DiT.
    x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
    t: (N,) tensor of diffusion timesteps
    y: (N,) tensor of class labels
    """
    x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
    t = self.t_embedder(t)                   # (N, D)
    y = self.y_embedder(y, self.training)    # (N, D)
    c = t + y                                # (N, D)
    for block in self.blocks:
        x = block(x, c)                      # (N, T, D)
    x = self.final_layer(x, c)                # (N, T, patch_size ** 2 * out_channels)
    x = self.unpatchify(x)                   # (N, out_channels, H, W)
    return x

DiT的forward_with_cfg会在采样使用cfg时用到,具体是使用了 Classifier-Free Guidance (CFG) 技术。其内部会调用self.forward得到模型输出,然后分离出eps 和 rest,eps是噪声预测结果,rest是方差,再将eps分为cond_eps, uncond_eps,分为有条件和无条件引导,这是由samply.py文件中模型输入方式决定的:

sample.py

# 这里的标号对应的是ImageNet数据集标签
class_labels = [207, 360, 387, 974, 88, 979, 417, 279]

# Create sampling noise:
n = len(class_labels)
z = torch.randn(n, 4, latent_size, latent_size, device=device)
y = torch.tensor(class_labels, device=device)

# Setup classifier-free guidance:
z = torch.cat([z, z], 0)
y_null = torch.tensor([1000] * n, device=device)
y = torch.cat([y, y_null], 0)
model_kwargs = dict(y=y, cfg_scale=args.cfg_scale)

# Sample images:
samples = diffusion.p_sample_loop(
    model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
)
def forward_with_cfg(self, x, t, y, cfg_scale):
    """
    Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
    """
    # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
    half = x[: len(x) // 2]
    combined = torch.cat([half, half], dim=0)
    model_out = self.forward(combined, t, y)
    # For exact reproducibility reasons, we apply classifier-free guidance on only
    # three channels by default. The standard approach to cfg applies it to all channels.
    # This can be done by uncommenting the following line and commenting-out the line following that.
    # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
    eps, rest = model_out[:, :3], model_out[:, 3:]
    cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
    half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
    eps = torch.cat([half_eps, half_eps], dim=0)
    return torch.cat([eps, rest], dim=1)
DiTBlock

下面看DiTBlock,在init初始化时,norm1和norm2初始化为elementwise_affine=False,不进行仿射变换,LayerNorm层不含有可学习参数;在forward函数中,条件c经过self.adaLN_modulation处理,得到6个参数。对input tokens x进行Layer Norm、Scale、Shift后进行attn多头自注意力,然后乘α系数与x进行残差连接得到新的输出x,再对x进行Layer Norm、Scale、Shift后使用mlp(Pointwise Feedforward),再次进行残差连接,得到最终的输出x:

class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        #使用自适应归一化替换标准归一化层
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        # 残差连接
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x
FinalLayer

对应方法中Transformer deocder,将图像tokens序列解码为输出噪声预测和输出对角协方差预测。

class FinalLayer(nn.Module):
    """
    The final layer of DiT.
    """
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x

再来看Class DiT中的最后一步unpatchify,将最后一层得到的x恢复为原始图片,将每个patch重新拼回图像。

def unpatchify(self, x):
    """
    x: (N, T, patch_size**2 * C)
    imgs: (N, H, W, C)
    """
    c = self.out_channels
    p = self.x_embedder.patch_size[0]
    h = w = int(x.shape[1] ** 0.5)
    assert h * w == x.shape[1]

    x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
    x = torch.einsum('nhwpqc->nchpwq', x)
    imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
    return imgs

下面附上TimestepEmbedder和LabelEmbedder:

TimestepEmbedder
class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb
LabelEmbedder
class LabelEmbedder(nn.Module):
    """
    Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
    """
    def __init__(self, num_classes, hidden_size, dropout_prob):
        super().__init__()
        use_cfg_embedding = dropout_prob > 0
        self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
        self.num_classes = num_classes
        self.dropout_prob = dropout_prob

    def token_drop(self, labels, force_drop_ids=None):
        """
        Drops labels to enable classifier-free guidance.
        """
        if force_drop_ids is None:
            drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
        else:
            drop_ids = force_drop_ids == 1
        labels = torch.where(drop_ids, self.num_classes, labels)
        return labels

    def forward(self, labels, train, force_drop_ids=None):
        use_dropout = self.dropout_prob > 0
        if (train and use_dropout) or (force_drop_ids is not None):
            labels = self.token_drop(labels, force_drop_ids)
        embeddings = self.embedding_table(labels)
        return embeddings

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值