今天学习用Rust来实现开源 LLM代表LLaMA模型。 本次使用的是karpathy/llama2.c: Inference Llama 2 in one file of pure C 的 Rust 实现的版本中的:danielgrittner/llama2-rs: LLaMA2 + Rust。仅涉及推理部份。
配置
struct Config {
dim: usize, // transformer dimension
hidden_dim: usize, // for ffn layers
n_layers: usize, // number of layers
n_heads: usize, // number of query heads
head_size: usize, // size of each head (dim / n_heads)
n_kv_heads: usize, // number of key/value heads
shared_weights: bool,
vocab_size: usize, // vocabulary size
seq_len: usize, // max. sequence length
}
在上述代码中,我们定义了一个名为 Config
的结构体(struct),用于表示某种配置信息。结构体包含了多个字段,每个字段都有对应的字段名和类型注释。
dim: usize
:transformer 的维度。hidden_dim: usize
:用于 ffn 层(feed-forward network,前馈神经网络)的隐藏层维度。n_layers: usize
:层数。n_heads: usize
:查询头的数量。head_size: usize
:每个查询头的大小(dim / n_heads
)。n_kv_heads: usize
:键/值头的数量。shared_weights: bool
:指示是否使用共享权重。vocab_size: usize
:词汇表大小。seq_len: usize
:最大序列长度。
该结构体定义了一个用于存储具有不同配置信息的对象。通过创建 Config
的实例,并为每个字段提供适当的值,我们可以在代码中使用配置对象来管理和传递相关的设置。
dim
就是上面一直说的 Dim,hidden_dim
仅在 FFN 层,因为 FFN 层需要先扩大再缩小。n_heads
和 n_kv_heads
是 Query 的 Head 数和 KV 的 Head 数,简单起见可以认为它们是相等的。
参数
struct TransformerWeights {
// Token Embedding Table
token_embedding_table: Vec<f32>, // (vocab_size, dim)
// Weights for RMSNorm
rms_att_weight: Vec<f32>, // (layer, dim)
rms_ffn_weight: Vec<f32>, // (layer, dim)
// Weights for matmuls in attn
wq: Vec<f32>, // (layer, dim, dim)
wk: Vec<f32>, // (layer, dim, dim)
wv: Vec<f32>, // (layer, dim, dim)
wo: Vec<f32>, // (layer, dim, dim)
// Weights for ffn
w1: Vec<f32>, // (layer, hidden_dim, dim)
w2: Vec<f32>, // (layer, dim, hidden_dim)
w3: Vec<f32>, // (layer, hidden_dim, dim)
// final RMSNorm
rms_final_weights: Vec<f32>, // (dim)
// freq_cis for RoPE relatively positional embeddings
freq_cis_real: Vec<f32>, // (seq_len, head_size/2)
freq_cis_imag: Vec<f32>, // (seq_len, head_size/2)
// (optional) classifier weights for the logits, on the last layer
wcls: Vec<f32>, // (vocab_size, dim)
}
上述代码定义了一个名为 TransformerWeights
的结构体(struct),用于存储一组转换器(transformer)的权重值。结构体包含了多个字段,每个字段都指定了对应的字段名和类型。
token_embedding_table: Vec<f32>
:用于存储标记嵌入表的权重值,类型为Vec<f32>
。表的形状为(vocab_size, dim)
,其中vocab_size
代表词汇表大小,dim
代表维度。rms_att_weight: Vec<f32>
:用于存储 RMSNorm 中注意力权重的值,类型为Vec<f32>
。形状为(layer, dim)
,其中layer
代表层数,dim
代表维度。rms_ffn_weight: Vec<f32>
:用于存储 RMSNorm 中前馈神经网络权重的值,类型为Vec<f32>
。形状为(layer, dim)
,其中layer
代表层数,dim
代表维度。wq: Vec<f32>
、wk: Vec<f32>
、wv: Vec<f32>
、wo: Vec<f32>
:用于存储注意力机制中矩阵相乘操作(matmuls)的权重值,类型均为Vec<f32>
。它们的形状为(layer, dim, dim)
,其中layer
代表层数,dim
代表维度。w1: Vec<f32>
、w2: Vec<f32>
、w3: Vec<f32>
:用于存储前馈神经网络中的权重值,类型均为Vec<f32>
。它们的形状为(layer, hidden_dim, dim)
或(layer, dim, hidden_dim)
,其中layer
代表层数,dim
代表维度,hidden_dim
代表隐藏层维度。rms_final_weights: Vec<f32>
:用于存储最终的 RMSNorm 权重值,类型为Vec<f32>
。形状为(dim)
,其中dim
代表维度。freq_cis_real: Vec<f32>
和freq_cis_imag: Vec<f32>
:用于相对位置编码的频率系数(RoPE)的实部和虚部的权重值,类型均为Vec<f32>
。它们的形状为(seq_len, head_size/2)
,其中seq_len
代表序列长度,head_size
代表每个注意力头的大小。wcls: Vec<f32>
:(可选)用于最后一层的 logits 分类器的权重值,类型为Vec<f32>
。形状为(vocab_size, dim)
,其中vocab_size
代表词汇表大小,dim
代表维度。
该结构体定义了一个用于存储转换器权重值的对象。通过创建 TransformerWeights
的实例,并为每个字段提供适当的值,我们可以在代码中使用该对象来管理和传递相关的权重。
其中:freq_
开头的两个参数,它们是和位置编码有关的参数,也就是说,我们每次生成一个 Token 时,都需要传入当前位置的位置信息。位置编码在 Transformer 中是比较重要的,因为 Self Attention 本质上是无序的,而语言的先后顺序在有些时候是很重要的
加载参数
fn byte_chunk_to_vec<T>(byte_chunk: &[u8], number_elements: usize) -> Vec<T>
where
T: Clone,
{
unsafe {
// 获取起始位置的原始指针
let data = byte_chunk.as_ptr() as *const T;
// 从原始指针创建一个 T 类型的切片,注意number_elements是element的数量,而不是bytes
// 这句是 unsafe 的
let slice_data: &[T] = std::slice::from_raw_parts(data, number_elements);
// 将切片转为 Vec,需要 T 可以 Clone
slice_data.to_vec()
}
}
fn byte_chunk_to_vec<T>(byte_chunk: &[u8], number_elements: usize) -> Vec<T>
:定义了一个名为byte_chunk_to_vec
的函数,以泛型T
作为类型参数。它接受一个byte_chunk
字节切片和一个number_elements
元素数量作为参数,并返回一个Vec<T>
向量。where T: Clone
:使用where
关键字,指定T
必须实现Clone
trait,以便能够克隆值。这是为了在转换切片到向量时,能够复制每个元素。unsafe
:标记块中的代码为不安全的代码。let data = byte_chunk.as_ptr() as *const T;
:将byte_chunk
的指针转换为T
类型的常量指针,并将其赋值给data
变量。这允许我们在指针级别对字节进行操作。let slice_data: &[T] = std::slice::from_raw_parts(data, number_elements);
:使用from_raw_parts
方法,根据起始指针data
和元素数量number_elements
创建一个T
类型的切片slice_data
。请注意,这实际上并不执行数据复制。slice_data.to_vec()
:将切片slice_data
转换为Vec<T>
向量。此方法将从切片中复制每个元素,并创建一个新的向量。- 最后,函数返回转换后的向量。
这段代码的目的是将一个字节切片转换为元素类型为 T
的向量。但由于涉及到指针操作和不安全的代码,因此要特别小心使用。where T: Clone
约束确保元素类型 T
可以被克隆来进行复制操作。byte_chunk
表示原始的字节切片,number_elements
表示结果向量中元素的个数。
unsafe的用法
-
获取原始指针:代码中使用
byte_chunk.as_ptr()
方法获取byte_chunk
字节切片的原始指针,然后将其转换为T
类型的常量指针。这个操作涉及到底层的指针操作,访问和操作内存的原始指针需要使用unsafe
代码块。 -
从原始指针创建切片:使用
std::slice::from_raw_parts()
方法根据原始指针和元素数量创建了一个T
类型的切片。这个方法也需要使用unsafe
代码块,因为它接触到了指针和内存操作。
需要注意的是,使用 unsafe
关键字会打开 Rust 中的一些安全性限制。在使用 unsafe
代码块时,需要确保代码正确地处理了指针和内存操作,以避免造成内存安全和未定义行为。
在这段代码中使用 unsafe
主要是为了利用底层的指针操作和内存操作,来直接操作原始数据。这样可以避免不必要的数据复制,并提高性能。但同时,需要非常小心,确保代码在使用指针和访问内存时不会引发潜在的错误。
as *const T的用法
在 let data = byte_chunk.as_ptr() as *const T 这段代码中的 * 运算符是指针类型转换中的解引用运算符,访问指针所指向的数据。
在这段代码中,byte_chunk.as_ptr() 返回了 byte_chunk 字节切片的原始指针(raw pointer),然后使用 as *const T 将其转换为 *const T 类型的常量指针,以便与类型 T 相匹配。
需要注意的是,在这个上下文中的 * 运算符并不是乘法运算符,它是指针类型转换语法的一部分。解引用运算符不能在 Safe Rust 中直接使用,因为它是一个不安全操作,需要使用 unsafe 关键字包裹。例如,let value = *ptr; 将会将 value 变量设置为指针 ptr 所指向位置的值。
加载模型
读取原始的 bin 文件并指定对应的参数大小
et token_embedding_table_size = config.vocab_size * config.dim;
// offset.. 表示从 offset 往后的所有元素
let token_embedding_table: Vec<f32> = byte_chunk_to_vec(&mmap[offset..], token_embedding_table_size);
- 这行代码计算了
token_embedding_table
的大小,即词嵌入表的大小。它将config.vocab_size
(词汇表大小)乘以config.dim
(维度),并将结果赋值给变量token_embedding_table_size
。 - 这行代码创建了一个名为
token_embedding_table
的Vec<f32>
向量。它调用了byte_chunk_to_vec
函数,并传递了从mmap[offset..]
中的指定偏移位置开始的字节切片,以及token_embedding_table_size
作为参数。
因为向量数据通常取决于从某个字节切片转换而来,所以需要借助 byte_chunk_to_vec
函数来执行转换。
简单的内容就到这里了,后面上硬菜,未完待续~~