系列文章目录
GLM-4 (1) - 推理+概览
GLM-4 (2) - RoPE
GLM-4 (3) - GLMBlock
GLM-4 (4) - SelfAttention
GLM-4 (5) - API & Function Calling
前言
上一篇对GLM-4
模型推理等做了一些记录,但是未涉及网络模型部分的细节。本篇分析一下旋转位置编码RoPE
。网上关于理论部分的博客较多,因此本篇不会在理论部分花太大的篇幅,主要还是从代码入手。原始论文在此,目前更关心理论部分的可移步其他教学博客。
一、RoPE概述
首先我们对RoPE
做一个简单概述;然后,为了更好的理解glm系列的位置编码实现,我们先分析chatglm-6b
源码,然后转到glm-4-9b-chat
上来,因为两者之间一脉相承,又有一些区别。
这部分摘自labmlai。以query
为例,说明RoPE
是如何为它添加位置编码的。假设query
的特征维度是
d
d
d,那么它们将组成
d
/
2
d/2
d/2个特征对。在序列位置
m
m
m处,query
的前两个特征就是
x
m
(
0
)
x^{(0)}_m
xm(0)和
x
m
(
1
)
x^{(1)}_m
xm(1),做位置编码如下(
θ
\theta
θ这里可以认为是常量,实际和所在维度的索引有关):
计算位置为
m
m
m和
n
n
n的点击注意力分数则有:
上述推导表明,对于点积注意力,旋转位置编码显示了相对位置编码的性质。上述推导是以简单的二维情形,扩展到多维(偶数),并将特征维度
i
i
i与
i
+
d
2
i + \frac{d}{2}
i+2d配对(注意,这边与论文中是不同的,论文中是特征维度
i
i
i与
i
+
1
i + 1
i+1配对),添加位置编码则表示如下:
代码如下:
# https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/rope/__init__.py
"""
---
title: Rotary Positional Embeddings (RoPE)
summary: >
Annotated implementation of RoPE from paper
RoFormer: Enhanced Transformer with Rotary Position Embedding
---
# Rotary Positional Embeddings (RoPE)
This is an implementation of
[Rotary Positional Embeddings (RoPE)](https://arxiv.org/abs/2104.09864)
in [PyTorch](https://pytorch.org).
Rotary Positional Embeddings (RoPE) encode position information of tokens
with a rotation matrix that naturally incorporates explicit relative position
dependency.
Here's [the training code](experiment.html) for training a transformer model with RoPE
on Tiny Shakespeare dataset.
"""
import torch
from torch import nn
# from labml.logger import inspect
# from labml_nn.transformers.mha import MultiHeadAttention
from mha import MultiHeadAttention
class RotaryPositionalEmbeddings(nn.Module):
"""
## RoPE module
"""
def __init__(self, d: int, base: int = 10_000):
"""
* `d` is the number of features $d$
* `base` is the constant used for calculating $\Theta$
"""
super().__init__()
self.base = base
self.d = d
self.cos_cached = None
self.sin_cached = None
def _build_cache(self, x: torch.Tensor):
"""
Cache $\cos$ and $\sin$ values
"""
# Return if cache is already built
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
return
# Get sequence length
seq_len = x.shape[0]
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) # (dim // 2, )
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) # (seq_len, )
# Calculate the product of position index and $\theta_i$
idx_theta = torch.einsum('n,d->nd', seq_idx, theta) # (seq_len, dim // 2)
# Concatenate so that for row $m$ we have
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) # (seq_len, dim)
# Cache them
self.cos_cached = idx_theta2.cos()[:, None, None, :] # (seq_len, 1, 1, dim)
self.sin_cached = idx_theta2.sin()[:, None, None, :]
def _neg_half(self, x: torch.Tensor):
# $\frac{d}{2}$
d_2 = self.d // 2
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
def forward(self, x: torch.Tensor):
"""
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
"""
# Cache $\cos$ and $\sin$ values
self._build_cache(x)
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
# Calculate
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
neg_half_x = self._neg_half(x_rope)
# Calculate
#
# \begin{align}
# \begin{pmatrix}
# x^{(i)}_m \cos m \theta_i - x^{(i + \frac{d}{2})}_m \sin m \theta_i \\
# x^{(i + \frac{d}{2})}_m \cos m\theta_i + x^{(i)}_m \sin m \theta_i \\
# \end{pmatrix} \\
# \end{align}
#
# for $i \in {1, 2, ..., \frac{d}{2}}$
x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
#
return torch.cat((x_rope, x_pass), dim=-1)
class RotaryPEMultiHeadAttention(MultiHeadAttention):
"""
## Multi-head attention with rotary positional embeddings
We override [multi-head attention from original transformer](../mha.html).
"""
def __init__(self, heads: int, d_model: int, rope_percentage: float = 0.5, dropout_prob: float = 0.0):
super().__init__(heads, d_model, dropout_prob)
# Rotary positional embedding layers
d_rope = int(self.d_k * rope_percentage)
self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
def get_scores(self, query: torch.Tensor, key: torch.Tensor):
"""
### Calculate scores between queries and keys
"""
# Calculate dot-product with RoPE
return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))
def _test_rotary():
"""
Testing RoPE with a simple example
"""
x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float) # (seq_len, dim)
x = x[:, None, None, :] # (seq_len, batch_size, num_head, dim)
# inspect(x)
rotary_pe = RotaryPositionalEmbeddings(4)
rotary_pe(x)
# inspect(rotary_pe(x))
if __name__ == '__main__':
_test_rotary()
针对forward()
,我们来理解一下它每一步在干什么。
self._build_cache(x)
:缓存 cos m θ i \cos m\theta_i cosmθi和 sin m θ i \sin m\theta_i sinmθi它们的shape都为(seq_len, 1, 1, dim)
,具体来说,对于位置 m m m来说,余弦值的具体排列顺序为 [ cos m θ 0 , cos m θ 1 , . . . , m θ d 2 − 1 , cos m θ 0 , cos m θ 1 , . . . , cos m θ d 2 − 1 ] [\cos m \theta_0, \cos m \theta_1, ..., m \theta_{\frac{d}{2}-1}, \cos m \theta_0, \cos m \theta_1, ..., \cos m \theta_{\frac{d}{2} - 1}] [cosmθ0,cosmθ1,...,mθ2d−1,cosmθ0,cosmθ1,...,cosmθ2d−1],正弦值同理;x_rope, x_pass = x[..., :self.d], x[..., self.d:]
:这边的x
就是key
或者query
,shape为(seq_len, 1, 1, dim)
,分成两部分,前一部分用于施加位置编码,后一部分不动。但正常来说后一部分是空的,也就是所有特征维度都参与位置编码;neg_half_x = self._neg_half(x_rope)
:对特征维度做了重排,这是为了方便后面位置编码的计算,具体为: [ − x d 2 , − x ( d 2 + 1 ) , . . . , − x ( d − 1 ) , x ( 0 ) , x ( 1 ) , . . . , x ( d 2 − 1 ) ] [-x^{\frac{d}{2}}, -x^{(\frac{d}{2} + 1)}, ..., -x^{(d-1)}, x^{(0)}, x^{(1)}, ..., x^{(\frac{d}{2} - 1)}] [−x2d,−x(2d+1),...,−x(d−1),x(0),x(1),...,x(2d−1)]x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
:逐位相乘,shape为(seq_len, 1, 1, dim)
,具体计算为如下形式:
二、chatglm-6b
这部分参考了博客1和博客2,需要注意的是,相比于原始论文中RoPE
的实现方式,chatglm-6b
中q
的组合方式是不同的,和RoPE
概述部分一致。在代码片段中,我添加了较多的注释。
# rope_chatglm_6b.py
# 源码:https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py
# 解读参考:https://blog.csdn.net/qq_41496421/article/details/139013470
import torch
import torch.nn.functional as F
class RotaryEmbedding(torch.nn.Module):
def __init__(self,
dim, # embedding维度
base=10000, # 基底
precision=torch.half, # 精度
learnable=False): # 是否可学习
super().__init__()
# 构建inv_freq,也即是公式中的 theta_i = 1 / (10000 ^ (2i / d))
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) # shape: (dim // 2)
inv_freq = inv_freq.half()
self.learnable = learnable
if learnable:
self.inv_freq = torch.nn.Parameter(inv_freq)
self.max_seq_len_cached = None
else:
# cos和sin值都是缓存起来的
self.register_buffer('inv_freq', inv_freq)
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
pass
def forward(self,
x, # 输入
seq_dim=1, # 在输入x中,seq在哪一个维度
seq_len=None): # seq的长度
if seq_len is None:
seq_len = x.shape[seq_dim]
# 只有当最大缓存序列长度开始为None,或者小于seq_len时更新,否则取缓存值
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = None if self.learnable else seq_len
# 构建公式中的m,也就是每个token的位置index
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
# 构建频率:m * theta_i,shape: (seq_len) * (dim // 2) -> (seq_len, dim // 2)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# (见解读参考,和原始rope实现不同)cat频率,shape: (seq_len, dim)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float()
# 获取cos和sin值
# [sx, 1 (b * np), hn]
cos_cached = emb.cos()[:, None, :] # shape: (seq_len, 1, dim)
sin_cached = emb.sin()[:, None, :]
if self.precision == torch.bfloat16:
cos_cached = cos_cached.bfloat16()
sin_cached = sin_cached.bfloat16()
if self.learnable:
return cos_cached, sin_cached
self.cos_cached, self.sin_cached = cos_cached, sin_cached
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
def _apply(self, fn):
"""
使用model.half()时,会调用module._apply(fn)
"""
if self.cos_cached is not None:
self.cos_cached = fn(self.cos_cached)
if self.sin_cached is not None:
self.sin_cached = fn(self.sin_cached)
return super()._apply(fn)
def rotate_half(x):
"""
构造与sin(m * theta_i) 逐位相乘的部分
"""
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
@torch.jit.script
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
# sq: seq_length, b: batch_size, np: 注意力头数量, hn: 每个头上的维度
# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
return q, k
另外,chatglm-6b
使用的是2维位置编码,并在SelfAttention
中体现,如这篇文章所示。这边就不再赘述了。
三、glm-4-9b-chat
此处的实现与chatglm-6b
中逻辑一致,变量名有些许变化。不同之处:
- 在base的基础上乘以
rope_ratio
; cos
和sin
数据做了拼接;- 不再使用二维位置编码,具体对比两者的
SelfAttention
实现就一目了然了(注意力相关的部分会单独写一篇)。
# rope_glm_4_9b_chat.py
import torch
from torch import nn
class RotaryEmbedding(nn.Module):
def __init__(self,
dim, # embedding维度
rope_ratio=1, # rope比率
original_impl=False, # 是否是原始实现
device=None, # 设备
dtype=None): # 数据类型
super().__init__()
# 构建theta_i
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
self.register_buffer("inv_freq", inv_freq)
self.dim = dim
self.original_impl = original_impl
self.rope_ratio = rope_ratio
def forward_impl(
self,
seq_len: int, # 序列长度
n_elem: int, # embedding维度
dtype: torch.dtype,
device: torch.device,
base: int = 10000 # 公式中的基底
):
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
# 构建theta_i,实现与chatglm_6b一致,只不过添加了一个scale
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
base = base * self.rope_ratio
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
# 构建(位置)索引,公式中 m
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
# 外积,构建频率freqs: m * theta_i,shape: (seq_len) * (dim // 2) -> (seq_len, dim // 2)
# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta).float()
# cos和sin拼接,作为cache,shape: (seq_len, dim // 2, 2)
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
# this is to mimic the behaviour of complex32, else we will get different results
if dtype in (torch.float16, torch.bfloat16, torch.int8):
cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
return cache
def forward(self, max_seq_len, offset=0):
return self.forward_impl(
max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
)
# @torch.jit.script # 为了debug,这一行应该注释
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
"""
x: query或者key
shape: (batch_size, num_head or num_group, seq_len, dim_per_head),
以query为例就是(1, 32, 8, 128)
rope_cache: 此前获取的旋转位置编码 cos & sin
shape: (batch_size, seq_len, rotary_dim // 2, 2),最后一个2是因为有cos和sin两个数据,
典型的shape: (1, 8, 32, 2)
"""
# x: [b, np, sq, hn]
b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3)
# 获取rotary_dim=64,根据dim_per_head=128,将x分为两部分,shape都为(1, 32, 8, 64)
# 这边和RoPE概述中是一样的操作,只不过rot_dim是dim_per_head的一半,所以存在不需要位置编码的维度x_pass
# 根据参考博客所说:glm原来是2维位置编码,glm4取消二维,所以只取前面的64维(也就是x)做操作
rot_dim = rope_cache.shape[-2] * 2
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
# truncate to support variable sizes
rope_cache = rope_cache[:, :sq]
xshaped = x.reshape(b, np, sq, rot_dim // 2, 2) # (1, 32, 8, 32, 2)
rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2)
# rope_cache: (1, 1, 8, 32, 2) 即(batch_size, 1, seq_len, rotary_dim // 2, 2)
# 计算 q_0 * cos(m * theta_0) - q_1 * sin(m * theta_0) 这些
x_out2 = torch.stack(
[
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
],
-1,
) # shape: (1, 32, 8, 32, 2)
x_out2 = x_out2.flatten(3) # (1, 32, 8, 64)
return torch.cat((x_out2, x_pass), dim=-1) # (1, 32, 8, 128)
if __name__ == '__main__':
# glm_4_9b_chat部分配置
config_glm_4_9b_chat = {
"seq_length": 131072,
"hidden_size": 4096,
"num_attention_heads": 32,
"kv_channels": 128,
"original_rope": True,
"rope_ratio": 500,
"device": "cuda",
"torch_dtype": torch.bfloat16
}
# 测试使用的配置
config = {
"seq_length": 256,
"hidden_size": 128,
"num_attention_heads": 4,
"kv_channels": 32,
"original_rope": True,
"rope_ratio": 500,
"device": "cpu",
"torch_dtype": torch.float32
}
seq_length = config["seq_length"]
rotary_dim = config["hidden_size"] // config["num_attention_heads"] \
if config["kv_channels"] is None else config["kv_channels"] # 32
net = RotaryEmbedding(dim=rotary_dim // 2,
rope_ratio=config["rope_ratio"],
original_impl=config["original_rope"],
device=config["device"],
dtype=config["torch_dtype"])
net(128)
print("done")
总结
本篇首先对旋转位置编码RoPE
做了一个简述,然后对照分析了chatglm-6b
和glm-4-9b-chat
的实现,希望由此对RoPE
有一个更加深入的认识。