pytorch千问模型源码分析


# 规范化技术,旨在替代传统的 Layer Normalization(LN)
# 核心思想是对输入张量的每个样本的每个特征进行规范化,使其均值为 0,方差为 1
class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6): # 隐藏层的大小
        super().__init__()
        # 一个可学习的权重参数,初始化为全 1 张量。
        self.weight = nn.Parameter(torch.ones(hidden_size))
        # 用于防止除零错误的小常数。
        self.variance_epsilon = eps
    def forward(self, hidden_states):
        # 记录输入张量的数据类型,以便最终转换回原始类型。
        input_dtype = hidden_states.dtype
        # 转换为 torch.float32 类型,以确保数值稳定性。
        hidden_states = hidden_states.to(torch.float32)
        # 计算每个样本的方差
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        # 计算每个样本的 RMS 值,并对每个样本进行规范化
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        # 应用可学习的权重,其中 γγ 是一个可学习的参数,用于缩放规范化后的张量。
        return self.weight * hidden_states.to(input_dtype)
# 用于生成旋转位置嵌入。这种嵌入方法在 Transformer 模型中用于捕捉序列中的位置信息,尤其适用于长序列任务。
# 通过旋转的方式将位置信息编码到嵌入向量中。具体步骤如下:
# 生成频率:通过指数函数生成一系列频率值。计算正弦和余弦:利用生成的频率计算正弦和余弦值
# ,旋转嵌入:将输入向量按一定规则旋转,以嵌入位置信息。
class Qwen2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        self.dim = dim
        # 最大位置嵌入的长度,默认为 2048,base:基数,默认为 10000。。
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        # inv_freq:计算频率的逆值。
        # 位置列表先归一化(从绝对位置变成相对位置),之后取指数(1--接近10000),之后取倒数,位置从1--越来越小
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        # register_buffer:将 inv_freq 注册为缓冲区,以便在模型保存和加载时保持不变。
        # register_buffer 方法用于注册一个非训练的缓冲区(buffer),这意味着它不会被梯度更新。当你使用 register_buffer 注册一个缓
        # 冲区时,它会被保存在模型的状态字典(state dict)中,并且在模型保存和加载时也会被序列化。
        # persistent=True:缓冲区会出现在模型的状态字典中,并且会被序列化和加载。
        # persistent=False:缓冲区不会出现在模型的状态字典中,但在实际保存和加载时,仍然会被序列化并加载。
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        # Build here to make `torch.jit.trace` work.生成正弦和余弦缓存
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )
    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        # t 是一个包含位置索引的张量,形状为 (seq_len,)。
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
        # torch.outer:计算外积,得到一个形状为 (seq_len, dim/2) 的张量
        freqs = torch.outer(t, self.inv_freq) # 计算频率。
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        # 拼接频率。emb 的形状为 (seq_len, dim)。
        # 在旋转位置嵌入(RoPE)中,我们通常将嵌入向量分为两个部分,并分别应用正弦和余弦变换。具体来说:
        # 对于每个位置 tt,计算频率 ff,得到一个形状为 (seq_len, dim/2) 的张量。
        # 将频率张量拼接两次,得到一个形状为 (seq_len, dim) 的张量。
        # 这样做的原因是,我们将嵌入向量分为两部分,每部分对应一个频率值。
        emb = torch.cat((freqs, freqs), dim=-1)
        # cos_cached 和 sin_cached:注册正弦和余弦缓存。
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
    def forward(self, x, seq_len=None): # x:输入张量。
        # x: [bs, num_attention_heads, seq_len, head_size]
        # 如果 seq_len 大于已缓存的最大长度,则重新生成缓存。
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
        return ( # 返回正弦和余弦缓存的切片。
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )

class Qwen2MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size # d
        self.intermediate_size = config.intermediate_size # hd
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) # d-->hd
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# d-->hd
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) # hd-->d
        self.act_fn = ACT2FN[config.hidden_act]
    def forward(self, hidden_state): # (h,s,d)
        # 门控信号生成:gate_proj(hidden_state) 生成门控信号
        # 特征调整:gate_output 与 up_output 相乘,将门控信号应用于特征表示。
        # 门控机制的作用:通过门控信号动态调整哪些特征应该通过哪些特征应该被抑制。
        # 激活函数的选择:如果 config.hidden_act 是 "sigmoid",那么激活函数将是 sigmoid
        return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))

class Qwen2Attention(nn.Module):
    def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
        super().__init__() # 调用父类的初始化方法
        self.config = config # 配置类实例
        self.layer_idx = layer_idx # 层索引
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
                "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )
        
        self.hidden_size = config.hidden_size # d
        self.num_heads = config.num_attention_heads # q_h
        self.head_dim = self.hidden_size // self.num_heads # dk
        self.num_key_value_heads = config.num_key_value_heads # kv_h
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads # 比例
        self.max_position_embeddings = config.max_position_embeddings # p
        self.rope_theta = config.rope_theta # base
        self.is_causal = True # 是否用因果掩码
        self.attention_dropout = config.attention_dropout # dropout
        # 嵌入维度必须能被整除
        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        # 线性投影
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
        #需要注意的是这里的投影维度可能和q的投影维度不同
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        # 最后一个线性转换层
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        # 旋转位置嵌入层
        self.rotary_emb = Qwen2RotaryEmbedding(
            self.head_dim, # dk
            max_position_embeddings=self.max_position_embeddings,# max_position
            base=self.rope_theta, # base
        )
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,# 可选
        position_ids: Optional[torch.LongTensor] = None,# 可选
        past_key_value: Optional[Cache] = None, # 可选参数:缓存
        output_attentions: bool = False,# 是否输出注意力权重
        use_cache: bool = False, # 是否使用缓存
        cache_position: Optional[torch.LongTensor] = None, # 缓存位置
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size() # b,s,d
        # 投影
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        # (b,q_len,q_h,dk)-->(b,q_h,q_len,dk),transpose:换轴(转置)
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        # (b,k_h,k_len,dk)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        kv_seq_len = key_states.shape[-2] # k_len
        # 缓存上个时间步的key,value表示
        if past_key_value is not None: # 如果设置了缓存
            if self.layer_idx is None: # 就必须有layer_idx,不然报错
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        # 旋转位置嵌入,传kv_len
        # 键/值序列长度:kv_seq_len 是键和值向量的长度,这是因为键和值向量代表的是相同的序列。
        # 查询序列长度:q_len 是查询向量的长度,这可能不同于键/值向量的长度。
        # 旋转位置嵌入:在计算旋转位置嵌入时,使用键/值序列长度是为了确保位置信息与键和值向量一致。
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        # 返回带位置信息的嵌入表示
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
        # 如果past_key_value is not None
        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            # 更新
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
        # repeat k/v heads if n_kv_heads < n_heads
        # 如果键值头数量少于查询头数量,则重复键值头以匹配查询头数量。
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        # (b,q_h,q_len,dk)@(b,k_h,dk,k_len)-->(b,h,q_len,k_len)
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )
        # 切片,在最后一个维度切出q_len的长度
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            # 相加,一般遮挡的地方是很大的负数
            attn_weights = attn_weights + causal_mask
        # upcast attention to fp32
        # 在q_len上归一化,得到query序列中每个token对应key中token的一系列权重,这些权重中较大的值表示和当前query中的token
        # 相似度较近,较小的表示离当前query中token较远
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        # dropout
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        # (b,h,q_len,k_len)@(b,h,v_len,dk)-->(b,h,q_len,dk)
        attn_output = torch.matmul(attn_weights, value_states)
        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )
        # (b,h,q_len,dk)-->(b,h,q_len,h,dk),之后.contiguous()转为内存连续存储
        attn_output = attn_output.transpose(1, 2).contiguous()
        # (b,h,q_len,h,dk)-->(b,h,d)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
        # 最后经过线性转换
        attn_output = self.o_proj(attn_output)
        # 不输出注意力权重
        if not output_attentions:
            attn_weights = None
        # 返回多头注意力的输出,注意力权重,上个时间步的key_value的缓存
        return attn_output, attn_weights, past_key_value
class Qwen2FlashAttention2(Qwen2Attention):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs) # 调用父类初始化方法
        # 如果大于2_10,这个是False
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
    def forward(
        self, # 当前实例
        hidden_states: torch.Tensor,# 上一层的输入,或者第一次传人的嵌入
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ):
        bsz, q_len, _ = hidden_states.size() # b,s,d
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        # (b,h,q_len,dk)
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        
        kv_seq_len = key_states.shape[-2] # k_len
        # 如果有k_v缓存,就必须有层索引
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
                    f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        # Because the input can be padded, the absolute sequence length depends on the max position id.
        
        rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
        # 获取cos,sin
        cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
        # 获取带位置信息的q,k状态
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
        # 使用滑动窗口的条件是这么几个,配置中有这个属性,kv长度大于窗口大小等
        use_sliding_windows = (
            _flash_supports_window_size
            and getattr(self.config, "sliding_window", None) is not None
            and kv_seq_len > self.config.sliding_window
            and self.config.use_sliding_window
        )
        # 不支持就报警告信息
        if not _flash_supports_window_size:
            logger.warning_once(
                "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
                " make sure to upgrade flash-attn library."
            )
        # 如果设置了past_key_value
        if past_key_value is not None:
            # Activate slicing cache only if the config has a value `sliding_windows` attribute
            cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
            if (
                getattr(self.config, "sliding_window", None) is not None
                and kv_seq_len > self.config.sliding_window
                and cache_has_contents
            ):
                slicing_tokens = 1 - self.config.sliding_window

                past_key = past_key_value[self.layer_idx][0]
                past_value = past_key_value[self.layer_idx][1]

                past_key = past_key[:, :, slicing_tokens:, :].contiguous()
                past_value = past_value[:, :, slicing_tokens:, :].contiguous()

                if past_key.shape[-2] != self.config.sliding_window - 1:
                    raise ValueError(
                        f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
                        f" {past_key.shape}"
                    )

                if attention_mask is not None:
                    attention_mask = attention_mask[:, slicing_tokens:]
                    attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)

            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        # repeat k/v heads if n_kv_heads < n_heads,设置q,k,v的头相同
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        # 如果是训练模式,就用dropout,评估模式不用dropout
        dropout_rate = 0.0 if not self.training else self.attention_dropout
        # 在PEFT(Prompt-Encoder Fine-Tuning)中,通常我们将层归一化(LayerNorm)用浮点数32位(float32)进行训练以
        # 保证稳定性。因此,输入的隐藏状态会被默默地转换为浮点数32位。所以,我们需要将它们转换回浮点数16位(float16),
        # 只是为了确保一切按预期工作。
        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )

            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        # Reashape to the expected shape for Flash Attention
        query_states = query_states.transpose(1, 2) # (b,q_len,h,dk)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        attn_output = self._flash_attention_forward(
            query_states,
            key_states,
            value_states,
            attention_mask,
            q_len,
            dropout=dropout_rate,
            use_sliding_windows=use_sliding_windows,
        )
        # (b,s,d)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
        attn_output = self.o_proj(attn_output)
        if not output_attentions:
            attn_weights = None
        return attn_output, attn_weights, past_key_value

    def _flash_attention_forward(
        self,
        query_states,
        key_states,
        value_states,
        attention_mask,
        query_length,
        dropout=0.0,
        softmax_scale=None,
        use_sliding_windows=False,
    ):
        if not self._flash_attn_uses_top_left_mask:
            causal = self.is_causal
        else:
            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
            causal = self.is_causal and query_length != 1

        # Decide whether to use SWA or not by layer index.
        # 超过配置的使用滑动窗口的最大层数的话,设置use_sliding_windows = False
        if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
            use_sliding_windows = False
        # 在序列中至少包含一个填充标记
        if attention_mask is not None:
            batch_size = query_states.shape[0] # b
            # query_states:(len(indices_q),h,dk)
            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
                query_states, key_states, value_states, attention_mask, query_length
            )
            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
            # 不用滑动窗口的情况
            if not use_sliding_windows:
                attn_output_unpad = flash_attn_varlen_func(
                    query_states,
                    key_states,
                    value_states,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_k=cu_seqlens_k,
                    max_seqlen_q=max_seqlen_in_batch_q,
                    max_seqlen_k=max_seqlen_in_batch_k,
                    dropout_p=dropout,
                    softmax_scale=softmax_scale,
                    causal=causal,
                )
            else:
                attn_output_unpad = flash_attn_varlen_func(
                    query_states,
                    key_states,
                    value_states,
                    cu_seqlens_q=cu_seqlens_q,
                    cu_seqlens_k=cu_seqlens_k,
                    max_seqlen_q=max_seqlen_in_batch_q,
                    max_seqlen_k=max_seqlen_in_batch_k,
                    dropout_p=dropout,
                    softmax_scale=softmax_scale,
                    causal=causal,
                    window_size=(self.config.sliding_window, self.config.sliding_window),
                )
            # attn_output_unpad:(len(indices_q),h,dk),indices_q:一维张量,非填充token索引
            # 这个可以把去填充的张量恢复成去填充之前的张量
            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
        else: # 没有设置attention_mask的情况
            if not use_sliding_windows: # 不用滑动窗口的情况
                attn_output = flash_attn_func(
                    query_states,
                    key_states,
                    value_states,
                    dropout,
                    softmax_scale=softmax_scale,
                    causal=causal,
                )
            else: # 用滑动窗口的情况
                attn_output = flash_attn_func(
                    query_states,
                    key_states,
                    value_states,
                    dropout,
                    softmax_scale=softmax_scale,
                    causal=causal,
                    window_size=(self.config.sliding_window, self.config.sliding_window),
                )
        # 返回注意力输出,形状(b,s,h,dk)
        return attn_output
    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
        # b,k_len,h,dk
        batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
        # 在第一次迭代时,我们需要通过在正确的位置进行切片来正确地重新创建填充掩码
        if kv_seq_len != attention_mask.shape[-1]:
            # 获取掩码的最后一维的大小
            attention_mask_num_tokens = attention_mask.shape[-1]
            # 切出kv_len的大小
            attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
        # 非填充的token索引,每一批次中每个序列长度的累积和,表示当前批次中最长序列的长度。
        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
        # index_first_axis = IndexFirstAxis.apply
        # 调用 .apply 方法:IndexFirstAxis.apply(input_tensor, indices) 实际上调用的是 forward 方法。
        # 前向传播:forward 方法执行具体的计算,并返回结果。
        # 反向传播:当计算梯度时,PyTorch 自动调用 backward 方法来计算梯度。
        # 返回的key_layer形状是(len(indices_k),h,dk),去了填充token的嵌入
        # len(indices_k):表示去除了填充后的有效序列长度。
        key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
        value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
        if query_length == kv_seq_len:
            # 索引在第一个轴
            query_layer = index_first_axis(
                query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
            )
            # 这时给q符相应的值
            cu_seqlens_q = cu_seqlens_k
            max_seqlen_in_batch_q = max_seqlen_in_batch_k
            indices_q = indices_k
        
        elif query_length == 1:
            max_seqlen_in_batch_q = 1
            # 每一批次中每个序列长度的累积和
            cu_seqlens_q = torch.arange(
                batch_size + 1, dtype=torch.int32, device=query_layer.device
            )  # There is a memcpy here, that is very bad.
            # 非填充token索引
            indices_q = cu_seqlens_q[:-1]
            query_layer = query_layer.squeeze(1) # (b,h,dk)
        else:
            # The -q_len: 切片操作假设是左填充
            attention_mask = attention_mask[:, -query_length:]
            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

        return (
            query_layer,# (len(indices_q),h,dk),
            key_layer,
            value_layer,
            indices_q,
            (cu_seqlens_q, cu_seqlens_k),
            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
        )
class Qwen2SdpaAttention(Qwen2Attention):
    # Adapted from Qwen2Attention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # 不支持输出注意力权重的回退警告,和回退到父类的forward调用
        if output_attentions:
            # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
            logger.warning_once(
                "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
            )
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )
        bsz, q_len, _ = hidden_states.size() # b,q_len,d
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        # (b,s,d)-->(b,s,h,dk)-->(b,h,s,dk)
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        kv_seq_len = key_states.shape[-2] # k_len
        # 如果使用key_value缓存
        if past_key_value is not None:
            # 设置新的kv_seq_len
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
        # 旋转位置嵌入
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        # 带上位置信息的嵌入
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
        # 设置缓存的情况
        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
        # 设置k,v和q具有相同的头数
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        causal_mask = attention_mask
        # 如果有传人掩码,切取k_len长度
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and attention_mask is not None:
            # 设置内存连续状态
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()
        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
        # q_len等于1时不需要因果掩码,编码器自注意力也不需要因果掩码
        is_causal = True if causal_mask is None and q_len > 1 else False
        
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
        )
        # (b,h,s,dk)-->(b,s,h,dk)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        return attn_output, None, past_key_value

# 千问解码器层,继承自nn.Module
class Qwen2DecoderLayer(nn.Module):
    # 构造函数参数:self,当前实例对象,config,Qwen2Config实例对象,layer_idx,层索引
    def __init__(self, config: Qwen2Config, layer_idx: int):
        super().__init__() # 调用父类的初始化方法
        self.hidden_size = config.hidden_size # d
        # 如果设置了使用滑动窗口,就必须设置_attn_implementation为"flash_attention_2,不然会警告
        if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
            logger.warning_once(
                f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
                "unexpected results may be encountered."
            )
        # 多头注意力层
        self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
        # 前馈全连接层
        self.mlp = Qwen2MLP(config)
        # 改良的标准化层
        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    def forward(
        self, # 当前实例对象
        hidden_states: torch.Tensor, # 上次的解码器层输出或者第一次的嵌入层
        # 可选参数:注意力掩码
        attention_mask: Optional[torch.Tensor] = None,
        # 位置ids
        position_ids: Optional[torch.LongTensor] = None,
        # 可选参数:元组(里面元素是Tensor)类型,一般用于推理时,指当前token
        # 之前的所有token表示,这里是作为key,value
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        # 可选参数:是否输出注意力权重,布尔类型
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False, # 是否使用缓存
        # 缓存的位置ids,可选,是LongTensor类型
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs, #其他参数,以上是类型注解,规范性的代码就应该有这个
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states #残差前段,如果第一次,是嵌入,否则是上一次解码器的输出
        hidden_states = self.input_layernorm(hidden_states) # 标准化
        # 目标序列自注意力
        # 自注意力输出,自注意力权重,返回的key_value
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask, # 注意力掩码
            position_ids=position_ids, # 位置ids
            past_key_value=past_key_value, #上个时间步的key_value
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
        )
        # 自注意力前后残差连接
        hidden_states = residual + hidden_states
        #重新设定残差连接前段
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states) #标准化
        hidden_states = self.mlp(hidden_states) # 前馈全连接层
        hidden_states = residual + hidden_states # 前馈前后残差
        #解码器输出,放入一个元组中
        outputs = (hidden_states,)
        if output_attentions:  # 如果要输出注意力权重
            outputs += (self_attn_weights,)
        if use_cache: # 如果使用缓存
            outputs += (present_key_value,)
        return outputs # 返回元组

PyTorch是一个流行的深度学习框架,其源码主要由Python语言编写。通过对PyTorch源码进行分析,可以了解其内部实现机制和算法原理,有助于深入理解和掌握PyTorch的特性和性能。 PyTorch源码主要由以下几个部分组成: 1. 核心模块:PyTorch的核心模块包括神经网络模型、优化器、数据加载器等。这些模块负责实现PyTorch的核心功能,如前向传播、反向传播、梯度计算等。 2. 底层实现:PyTorch的底层实现包括数据类型、内存管理、设备管理等。这些实现细节直接影响PyTorch的性能和稳定性。 3. 扩展模块:PyTorch提供了许多扩展模块,如文本处理、图像处理、强化学习等。这些模块扩展了PyTorch的功能,方便用户进行各种类型的数据分析和处理。 在对PyTorch源码进行分析时,可以从以下几个方面入手: 1. 模块结构:了解PyTorch的模块结构,分析各个模块的功能和相互关系。 2. 数据类型和内存管理:分析PyTorch的数据类型和内存管理机制,了解如何高效地处理数据和计算。 3. 算法实现:分析PyTorch中常用的算法实现,如反向传播、优化器等,了解其原理和实现方式。 4. 性能优化:分析PyTorch的性能优化策略,如动态计算图、分布式训练等,了解如何提高PyTorch的性能和稳定性。 需要注意的是,PyTorch源码较为庞大和复杂,需要具备一定的编程经验和深度学习基础知识才能较好地进行分析。同时,分析PyTorch源码需要耐心和细心,需要不断尝试和调试,才能更好地理解其内部实现机制。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值