GLM-4 (3) - GLMBlock

系列文章目录

GLM-4 (1) - 推理+概览
GLM-4 (2) - RoPE
GLM-4 (3) - GLMBlock
GLM-4 (4) - SelfAttention
GLM-4 (5) - API & Function Calling
GLM-4 (6) - KV Cache / Prefill & Decode



前言

前两篇文章分别讲了GLM-4推理+概览,以及旋转位置编码,这一篇主要来看一下模型架构/组件。我们知道现在的大模型都是基于Transformer的,它又是由若干层TransformerBlock组成。在GLM-4代码中,GLMTransformer对应了Transformer部分,而GLMBlock就对应着TransformerBlock。我们主要就来看这两部分。

一、模型架构简述

简单说一下模型组件之间的关系:ChatGLMForConditionalGeneration是用来chat的完整模型;其重要组件是ChatGLMModel,你可以认为它是一个完整的transformerChatGLMModel的核心组件是GLMTransformer;而GLMTransformer由多层GLMBlock堆叠而成。这么一看,就整个模型架构就比较清楚了。

二、ChatGLMModel & GLMTransformer

我通过debug查看到ChatGLMModel的信息如下,其中包含了一些配置信息,以及"embedding""encoder""encoder"部分模型结构也打印出来了,就是40层GLMBlock的堆叠。下一节会绘制出GLMBlock架构图,并配合代码来阐述这部分结构。

{
    "base_model": ChatGLMModel,
    "base_model_prefix": "transformer",
    "config": ChatGLMConfig,
    "dtype": torch.bfloat16,
    "dummy_inputs": {'input_ids': tensor([[7, 6, 0, 0, 1],
        [1, 2, 3, 0, 0],
        [0, 0, 0, 4, 5]])}     # 为什么???
    "embedding": Embedding((word_embeddings): Embedding(151552, 4096))
    "encoder": GLMTransformer(
          (layers): ModuleList(
            (0-39): 40 x GLMBlock(
              (input_layernorm): RMSNorm()
              (self_attention): SelfAttention(
                (query_key_value): Linear(in_features=4096, out_features=4608, bias=True)
                (core_attention): SdpaAttention(
                  (attention_dropout): Dropout(p=0.0, inplace=False)
                )
                (dense): Linear(in_features=4096, out_features=4096, bias=False)
              )
              (post_attention_layernorm): RMSNorm()
              (mlp): MLP(
                (dense_h_to_4h): Linear(in_features=4096, out_features=27392, bias=False)
                (dense_4h_to_h): Linear(in_features=13696, out_features=4096, bias=False)
              )
            )
          )
          (final_layernorm): RMSNorm()
        )
    "is_gradient_checkpointing": False,
    "is_parallelizable": False,
    "kv_channels": 128,
    "main_input_name": "input_ids",
    "multi_query_group_num": 2,
    "name_or_path": "/home/ubuntu/Projects_ubuntu/glm-4-9b-chat",
    "num_layer": 40,
    "output_layer": Linear(in_features=4096, out_features=151552, bias=False),
    "rotary_pos_emb": RotaryEmbedding(),
    "seq_length": 131072,               # 配置中的长度,应该是预设的context长度==>128k长度
    "supports_gradient_checkpointing": True,
    "training": False,
}

三、GLMBlock

GLMBlockTransformerBlock略有不同,根据glm-4-9b-chat的配置,apply_residual_connection_post_layernorm=False,也就是说残差连接来源于归一化之前(如黑色实线所示),如果设置成True,那么就和TransformerBlock一致(如黑色虚线所示)。
在这里插入图片描述
为了与原始的Transformer对比,这边也贴一下它的架构图:
在这里插入图片描述
接下来我们配合代码,再说一些细节:

  • 归一化可选RMSNormLayerNorm;
  • 使用multi-query attention,这一部分我会在后面单独开一篇来讲;
  • FFN中实际的两层线性层维度变化并不是严格的h -> 4h, 4h -> hh为隐藏层维度;同时,这边激活函数使用的是swiglu
  • dropout在上述图中没有体现。
class MLP(torch.nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

    def __init__(self, config: ChatGLMConfig, device=None):
        super(MLP, self).__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
        self.dense_h_to_4h = nn.Linear(
            config.hidden_size,
            config.ffn_hidden_size * 2,
            bias=self.add_bias,
            device=device,
            **_config_to_kwargs(config)
        )

        def swiglu(x):
            x = torch.chunk(x, 2, dim=-1)
            return F.silu(x[0]) * x[1]

        self.activation_func = swiglu

        # Project back to h.
        self.dense_4h_to_h = nn.Linear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=self.add_bias,
            device=device,
            **_config_to_kwargs(config)
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output = self.dense_4h_to_h(intermediate_parallel)
        return output
class GLMBlock(torch.nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(self, config: ChatGLMConfig, layer_number, device=None):
        super(GLMBlock, self).__init__()
        self.layer_number = layer_number

        self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm

        self.fp32_residual_connection = config.fp32_residual_connection

        LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                             dtype=config.torch_dtype)

        # Self attention.
        self.self_attention = SelfAttention(config, layer_number, device=device)
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                                      dtype=config.torch_dtype)

        # MLP
        self.mlp = MLP(config, device=device)

    def forward(
            self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
    ):
        # hidden_states: [s, b, h]

        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)   # 相比于transformer block,GLMBlock上来就使用了layernorm
        # Self attention.
        attention_output, kv_cache = self.self_attention(
            layernorm_output,
            attention_mask,
            rotary_pos_emb,
            kv_cache=kv_cache,
            use_cache=use_cache
        )                # (1, 8, 4096), (1, 2, 1, 2, 8, 128)

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
        layernorm_input = residual + layernorm_input

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # MLP.
        mlp_output = self.mlp(layernorm_output)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
        output = residual + output

        return output, kv_cache     # (1, 8, 4096), (1, 2, 1, 2, 8, 128)

总结

本篇我们围绕GLM-4核心组件GLMTransformerGLMBlock来分析,并对一些细节做出了解释。下一篇将会对attention部分进行讲解。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值