【扩散模型(八)】Stable Diffusion 3 diffusers 源码详解2 - DiT 与 MMDiT 相关代码(下)

32 篇文章 1 订阅
8 篇文章 0 订阅

系列文章目录



MMDiT

在这里插入图片描述

四层代码结构

  • 上图中的 (a) 为第一层,(b) 为第二层和三层,而 (b) 中 Attention 的实现是在另外一个代码文件(第四层)中。
  • 文本和图像的融合部分是在第四层 (JointAttnProcessor2_0) 中。
  • 第四层的完整结构如下所示,重点放在了 Joint Attention 的具体实现上。
    在这里插入图片描述
第一层

图(a)对应的代码在 /path/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py , 是从 noise_pred = self.transformer 进入到整个 transformer ( MM-DiT 1 至 d )中

pipeline_stable_diffusion_3.py 中的 call 函数中的以下片段

noise_pred = self.transformer(
                    hidden_states=latent_model_input,
                    timestep=timestep,
                    encoder_hidden_states=prompt_embeds,
                    pooled_projections=pooled_prompt_embeds,
                    joint_attention_kwargs=self.joint_attention_kwargs,
                    return_dict=False,
                )[0]
第二和第三层

(b)对应的代码在进入 transformer( /path/lib/python3.12/site-packages/diffusers/models/transformers/transformer_sd3.py )后,的 for 循环中,依次进入每个 MM-DiT block(JointTransformerBlock)

第二层: transformer_sd3.py 中的 forward 函数中以下片段进入 for 循环,如果不训练 backbone的话,那么就是从 else 分支进入 block 中。

 for index_block, block in enumerate(self.transformer_blocks):
     if self.training and self.gradient_checkpointing:

         def create_custom_forward(module, return_dict=None):
             def custom_forward(*inputs):
                 if return_dict is not None:
                     return module(*inputs, return_dict=return_dict)
                 else:
                     return module(*inputs)

             return custom_forward

         ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
         encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
             create_custom_forward(block),
             hidden_states,
             encoder_hidden_states,
             temb,
             **ckpt_kwargs,
         )

     else:                
         encoder_hidden_states, hidden_states = block(
              hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
         )

第三层: block 的实现是在 /path/lib/python3.12/site-packages/diffusers/models/attention.py 中的 JointTransformerBlock 类,其中 hidden_states (noisy latent)和 encoder_hidden_states (text prompt) 分别通过 norm1 和 norm1_context 后,进入了第四层 self.attn

	norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
	if self.context_pre_only:
	    norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
	else:
	    norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
	        encoder_hidden_states, emb=temb
	    )
	
	# Attention.
	attn_output, context_attn_output = self.attn(
	    hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
	)
第四层

从 self.attn 的 init 中,我们可以找到实际代码在 JointAttnProcessor2_0() 类,即 /path/lib/python3.12/site-packages/diffusers/models/attention_processor.py

下方为 self.attn 的 init 初始化

if hasattr(F, "scaled_dot_product_attention"):
            processor = JointAttnProcessor2_0()
        else:
            raise ValueError(
                "The current PyTorch version does not support the `scaled_dot_product_attention` function."
            )
        self.attn = Attention(
            query_dim=dim,
            cross_attention_dim=None,
            added_kv_proj_dim=dim,
            dim_head=attention_head_dim // num_attention_heads,
            heads=num_attention_heads,
            out_dim=attention_head_dim,
            context_pre_only=context_pre_only,
            bias=True,
            processor=processor,
        )

下方画出的图片和对应代码即为文图融合的核心关键,在原论文中1对这部分结构的解释是 “等价于两个针对文/图模态的独立的 transformers,但在 attention 操作中两种模态联合(joining)在了一起”,贴出原文描述来更好理解
在这里插入图片描述

在这里插入图片描述

class JointAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
    ) -> torch.FloatTensor:
        residual = hidden_states

        input_ndim = hidden_states.ndim
        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
        context_input_ndim = encoder_hidden_states.ndim
        if context_input_ndim == 4:
            batch_size, channel, height, width = encoder_hidden_states.shape
            encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size = encoder_hidden_states.shape[0]

        # `sample` projections.
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        # `context` projections.
        encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

        # attention
        query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
        key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
        value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        hidden_states = hidden_states = F.scaled_dot_product_attention(
            query, key, value, dropout_p=0.0, is_causal=False
        )
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # Split the attention outputs.
        hidden_states, encoder_hidden_states = (
            hidden_states[:, : residual.shape[1]],
            hidden_states[:, residual.shape[1] :],
        )

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)
        if not attn.context_pre_only:
            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
        if context_input_ndim == 4:
            encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        return hidden_states, encoder_hidden_states

  1. Scaling Rectified Flow Transformers for High-Resolution Image Synthesis ↩︎

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值