概述
Qwen2项目地址:QwenLM/Qwen2
从4.37.0版本开始,transformers集成了Qwen2的代码),因此要使用Qwen2,需要transformers>=4.37.0
,Qwen2的代码地址在transformers/models/qwen2
目录下。
和Qwen一样,Qwen2仍然是一个decoder-only的transformer模型,使用RMSNorm、SwiGLU激活函数、RoPE、多头注意力机制等。
层标准化(Layer Normalization)
层标准化采用的是Root Mean Square Layer Normalization(RMSNorm)。RMSNorm的思想很简单,其根据均方根(RMS)对神经网络层的输出进行正则化,如以下公式所示:
a ˉ i = a i RMS ( a ) g i , where RMS ( a ) = 1 n ∑ i = 1 n a i 2 \bar{a}_{i}=\frac{a_{i}}{\operatorname{RMS}(\mathbf{a})} g_{i}\,, \quad \text{where}\ \operatorname{RMS}(\mathbf{a})=\sqrt{\frac{1}{n} \sum_{i=1}^{n} a_{i}^{2}} aˉi=RMS(a)aigi,where RMS(a)=n1i=1∑nai2
其中, a ∈ R n \mathbf{a} \in \mathbb{R}^n a∈Rn为网络层的输出向量, g ∈ R n \mathbf{g} \in \mathbb{R}^n g∈Rn是用于缩放标准化后的输出的增益参数,在开始时设置为1。
下面是Qwen2RMSNorm
的代码,基本上是根据上述的公式实现的:
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
LayerNorm的位置
Qwen2中在三个地方使用了LayerNorm:
- 在decoder层中,
hidden_states
被输入到自注意力子层之前会先应用LayerNorm:
class Qwen2DecoderLayer(nn.Module):
def __init__(self, config: Qwen2Config, layer_idx: int):
...
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = residual + hidden_states
...
- 在decoder层中,自注意力子层输出的
hidden_states
被输入到全连接子层之前会先应用LayerNorm:
class Qwen2DecoderLayer(nn.Module):
def __init__(self, config: Qwen2Config, layer_idx: int):
...
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
# Self Attention
...
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
...
- 对最后一个decoder层输出的
hidden_states
应用LayerNorm:
class Qwen2Model(Qwen2PreTrainedModel):
def __init__(self, config: Qwen2Config):
...
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
...
for decoder_layer in self.layers:...
hidden_states = self.norm(hidden_states)
...
位置编码
位置编码使用的是Rotary Position Embedding,它使用旋转矩阵对绝对位置进行编码,同时在自注意力公式中结合了明确的相对位置依赖,也就是说它将相对位置信息依赖集成到了self-attention中。因此该方法的位置编码是发生在注意力的计算过程中,并非之前的在输入tokens时将位置embedding和token embedding相加。
RoPE模块代码如下,主要是为了计算出RoPE方法中的cos值和sin值:
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
得到cos值和sin值后,使用下面的函数进行绝对位置编码和集成相对位置信息依赖:
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
注意到,这里的实现跟论文中的实现有一点不一样:
- 论文中是对输入向量 x m ∈ R d \bold{x}_m \in \R^d xm∈Rd的 d d d个元素从左到右排列后每两个数组成一个二维向量,再对其应用相应的二维旋转矩阵,如下面的公式和示意图所示:
( x m ( 2 i − 1 ) ′ x m ( 2 i ) ′ ) = ( cos m θ i − sin m θ i sin m θ i cos m θ i ) ( x m ( 2 i − 1 ) x m ( 2 i ) ) , i ∈ [ 1 , … , d / 2 ] \left(\begin{array}{cc} {x_m^{(2i-1)}}^{\prime} \\ {x_m^{(2i)}}^{\prime} \end{array}\right) = \left(\begin{array}{cc}\cos m \theta_i & -\sin m \theta_i \\ \sin m \theta_i & \cos m \theta_i\end{array}\right) \left(\begin{array}{cc} x_m^{(2i-1)} \\ x_m^{(2i)} \end{array}\right), \quad i \in [1,\dots,d/2] (xm(2i−1)′xm(2i)′)=(cosmθisinmθi−sinmθicosmθi)(xm(2i−1)xm(2i)),i∈[1,…,d/2]
- 这里的实现则是将 x m ∈ R d \bold{x}_m \in \R^d xm∈Rd的 d d d个元素从左到右排列后分为前后相同长度的两部分,这两部分对应位置的元素组成一个二维向量,再对其应用相应的二维旋转矩阵,如下面的公式所示:
( x m ( i ) ′ x m ( i + d / 2 ) ′ ) = ( cos m θ i − sin m θ i sin m θ i cos m θ i ) ( x m ( i ) x m ( i + d / 2 ) ) , i ∈ [ 1 , … , d / 2 ] \left(\begin{array}{cc} {x_m^{(i)}}^{\prime} \\ {x_m^{(i+d/2)}}^{\prime} \end{array}\right) = \left(\begin{array}{cc}\cos m \theta_i & -\sin m \theta_i \\ \sin m \theta_i & \cos m \theta_i\end{array}\right) \left(\begin{array}{cc} x_m^{(i)} \\ x_m^{(i+d/2)} \end{array}\right), \quad i \in [1,\dots,d/2] (xm(i)′xm(i+d/2)′)=(cosmθisinmθi−sinmθicosmθi)(xm(i)xm(i+d/2)),i∈[1,…,d/2]
虽然这两种实现方式在对输入向量的元素的排列组合上有所差异(这里应该也是为了实现上的方便),但是原理是一样的,对结果也没影响,都能达到集成相对位置信息依赖的目的。
激活函数
激活函数只用在decoder模块的前馈网络中,用的SwiGLU(门控线性单元的一种变体),代码如下:
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
class Qwen2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
这里面的config.hidden_act
根据配置文件中的设置,其值为"silu"
,也就是SiLU(Sigmoid Linear Unit)函数:
SiLU ( x ) = x ∗ σ ( x ) , 其中 σ ( x ) 为 Sigmoid 函数 \text{SiLU}(x) = x \ast \sigma(x), \quad \text{其中 } \sigma(x) \text{ 为 Sigmoid 函数} SiLU(x)=x∗σ(x),其中 σ(x) 为 Sigmoid 函数
SiLU函数也称为Swish函数,用Swish函数替换原始GLU中的Sigmoid,就得到了SwiGLU:
SwiGLU ( x , W , V ) = Swish ( x W ) ⊗ ( x V ) , ⊗ 表示元素乘积 \text{SwiGLU}(x, W, V) = \text{Swish}(xW) \otimes (xV),\otimes\text{表示元素乘积} SwiGLU(x,W,V)=Swish(xW)⊗(xV),⊗表示元素乘积
self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
这行代码就实现了使用SwiGLU激活函数的FFN的前向计算:
FFN SwiGLU ( x , W , V , W 2 ) = ( Swish ( x W ) ⊗ x V ) W 2 \text{FFN}_{\text{SwiGLU}}(x, W, V, W2) = (\text{Swish}(xW) \otimes xV)W2 FFNSwiGLU(x,W,V,W2)=(Swish(xW)⊗xV)W2
自注意力
自注意力实现方式
在Qwen2中,有三种实现自注意力机制的方式:
QWEN2_ATTENTION_CLASSES = {
"eager": Qwen2Attention,
"flash_attention_2": Qwen2FlashAttention2,
"sdpa": Qwen2SdpaAttention,
}
-
eager
自定义实现
-
flash_attention_2
基于FlashAttention2(支持滑动窗口注意力)
-
sdpa
自注意力类型
Qwen2支持三种类型的自注意力机制:
- MHA(Multi-Head Attention):标准的多头注意力
- MQA(Multi-Query Attention):多Query注意力,在所有不同的注意力头中共享同一个Key和Value
- GQA(Grouped-Query Attention):分组Query注意力,将注意力头分成多个组,每个组中的Query共享同一个Key和Value
自注意力层的部分初始化代码如下:
class Qwen2Attention(nn.Module):
def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
....
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
...
其中config.num_key_value_heads
的不同取值决定了使用哪种自注意力类型:
config.num_key_value_heads=config.num_attention_heads
表示使用MHAconfig.num_key_value_heads=1
表示使用MQA1 < config.num_key_value_heads < config.num_attention_heads
表示使用GQA