【扩散模型(七)】Stable Diffusion 3 diffusers 源码详解2 - DiT 与 MMDiT 相关代码(上)

64 篇文章 2 订阅
31 篇文章 1 订阅

系列文章目录



一、DiT

  • DiT 1 是 SD3 中 MMDiT 的核心基础,而
  • 通过将 Diffusion 中的 Unet 换成了 DiT Block,来实现基于条件的图像生成。
  • 原文中的条件是类别标签,而非文本提示词。
    在这里插入图片描述
  • 原文测试了多种设置,最终采用了 adaLN-Zero 作为 Cross-Attention 的替代。

DiT 整体代码

官方代码仓库为 https://github.com/facebookresearch/DiT,下面代码的具体位置在 /path/to/DiT/models.py

  • 下方代码为上图的左边部分,输入 x 是 Noised Latent,t 是 Timestep,Label 为 y
  • 其中 block 则是上图中的 DiT Block,将 x 和 c 共同作为输入,以 c 为条件来生成 x (对 Noised Latent 进行去噪)。
   def forward(self, x, t, y):
       """
       Forward pass of DiT.
       x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
       t: (N,) tensor of diffusion timesteps
       y: (N,) tensor of class labels
       """
       x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
       t = self.t_embedder(t)                   # (N, D)
       y = self.y_embedder(y, self.training)    # (N, D)
       c = t + y                                # (N, D)
       for block in self.blocks:
           x = block(x, c)                      # (N, T, D)
       x = self.final_layer(x, c)                # (N, T, patch_size ** 2 * out_channels)
       x = self.unpatchify(x)                   # (N, out_channels, H, W)
       return x

DiT Block

与下面代码中 forward 函数内对应的变量在 DiT Block 中的位置。
在这里插入图片描述

def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x

在这个 DiTBlock 类中,shift_msascale_msagate_msashift_mlpscale_mlpgate_mlp 是从 adaLN_modulation(c) 这一步中得到的,它们在具体功能上是有所区别的,虽然它们是通过同一个输入 c 生成的。

  1. shift_msascale_msa 这两个变量与 Multi-Head Self-Attention (MSA) 模块的自适应层归一化(adaptive LayerNorm, adaLN)有关:

    • shift_msa: 这个变量用于平移 LayerNorm 的输出,也就是在归一化的基础上加上一个偏置。它在调节 MSA 模块的激活输出时用作偏移量。
    • scale_msa: 这个变量用于缩放 LayerNorm 的输出,即对归一化的结果乘以一个比例因子。它控制了 MSA 模块中激活的放大或缩小程度。
  2. gate_msa: 这个变量是作为一个门控(gate)信号,作用于 MSA 模块的输出上。它决定了 MSA 模块输出在累加到 x 之前的权重。如果 gate_msa 很小,那么这个输出会被抑制;如果 gate_msa 接近1,则输出会如常累加。

  3. shift_mlpscale_mlp 这两个变量与 Pointwise Feedforward (MLP) 模块的自适应层归一化(adaLN)有关,类似于 shift_msascale_msa,但它们作用在 MLP 模块上:

    • shift_mlp: 用于平移 LayerNorm 的输出,在 MLP 模块中作为偏移量。
    • scale_mlp: 用于缩放 LayerNorm 的输出,在 MLP 模块中控制激活的放大或缩小。
  4. gate_mlp: 类似于 gate_msa,但它控制的是 MLP 模块的输出。它决定了 MLP 模块输出在累加到 x 之前的权重。

这六个参数是否是相同的值?

adaLN_modulation(c) 中,c 经过一个 nn.Linear 层(即 nn.Linear(hidden_size, 6 * hidden_size, bias=True)),然后被 chunk(6, dim=1) 分成六个部分,分别得到 shift_msa、scale_msa、gate_msa、shift_mlp、scale_mlp 和 gate_mlp。

虽然这些变量来自于同一个线性层的输出,但由于 nn.Linear 层的权重在训练过程中是可学习的,并且是随机初始化的,因此这些权重会在训练过程中被更新为不同的值。

代替 Cross-attention 的 adaLN-Zero Block

那么为什用 adaLN-Zero 来代替 Cross-Attention 呢?主要是因为计算资源。(DiT 原文提到 Cross-attention adds the most Gflops to the model, roughly a 15% overhead.)

  1. 什么是adaLN-Zero Block?
    adaLN-Zero Block是一种改进版的adaLN(Adaptive Layer Normalization)模块,主要用于扩散模型(Diffusion Model)中。它的核心思想是通过初始化技巧和引入额外的缩放参数,来加速模型训练并提高生成样本的质量。

  2. 为什么引入adaLN-Zero Block?

    • 加速训练: 通过将残差块初始化为恒等映射,模型在训练初期更容易收敛,从而加快训练速度。
    • 提升性能: 引入维度缩放参数,使得模型能够学习到更具表达能力的特征表示,从而生成质量更高的样本。
    • 增强稳定性: 恒等初始化有助于稳定模型的训练过程,尤其对于深层模型。
  3. adaLN-Zero Block的工作原理

    • 恒等初始化: 对于每个残差块的最后一个adaLN层,将缩放参数γ初始化为0。这使得该层在初始阶段相当于一个恒等映射,不会对输入数据进行缩放。
    • 维度缩放参数α: 在残差连接之前,引入一个维度缩放参数α,用于对特征进行缩放。这个参数是可学习的,能够自适应地调整特征的尺度。
  4. 与传统adaLN的区别

    • 初始化方式不同: adaLN-Zero对缩放参数γ进行了特殊的初始化,而传统的adaLN通常使用随机初始化。
    • 参数数量增加: adaLN-Zero引入了额外的维度缩放参数α,增加了模型的参数数量。
  5. 为什么有效?

    • 恒等初始化使得模型在训练初期能够快速学习到残差部分,从而加速训练过程。
    • 维度缩放参数提供了更大的灵活性,使得模型能够更好地适应不同尺度的特征。

最后也附上原文,便于对照理解。
在这里插入图片描述


  1. Peebles, William, and Saining Xie. “Scalable diffusion models with transformers.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2023. ↩︎

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值