下面我将详细介绍代码中每一个部分,并用LaTeX公式来表示各个步骤的数学运算。
类 GLMBlock
该类继承自 torch.nn.Module
,表示一个单一的Transformer层。Transformer层接收尺寸为 [s, b, h]
的输入,并返回相同尺寸的输出。
class GLMBlock(torch.nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
初始化方法
初始化方法定义了该层的各个组件,包括输入层的层归一化、自注意力机制、注意力输出的层归一化和MLP层。
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
super(GLMBlock, self).__init__()
self.layer_number = layer_number
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
self.fp32_residual_connection = config.fp32_residual_connection
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
dtype=config.torch_dtype)
self.self_attention = SelfAttention(config, layer_number, device=device)
self.hidden_dropout = config.hidden_dropout
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
dtype=config.torch_dtype)
self.mlp = MLP(config, device=device)
前向传播方法
前向传播方法定义了数据流经各个组件的方式。
def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True):
# hidden_states: [s, b, h]
1. 输入层归一化
对输入进行层归一化。
layernorm_output = LayerNorm ( hidden_states ) \text{layernorm\_output} = \text{LayerNorm}(\text{hidden\_states}) layernorm_output=LayerNorm(hidden_states)
layernorm_output = self.input_layernorm(hidden_states)
2. 自注意力机制
将归一化后的输出传递给自注意力层,并获取注意力输出和更新后的缓存。
attention_output , kv_cache = SelfAttention ( layernorm_output , attention_mask , rotary_pos_emb , kv_cache = kv_cache , use_cache = use_cache ) \text{attention\_output}, \text{kv\_cache} = \text{SelfAttention}(\text{layernorm\_output}, \text{attention\_mask}, \text{rotary\_pos\_emb}, \text{kv\_cache}=\text{kv\_cache}, \text{use\_cache}=\text{use\_cache}) attention_output,kv_cache=SelfAttention(layernorm_output,attention_mask,rotary_pos_emb,kv_cache=kv_cache,use_cache=use_cache)
attention_output, kv_cache = self.self_attention(
layernorm_output,
attention_mask,
rotary_pos_emb,
kv_cache=kv_cache,
use_cache=use_cache
)
3. 残差连接
根据配置决定残差连接的位置。
residual = { layernorm_output if apply_residual_connection_post_layernorm hidden_states otherwise \text{residual} = \begin{cases} \text{layernorm\_output} & \text{if apply\_residual\_connection\_post\_layernorm} \\ \text{hidden\_states} & \text{otherwise} \end{cases} residual={layernorm_outputhidden_statesif apply_residual_connection_post_layernormotherwise
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
4. 添加Dropout并进行第二次层归一化
layernorm_input
=
Dropout
(
attention_output
,
p
=
self.hidden_dropout
)
\text{layernorm\_input} = \text{Dropout}(\text{attention\_output}, p=\text{self.hidden\_dropout})
layernorm_input=Dropout(attention_output,p=self.hidden_dropout)
layernorm_input
=
residual
+
layernorm_input
\text{layernorm\_input} = \text{residual} + \text{layernorm\_input}
layernorm_input=residual+layernorm_input
layernorm_output
=
LayerNorm
(
layernorm_input
)
\text{layernorm\_output} = \text{LayerNorm}(\text{layernorm\_input})
layernorm_output=LayerNorm(layernorm_input)
layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
layernorm_input = residual + layernorm_input
layernorm_output = self.post_attention_layernorm(layernorm_input)
5. MLP层
mlp_output = MLP ( layernorm_output ) \text{mlp\_output} = \text{MLP}(\text{layernorm\_output}) mlp_output=MLP(layernorm_output)
mlp_output = self.mlp(layernorm_output)
6. 第二次残差连接和输出Dropout
residual
=
{
layernorm_output
if apply_residual_connection_post_layernorm
layernorm_input
otherwise
\text{residual} = \begin{cases} \text{layernorm\_output} & \text{if apply\_residual\_connection\_post\_layernorm} \\ \text{layernorm\_input} & \text{otherwise} \end{cases}
residual={layernorm_outputlayernorm_inputif apply_residual_connection_post_layernormotherwise
output
=
Dropout
(
mlp_output
,
p
=
self.hidden_dropout
)
\text{output} = \text{Dropout}(\text{mlp\_output}, p=\text{self.hidden\_dropout})
output=Dropout(mlp_output,p=self.hidden_dropout)
output
=
residual
+
output
\text{output} = \text{residual} + \text{output}
output=residual+output
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
output = residual + output
返回输出和缓存
return output, kv_cache
总结
通过这种方式,GLMBlock
类实现了一个Transformer层,其中包括层归一化、自注意力机制、残差连接、Dropout和MLP层。各个步骤通过LaTeX公式表示如下:
-
输入层归一化:
layernorm_output = LayerNorm ( hidden_states ) \text{layernorm\_output} = \text{LayerNorm}(\text{hidden\_states}) layernorm_output=LayerNorm(hidden_states) -
自注意力机制:
attention_output , kv_cache = SelfAttention ( layernorm_output , attention_mask , rotary_pos_emb , kv_cache = kv_cache , use_cache = use_cache ) \text{attention\_output}, \text{kv\_cache} = \text{SelfAttention}(\text{layernorm\_output}, \text{attention\_mask}, \text{rotary\_pos\_emb}, \text{kv\_cache}=\text{kv\_cache}, \text{use\_cache}=\text{use\_cache}) attention_output,kv_cache=SelfAttention(layernorm_output,attention_mask,rotary_pos_emb,kv_cache=kv_cache,use_cache=use_cache) -
残差连接:
residual = { layernorm_output if apply_residual_connection_post_layernorm hidden_states otherwise \text{residual} = \begin{cases} \text{layernorm\_output} & \text{if apply\_residual\_connection\_post\_layernorm} \\ \text{hidden\_states} & \text{otherwise} \end{cases} residual={layernorm_outputhidden_statesif apply_residual_connection_post_layernormotherwise -
添加Dropout并进行第二次层归一化:
layernorm_input = Dropout ( attention_output , p = self.hidden_dropout ) \text{layernorm\_input} = \text{Dropout}(\text{attention\_output}, p=\text{self.hidden\_dropout}) layernorm_input=Dropout(attention_output,p=self.hidden_dropout)
layernorm_input = residual + layernorm_input \text{layernorm\_input} = \text{residual} + \text{layernorm\_input} layernorm_input=residual+layernorm_input
layernorm_output = LayerNorm ( layernorm_input ) \text{layernorm\_output} = \text{LayerNorm}(\text{layernorm\_input}) layernorm_output=LayerNorm(layernorm_input) -
MLP层:
mlp_output = MLP ( layernorm_output ) \text{mlp\_output} = \text{MLP}(\text{layernorm\_output}) mlp_output=MLP(layernorm_output) -
第二次残差连接和输出Dropout:
residual = { layernorm_output if apply_residual_connection_post_layernorm layernorm_input otherwise \text{residual} = \begin{cases} \text{layernorm\_output} & \text{if apply\_residual\_connection\_post\_layernorm} \\ \text{layernorm\_input} & \text{otherwise} \end{cases} residual={layernorm_outputlayernorm_inputif apply_residual_connection_post_layernormotherwise
output = Dropout ( mlp_output , p = self.hidden_dropout ) \text{output} = \text{Dropout}(\text{mlp\_output}, p=\text{self.hidden\_dropout}) output=Dropout(mlp_output,p=self.hidden_dropout)
output = residual + output \text{output} = \text{residual} + \text{output} output=residual+output