GLM-4 (2) - RoPE

系列文章目录

GLM-4 (1) - 推理+概览
GLM-4 (2) - RoPE
GLM-4 (3) - GLMBlock
GLM-4 (4) - SelfAttention
GLM-4 (5) - API & Function Calling



前言

上一篇对GLM-4模型推理等做了一些记录,但是未涉及网络模型部分的细节。本篇分析一下旋转位置编码RoPE。网上关于理论部分的博客较多,因此本篇不会在理论部分花太大的篇幅,主要还是从代码入手。原始论文在此,目前更关心理论部分的可移步其他教学博客。

在这里插入图片描述


一、RoPE概述

首先我们对RoPE做一个简单概述;然后,为了更好的理解glm系列的位置编码实现,我们先分析chatglm-6b源码,然后转到glm-4-9b-chat上来,因为两者之间一脉相承,又有一些区别。

这部分摘自labmlai。以query为例,说明RoPE是如何为它添加位置编码的。假设query的特征维度是 d d d,那么它们将组成 d / 2 d/2 d/2个特征对。在序列位置 m m m处,query的前两个特征就是 x m ( 0 ) x^{(0)}_m xm(0) x m ( 1 ) x^{(1)}_m xm(1),做位置编码如下( θ \theta θ这里可以认为是常量,实际和所在维度的索引有关):

在这里插入图片描述
计算位置为 m m m n n n的点击注意力分数则有:
在这里插入图片描述
上述推导表明,对于点积注意力,旋转位置编码显示了相对位置编码的性质。上述推导是以简单的二维情形,扩展到多维(偶数),并将特征维度 i i i i + d 2 i + \frac{d}{2} i+2d配对(注意,这边与论文中是不同的,论文中是特征维度 i i i i + 1 i + 1 i+1配对),添加位置编码则表示如下:

在这里插入图片描述
代码如下:

# https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/rope/__init__.py


"""
---
title: Rotary Positional Embeddings (RoPE)
summary: >
  Annotated implementation of RoPE from paper
  RoFormer: Enhanced Transformer with Rotary Position Embedding
---

# Rotary Positional Embeddings (RoPE)

This is an implementation of
[Rotary Positional Embeddings (RoPE)](https://arxiv.org/abs/2104.09864)
in [PyTorch](https://pytorch.org).

Rotary Positional Embeddings (RoPE) encode position information of tokens
with a rotation matrix that naturally incorporates explicit relative position
dependency.

Here's [the training code](experiment.html) for training a transformer model with RoPE
 on Tiny Shakespeare dataset.
"""

import torch
from torch import nn

# from labml.logger import inspect
# from labml_nn.transformers.mha import MultiHeadAttention
from mha import MultiHeadAttention


class RotaryPositionalEmbeddings(nn.Module):
    """
    ## RoPE module
    """
    def __init__(self, d: int, base: int = 10_000):
        """
        * `d` is the number of features $d$
        * `base` is the constant used for calculating $\Theta$
        """
        super().__init__()

        self.base = base
        self.d = d
        self.cos_cached = None
        self.sin_cached = None

    def _build_cache(self, x: torch.Tensor):
        """
        Cache $\cos$ and $\sin$ values
        """
        # Return if cache is already built
        if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
            return

        # Get sequence length
        seq_len = x.shape[0]

        # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
        theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)  # (dim // 2, )

        # Create position indexes `[0, 1, ..., seq_len - 1]`
        seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)  # (seq_len, )

        # Calculate the product of position index and $\theta_i$
        idx_theta = torch.einsum('n,d->nd', seq_idx, theta)   # (seq_len, dim // 2)

        # Concatenate so that for row $m$ we have
        # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)   # (seq_len, dim)

        # Cache them
        self.cos_cached = idx_theta2.cos()[:, None, None, :]   # (seq_len, 1, 1, dim)
        self.sin_cached = idx_theta2.sin()[:, None, None, :]

    def _neg_half(self, x: torch.Tensor):
        # $\frac{d}{2}$
        d_2 = self.d // 2

        # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
        return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)

    def forward(self, x: torch.Tensor):
        """
        * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
        """
        # Cache $\cos$ and $\sin$ values
        self._build_cache(x)

        # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
        x_rope, x_pass = x[..., :self.d], x[..., self.d:]

        # Calculate
        # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
        neg_half_x = self._neg_half(x_rope)

        # Calculate
        #
        # \begin{align}
        # \begin{pmatrix}
        # x^{(i)}_m \cos m \theta_i - x^{(i + \frac{d}{2})}_m \sin m \theta_i \\
        # x^{(i + \frac{d}{2})}_m \cos m\theta_i + x^{(i)}_m \sin m \theta_i \\
        # \end{pmatrix} \\
        # \end{align}
        #
        # for $i \in {1, 2, ..., \frac{d}{2}}$
        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])

        #
        return torch.cat((x_rope, x_pass), dim=-1)


class RotaryPEMultiHeadAttention(MultiHeadAttention):
    """
    ## Multi-head attention with rotary positional embeddings

    We override [multi-head attention from original transformer](../mha.html).
    """

    def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
        super().__init__(heads, d_model, dropout_prob)

        # Rotary positional embedding layers
        d_rope = int(self.d_k * rope_percentage)
        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)

    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
        """
        ### Calculate scores between queries and keys
        """

        # Calculate dot-product with RoPE
        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))


def _test_rotary():
    """
    Testing RoPE with a simple example
    """
    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)  # (seq_len, dim)
    x = x[:, None, None, :]    # (seq_len, batch_size, num_head, dim)
    # inspect(x)

    rotary_pe = RotaryPositionalEmbeddings(4)
    rotary_pe(x)
    # inspect(rotary_pe(x))


if __name__ == '__main__':
    _test_rotary()

针对forward(),我们来理解一下它每一步在干什么。

  • self._build_cache(x):缓存 cos ⁡ m θ i \cos m\theta_i cosmθi sin ⁡ m θ i \sin m\theta_i sinmθi它们的shape都为(seq_len, 1, 1, dim),具体来说,对于位置 m m m来说,余弦值的具体排列顺序为 [ cos ⁡ m θ 0 , cos ⁡ m θ 1 , . . . , m θ d 2 − 1 , cos ⁡ m θ 0 , cos ⁡ m θ 1 , . . . , cos ⁡ m θ d 2 − 1 ] [\cos m \theta_0, \cos m \theta_1, ..., m \theta_{\frac{d}{2}-1}, \cos m \theta_0, \cos m \theta_1, ..., \cos m \theta_{\frac{d}{2} - 1}] [cosmθ0,cosmθ1,...,mθ2d1,cosmθ0,cosmθ1,...,cosmθ2d1],正弦值同理;
  • x_rope, x_pass = x[..., :self.d], x[..., self.d:]:这边的x就是key或者query,shape为(seq_len, 1, 1, dim),分成两部分,前一部分用于施加位置编码,后一部分不动。但正常来说后一部分是空的,也就是所有特征维度都参与位置编码;
  • neg_half_x = self._neg_half(x_rope):对特征维度做了重排,这是为了方便后面位置编码的计算,具体为: [ − x d 2 , − x ( d 2 + 1 ) , . . . , − x ( d − 1 ) , x ( 0 ) , x ( 1 ) , . . . , x ( d 2 − 1 ) ] [-x^{\frac{d}{2}}, -x^{(\frac{d}{2} + 1)}, ..., -x^{(d-1)}, x^{(0)}, x^{(1)}, ..., x^{(\frac{d}{2} - 1)}] [x2d,x(2d+1),...,x(d1),x(0),x(1),...,x(2d1)]
  • x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]):逐位相乘,shape为(seq_len, 1, 1, dim),具体计算为如下形式:
    在这里插入图片描述

二、chatglm-6b

这部分参考了博客1博客2,需要注意的是,相比于原始论文中RoPE的实现方式,chatglm-6bq的组合方式是不同的,和RoPE概述部分一致。在代码片段中,我添加了较多的注释。
RoPE原始实现
RoPE在chatglm-6b中的实现

# rope_chatglm_6b.py
# 源码:https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py
# 解读参考:https://blog.csdn.net/qq_41496421/article/details/139013470

import torch
import torch.nn.functional as F

class RotaryEmbedding(torch.nn.Module):
    def __init__(self, 
                 dim,                    # embedding维度
                 base=10000,             # 基底
                 precision=torch.half,   # 精度
                 learnable=False):       # 是否可学习
        super().__init__()
        # 构建inv_freq,也即是公式中的 theta_i = 1 / (10000 ^ (2i / d))
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))   # shape: (dim // 2)
        inv_freq = inv_freq.half()
        self.learnable = learnable
        if learnable:
            self.inv_freq = torch.nn.Parameter(inv_freq)
            self.max_seq_len_cached = None
        else:
            # cos和sin值都是缓存起来的
            self.register_buffer('inv_freq', inv_freq)
            self.max_seq_len_cached = None
            self.cos_cached = None
            self.sin_cached = None
        self.precision = precision

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
                              error_msgs):
        pass

    def forward(self, 
                x,              # 输入
                seq_dim=1,      # 在输入x中,seq在哪一个维度
                seq_len=None):  # seq的长度
        if seq_len is None:
            seq_len = x.shape[seq_dim]
        
        # 只有当最大缓存序列长度开始为None,或者小于seq_len时更新,否则取缓存值
        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
            
            self.max_seq_len_cached = None if self.learnable else seq_len
            # 构建公式中的m,也就是每个token的位置index
            t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)

            # 构建频率:m * theta_i,shape: (seq_len) * (dim // 2) -> (seq_len, dim // 2)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            
            # (见解读参考,和原始rope实现不同)cat频率,shape: (seq_len, dim)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            if self.precision == torch.bfloat16:
                emb = emb.float()

            # 获取cos和sin值
            # [sx, 1 (b * np), hn]
            cos_cached = emb.cos()[:, None, :]    # shape: (seq_len, 1, dim)
            sin_cached = emb.sin()[:, None, :]   
            if self.precision == torch.bfloat16:
                cos_cached = cos_cached.bfloat16()
                sin_cached = sin_cached.bfloat16()
            if self.learnable:
                return cos_cached, sin_cached
            self.cos_cached, self.sin_cached = cos_cached, sin_cached
        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]

    def _apply(self, fn):
        """
        使用model.half()时,会调用module._apply(fn)
        """
        if self.cos_cached is not None:
            self.cos_cached = fn(self.cos_cached)
        if self.sin_cached is not None:
            self.sin_cached = fn(self.sin_cached)
        return super()._apply(fn)


def rotate_half(x):
    """
    构造与sin(m * theta_i) 逐位相乘的部分
    """
    x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=x1.ndim - 1)  # dim=-1 triggers a bug in earlier torch versions


@torch.jit.script
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
    # sq: seq_length, b: batch_size, np: 注意力头数量, hn: 每个头上的维度
    # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
    cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
        F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
    q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
    return q, k

另外,chatglm-6b使用的是2维位置编码,并在SelfAttention中体现,如这篇文章所示。这边就不再赘述了。

三、glm-4-9b-chat

此处的实现与chatglm-6b中逻辑一致,变量名有些许变化。不同之处:

  • 在base的基础上乘以rope_ratio
  • cossin数据做了拼接;
  • 不再使用二维位置编码,具体对比两者的SelfAttention实现就一目了然了(注意力相关的部分会单独写一篇)。
# rope_glm_4_9b_chat.py
import torch
from torch import nn


class RotaryEmbedding(nn.Module):
    def __init__(self, 
                 dim,                   # embedding维度
                 rope_ratio=1,          # rope比率
                 original_impl=False,   # 是否是原始实现
                 device=None,           # 设备
                 dtype=None):           # 数据类型
        super().__init__()
        # 构建theta_i
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.dim = dim
        self.original_impl = original_impl
        self.rope_ratio = rope_ratio

    def forward_impl(
            self, 
            seq_len: int,               # 序列长度
            n_elem: int,                # embedding维度
            dtype: torch.dtype, 
            device: torch.device, 
            base: int = 10000           # 公式中的基底
    ):
        """Enhanced Transformer with Rotary Position Embedding.

        Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
        transformers/rope/__init__.py. MIT License:
        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
        """
        # 构建theta_i,实现与chatglm_6b一致,只不过添加了一个scale
        # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
        base = base * self.rope_ratio
        theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))

        # 构建(位置)索引,公式中 m
        # Create position indexes `[0, 1, ..., seq_len - 1]`
        seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)

        # 外积,构建频率freqs: m * theta_i,shape: (seq_len) * (dim // 2) -> (seq_len, dim // 2)
        # Calculate the product of position index and $\theta_i$
        idx_theta = torch.outer(seq_idx, theta).float()

        # cos和sin拼接,作为cache,shape: (seq_len, dim // 2, 2)
        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

        # this is to mimic the behaviour of complex32, else we will get different results
        if dtype in (torch.float16, torch.bfloat16, torch.int8):
            cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
        return cache

    def forward(self, max_seq_len, offset=0):
        return self.forward_impl(
            max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
        )

# @torch.jit.script    # 为了debug,这一行应该注释
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
    """
    x: query或者key
    shape: (batch_size, num_head or num_group, seq_len, dim_per_head),
    以query为例就是(1, 32, 8, 128)
    
    rope_cache: 此前获取的旋转位置编码 cos & sin
    shape: (batch_size, seq_len, rotary_dim // 2, 2),最后一个2是因为有cos和sin两个数据,
    典型的shape: (1, 8, 32, 2)
    """
    # x: [b, np, sq, hn]
    b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3)

    # 获取rotary_dim=64,根据dim_per_head=128,将x分为两部分,shape都为(1, 32, 8, 64)
    # 这边和RoPE概述中是一样的操作,只不过rot_dim是dim_per_head的一半,所以存在不需要位置编码的维度x_pass
    # 根据参考博客所说:glm原来是2维位置编码,glm4取消二维,所以只取前面的64维(也就是x)做操作
    rot_dim = rope_cache.shape[-2] * 2
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
    # truncate to support variable sizes
    rope_cache = rope_cache[:, :sq]
    
    xshaped = x.reshape(b, np, sq, rot_dim // 2, 2)   # (1, 32, 8, 32, 2)
    rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2)   
    # rope_cache: (1, 1, 8, 32, 2) 即(batch_size, 1, seq_len, rotary_dim // 2, 2)
    
    # 计算 q_0 * cos(m * theta_0) - q_1 * sin(m * theta_0) 这些
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )    # shape: (1, 32, 8, 32, 2)
    x_out2 = x_out2.flatten(3)      # (1, 32, 8, 64)
    return torch.cat((x_out2, x_pass), dim=-1)   # (1, 32, 8, 128)


if __name__ == '__main__':
    # glm_4_9b_chat部分配置
    config_glm_4_9b_chat = {
        "seq_length": 131072,
        "hidden_size": 4096,
        "num_attention_heads": 32,
        "kv_channels": 128,
        "original_rope": True,
        "rope_ratio": 500,
        "device": "cuda",
        "torch_dtype": torch.bfloat16
    }
    # 测试使用的配置
    config = {
        "seq_length": 256,
        "hidden_size": 128,
        "num_attention_heads": 4,
        "kv_channels": 32,
        "original_rope": True,
        "rope_ratio": 500,
        "device": "cpu",
        "torch_dtype": torch.float32
    }
    seq_length = config["seq_length"]
    rotary_dim = config["hidden_size"] // config["num_attention_heads"] \
        if config["kv_channels"] is None else config["kv_channels"]   # 32
    net = RotaryEmbedding(dim=rotary_dim // 2,
                          rope_ratio=config["rope_ratio"],
                          original_impl=config["original_rope"],
                          device=config["device"],
                          dtype=config["torch_dtype"])
    net(128)
    print("done")


总结

本篇首先对旋转位置编码RoPE做了一个简述,然后对照分析了chatglm-6bglm-4-9b-chat的实现,希望由此对RoPE有一个更加深入的认识。

  • 25
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值