腾讯HunyuanDit代码解析

注意:本文仅供自己记录学习过程使用。

训练

全参训练过程

  1. 输入图像用VAE编码得到输入的x_start(1,4,128,128);文本的两个特征:bert的encoder feature(1,77,1024)和T5 的feature(1,256,2048),和旋转位置编码freqs_cis_img: cos (4096,88),sin (4096,88)。
  2. 生成随机的时间步长t;生成随机的噪声(1,4,128,128),给输入的x_start加上噪声得到输出的x_t;
    def q_sample(self, x_start, t, noise=None):
        """
        Diffuse the data for a given number of diffusion steps.

        In other words, sample from q(x_t | x_0).

        :param x_start: the initial data batch.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :param noise: if specified, the split-out normal noise.
        :return: A noisy version of x_start.
        """
        if noise is None:
            noise = th.randn_like(x_start)
        assert_shape(noise, x_start)
        return (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
            * noise
        )

  1. 对T5 的feature(1,256,2048)用mlp降为到(1,256,1024),然后把它和bert的feature cat起来得到text_states (1,33,1024);
  2. 对时间t编码(1,1408),x_t打成path,x(1,4096,1048);
  3. 对t5 feature进行pooling(multihead self-attention)得到extra_vec(1,1024);
  4. 时间t+mlp(extra_vec)=c(1,1408),得到condition;
  5. 上述步骤已得到以下参数:x ,c,text_states,freqs_cis_img。开始迭代处理。
x = block(x, c, text_states, freqs_cis_img)
  1. mlp(c)+x得到self-attention block的输入,把输入分成q/k/v,然后把q/k用旋转位置编码进行编码,得到新的qk。然后mlp提特征,输出x(1,4096,1408);简单来说,就是输入的x和文本的全局特征做了一次注意力提取特征的操作;
    def forward(self, x, freqs_cis_img=None):
        """
        Parameters
        ----------
        x: torch.Tensor
            (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
        freqs_cis_img: torch.Tensor
            (batch, hidden_dim // 2), RoPE for image
        """
        b, s, d = x.shape

        qkv = self.Wqkv(x)
        qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim)  # [b, s, 3, h, d]
        q, k, v = qkv.unbind(dim=2) # [b, s, h, d]
        q = self.q_norm(q).half()   # [b, s, h, d]
        k = self.k_norm(k).half()

        # Apply RoPE if needed
        if freqs_cis_img is not None:
            qq, kk = apply_rotary_emb(q, k, freqs_cis_img)
            assert qq.shape == q.shape and kk.shape == k.shape, f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
            q, k = qq, kk

        qkv = torch.stack([q, k, v], dim=2)     # [b, s, 3, h, d]
        context = self.inner_attn(qkv)
        out = self.out_proj(context.view(b, s, d))
        out = self.proj_drop(out)

        out_tuple = (out,)

        return out_tuple
def apply_rotary_emb(
        xq: torch.Tensor,
        xk: Optional[torch.Tensor],
        freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
        head_first: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
    returned as real tensors.

    Args:
        xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
        xk (torch.Tensor): Key tensor to apply rotary embeddings.   [B, S, H, D]
        freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
        head_first (bool): head dimension first (except batch dim) or not.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.

    """
    xk_out = None
    if isinstance(freqs_cis, tuple):
        cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)    # [S, D]
        cos, sin = cos.to(xq.device), sin.to(xq.device)
        xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
        if xk is not None:
            xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
    else:
        xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  # [B, S, H, D//2]
        freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device)   # [S, D//2] --> [1, S, 1, D//2]
        xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
        if xk is not None:
            xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  # [B, S, H, D//2]
            xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)

    return xq_out, xk_out
  1. 上面得到的x,以及文本特征text_states,和旋转位置编码freqs_cis_img,作为cross attention block的输入;y是文本特征text_states;x作为q, text_states作为kv,q加上位置编码后,和kv作cross attention,得到输出x(1,4096,1408);
    def forward(self, x, y, freqs_cis_img=None):
        """
        Parameters
        ----------
        x: torch.Tensor
            (batch, seqlen1, hidden_dim) (where hidden_dim = num_heads * head_dim)
        y: torch.Tensor
            (batch, seqlen2, hidden_dim2)
        freqs_cis_img: torch.Tensor
            (batch, hidden_dim // num_heads), RoPE for image
        """
        b, s1, _ = x.shape     # [b, s1, D]
        _, s2, _ = y.shape     # [b, s2, 1024]

        q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim)       # [b, s1, h, d]
        kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim)  # [b, s2, 2, h, d]
        k, v = kv.unbind(dim=2)                 # [b, s2, h, d]
        q = self.q_norm(q).half()               # [b, s1, h, d]
        k = self.k_norm(k).half()               # [b, s2, h, d]

        # Apply RoPE if needed
        if freqs_cis_img is not None:
            qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
            assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
            q = qq                              # [b, s1, h, d]
        kv = torch.stack([k, v], dim=2)         # [b, s1, 2, h, d]
        context = self.inner_attn(q, kv)        # [b, s1, h, d]
        context = context.view(b, s1, -1)       # [b, s1, D]

        out = self.out_proj(context)
        out = self.proj_drop(out)

        out_tuple = (out,)

        return out_tuple
  1. 最后mlp输出x(1,4096,1408)。共有19个hunyuan block,每个block输出的都是(1,4096,1408);(类似于unet encoder的操作,后续就是解码了,但是它这里“编解码”并没有分辨率的概念)。
  2. 开始“解码操作”了,其实就是前面最后输出x和前面block的输出cat起来,然后提取特征,后续步骤和前面是一样的。
    def _forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
        # Long Skip Connection
        if self.skip_linear is not None:
            cat = torch.cat([x, skip], dim=-1)
            cat = self.skip_norm(cat)
            x = self.skip_linear(cat)

        # Self-Attention
        shift_msa = self.default_modulation(c).unsqueeze(dim=1)
        attn_inputs = (
            self.norm1(x) + shift_msa, freq_cis_img,
        )
        x = x + self.attn1(*attn_inputs)[0]

        # Cross-Attention
        cross_inputs = (
            self.norm3(x), text_states, freq_cis_img
        )
        x = x + self.attn2(*cross_inputs)[0]

        # FFN Layer
        mlp_inputs = self.norm2(x)
        x = x + self.mlp(mlp_inputs)

        return x
  1. 最后整个网络输出(1,8,128,128);
  2. 网络的输出前4个通道(1,4,128,128)和输入的纯净的x_start作mse loss,后四个通道作什么变分概率误差;至此训练完成;

lora训练过程

训练过程和全参一样,低秩矩阵调用库peft训练的,略;在这里插入图片描述

controlnet训练过程

  1. 架构和hunyuandit一致;
  2. 1-6步和全参训练一样,第六步后有个VAE编码后的control img(1,4,128,128)作为condition, 把它和x_t相加+,得到网络的输入x;其他c,text_states,freqs_cis_img和之前一样;
        condition = self.x_embedder(condition)

        # ========================= Forward pass through HunYuanDiT blocks =========================
        controls = []
        x = x + self.before_proj(condition) # add condition
        for layer, block in enumerate(self.blocks):
            x = block(x, c, text_states, freqs_cis_img)
            controls.append(self.after_proj_list[layer](x)) # zero linear for output
  1. 输出19个block的control feature;与冻结后的hunyuandit的“解码层”特征相加即可;在这里插入图片描述
        for layer, block in enumerate(self.blocks):
            if layer > self.depth // 2:
                if controls is not None:
                    skip = skips.pop() + controls.pop()
                else:
                    skip = skips.pop()
                x = block(x, c, text_states, freqs_cis_img, skip)   # (N, L, D)
            else:
                x = block(x, c, text_states, freqs_cis_img)         # (N, L, D)

            if layer < (self.depth // 2 - 1):
                skips.append(x)
  1. 损失和之前一致,训练完毕;

推理

  1. 全参推理,基本根训练一样
    a. 准备正向prompt 和负向prompt,cat起来得到prompt_embeds (2,77,1024)。t5还有个文本的embeding:prompt_embeds_t5 (2,256,2048)。
    b. 随机生成噪声(1,4,128,128);
    c. 放到unet中的得到预测的噪声(1,4,128,128);然后减去反向提示词noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
    d. 根据公式得到新的latent,通过不断的去噪,得到最终的结果,放到Vae解码器中得到输出图片。
  2. lora推理,同上;lora融合权重公式,模型原始权重 = 模型原始权重+系数 * lora权重。
def load_hunyuan_dit_lora(transformer_state_dict, lora_state_dict, lora_scale):
    for i in range(num_layers):
        Wqkv = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_A.weight"])  # lora权重
        q, k, v = torch.chunk(Wqkv, 3, dim=0)
        transformer_state_dict[f"blocks.{i}.attn1.to_q.weight"] += lora_scale * q # 原始权重+lora权重
        transformer_state_dict[f"blocks.{i}.attn1.to_k.weight"] += lora_scale * k
        transformer_state_dict[f"blocks.{i}.attn1.to_v.weight"] += lora_scale * v

        out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_A.weight"]) 
        transformer_state_dict[f"blocks.{i}.attn1.to_out.0.weight"] += lora_scale * out_proj

        q_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_A.weight"])
        transformer_state_dict[f"blocks.{i}.attn2.to_q.weight"] += lora_scale * q_proj

        kv_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_A.weight"])
        k, v = torch.chunk(kv_proj, 2, dim=0)
        transformer_state_dict[f"blocks.{i}.attn2.to_k.weight"] += lora_scale * k
        transformer_state_dict[f"blocks.{i}.attn2.to_v.weight"] += lora_scale * v

        out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_A.weight"]) 
        transformer_state_dict[f"blocks.{i}.attn2.to_out.0.weight"] += lora_scale * out_proj
    
    q_proj = torch.matmul(lora_state_dict["pooler.q_proj.lora_B.weight"], lora_state_dict["pooler.q_proj.lora_A.weight"])
    transformer_state_dict["time_extra_emb.pooler.q_proj.weight"] += lora_scale * q_proj
    
    return transformer_state_dict
  1. controlnet推理,和训练基本一样,不一样的在于controlnet的特征可以乘上权重,再加上原始unet的特征。
                    controls = self.controlnet(
                        latent_model_input,
                        t_expand,
                        condition,
                        encoder_hidden_states=prompt_embeds,
                        text_embedding_mask=attention_mask,
                        encoder_hidden_states_t5=prompt_embeds_t5,
                        text_embedding_mask_t5=attention_mask_t5,
                        image_meta_size=ims,
                        style=style,
                        cos_cis_img=freqs_cis_img[0],
                        sin_cis_img=freqs_cis_img[1],
                        return_dict=False,
                    )
                    if isinstance(control_weight, list):
                        assert len(control_weight) == len(controls)
                        controls = [control * weight for control, weight in zip(controls, control_weight)] # 每一层特征乘以权重
                    else:
                        controls = [control * control_weight for control in controls]
                    noise_pred = self.unet(
                        latent_model_input,
                        t_expand,
                        encoder_hidden_states=prompt_embeds,
                        text_embedding_mask=attention_mask,
                        encoder_hidden_states_t5=prompt_embeds_t5,
                        text_embedding_mask_t5=attention_mask_t5,
                        image_meta_size=ims,
                        style=style,
                        cos_cis_img=freqs_cis_img[0],
                        sin_cis_img=freqs_cis_img[1],
                        return_dict=False,
                        controls=controls
                    )
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值