RUST学习笔记(Day 3)

今天学习用Rust来实现开源 LLM代表LLaMA模型。 本次使用的是karpathy/llama2.c: Inference Llama 2 in one file of pure C 的 Rust 实现的版本中的:danielgrittner/llama2-rs: LLaMA2 + Rust。仅涉及推理部份。

配置

struct Config {
    dim: usize,        // transformer dimension
    hidden_dim: usize, // for ffn layers
    n_layers: usize,   // number of layers
    n_heads: usize,    // number of query heads
    head_size: usize,  // size of each head (dim / n_heads)
    n_kv_heads: usize, // number of key/value heads
    shared_weights: bool,
    vocab_size: usize, // vocabulary size
    seq_len: usize,    // max. sequence length
}

在上述代码中,我们定义了一个名为 Config 的结构体(struct),用于表示某种配置信息。结构体包含了多个字段,每个字段都有对应的字段名和类型注释。

  • dim: usize:transformer 的维度。
  • hidden_dim: usize:用于 ffn 层(feed-forward network,前馈神经网络)的隐藏层维度。
  • n_layers: usize:层数。
  • n_heads: usize:查询头的数量。
  • head_size: usize:每个查询头的大小(dim / n_heads)。
  • n_kv_heads: usize:键/值头的数量。
  • shared_weights: bool:指示是否使用共享权重。
  • vocab_size: usize:词汇表大小。
  • seq_len: usize:最大序列长度。

该结构体定义了一个用于存储具有不同配置信息的对象。通过创建 Config 的实例,并为每个字段提供适当的值,我们可以在代码中使用配置对象来管理和传递相关的设置。

dim 就是上面一直说的 Dim,hidden_dim 仅在 FFN 层,因为 FFN 层需要先扩大再缩小。n_heads 和 n_kv_heads 是 Query 的 Head 数和 KV 的 Head 数,简单起见可以认为它们是相等的。

参数

struct TransformerWeights {
    // Token Embedding Table
    token_embedding_table: Vec<f32>, // (vocab_size, dim)
    // Weights for RMSNorm
    rms_att_weight: Vec<f32>, // (layer, dim)
    rms_ffn_weight: Vec<f32>, // (layer, dim)
    // Weights for matmuls in attn
    wq: Vec<f32>, // (layer, dim, dim)
    wk: Vec<f32>, // (layer, dim, dim)
    wv: Vec<f32>, // (layer, dim, dim)
    wo: Vec<f32>, // (layer, dim, dim)
    // Weights for ffn
    w1: Vec<f32>, // (layer, hidden_dim, dim)
    w2: Vec<f32>, // (layer, dim, hidden_dim)
    w3: Vec<f32>, // (layer, hidden_dim, dim)
    // final RMSNorm
    rms_final_weights: Vec<f32>, // (dim)
    // freq_cis for RoPE relatively positional embeddings
    freq_cis_real: Vec<f32>, // (seq_len, head_size/2)
    freq_cis_imag: Vec<f32>, // (seq_len, head_size/2)
    // (optional) classifier weights for the logits, on the last layer
    wcls: Vec<f32>, // (vocab_size, dim)
}

上述代码定义了一个名为 TransformerWeights 的结构体(struct),用于存储一组转换器(transformer)的权重值。结构体包含了多个字段,每个字段都指定了对应的字段名和类型。

  • token_embedding_table: Vec<f32>:用于存储标记嵌入表的权重值,类型为 Vec<f32>。表的形状为 (vocab_size, dim),其中 vocab_size 代表词汇表大小,dim 代表维度。
  • rms_att_weight: Vec<f32>:用于存储 RMSNorm 中注意力权重的值,类型为 Vec<f32>。形状为 (layer, dim),其中 layer 代表层数,dim 代表维度。
  • rms_ffn_weight: Vec<f32>:用于存储 RMSNorm 中前馈神经网络权重的值,类型为 Vec<f32>。形状为 (layer, dim),其中 layer 代表层数,dim 代表维度。
  • wq: Vec<f32>wk: Vec<f32>wv: Vec<f32>wo: Vec<f32>:用于存储注意力机制中矩阵相乘操作(matmuls)的权重值,类型均为 Vec<f32>。它们的形状为 (layer, dim, dim),其中 layer 代表层数,dim 代表维度。
  • w1: Vec<f32>w2: Vec<f32>w3: Vec<f32>:用于存储前馈神经网络中的权重值,类型均为 Vec<f32>。它们的形状为 (layer, hidden_dim, dim) 或 (layer, dim, hidden_dim),其中 layer 代表层数,dim 代表维度,hidden_dim 代表隐藏层维度。
  • rms_final_weights: Vec<f32>:用于存储最终的 RMSNorm 权重值,类型为 Vec<f32>。形状为 (dim),其中 dim 代表维度。
  • freq_cis_real: Vec<f32> 和 freq_cis_imag: Vec<f32>:用于相对位置编码的频率系数(RoPE)的实部和虚部的权重值,类型均为 Vec<f32>。它们的形状为 (seq_len, head_size/2),其中 seq_len 代表序列长度,head_size 代表每个注意力头的大小。
  • wcls: Vec<f32>:(可选)用于最后一层的 logits 分类器的权重值,类型为 Vec<f32>。形状为 (vocab_size, dim),其中 vocab_size 代表词汇表大小,dim 代表维度。

该结构体定义了一个用于存储转换器权重值的对象。通过创建 TransformerWeights 的实例,并为每个字段提供适当的值,我们可以在代码中使用该对象来管理和传递相关的权重。

其中:freq_ 开头的两个参数,它们是和位置编码有关的参数,也就是说,我们每次生成一个 Token 时,都需要传入当前位置的位置信息。位置编码在 Transformer 中是比较重要的,因为 Self Attention 本质上是无序的,而语言的先后顺序在有些时候是很重要的

加载参数

fn byte_chunk_to_vec<T>(byte_chunk: &[u8], number_elements: usize) -> Vec<T>
where
    T: Clone,
{
    unsafe {
        // 获取起始位置的原始指针
        let data = byte_chunk.as_ptr() as *const T;
        // 从原始指针创建一个 T 类型的切片,注意number_elements是element的数量,而不是bytes
        // 这句是 unsafe 的
        let slice_data: &[T] = std::slice::from_raw_parts(data, number_elements);
        // 将切片转为 Vec,需要 T 可以 Clone
        slice_data.to_vec()
    }
}
  1. fn byte_chunk_to_vec<T>(byte_chunk: &[u8], number_elements: usize) -> Vec<T>:定义了一个名为 byte_chunk_to_vec 的函数,以泛型 T 作为类型参数。它接受一个 byte_chunk 字节切片和一个 number_elements 元素数量作为参数,并返回一个 Vec<T> 向量。
  2. where T: Clone:使用 where 关键字,指定 T 必须实现 Clone trait,以便能够克隆值。这是为了在转换切片到向量时,能够复制每个元素。
  3. unsafe:标记块中的代码为不安全的代码。
  4. let data = byte_chunk.as_ptr() as *const T;:将 byte_chunk 的指针转换为 T 类型的常量指针,并将其赋值给 data 变量。这允许我们在指针级别对字节进行操作。
  5. let slice_data: &[T] = std::slice::from_raw_parts(data, number_elements);:使用 from_raw_parts 方法,根据起始指针 data 和元素数量 number_elements 创建一个 T 类型的切片 slice_data。请注意,这实际上并不执行数据复制。
  6. slice_data.to_vec():将切片 slice_data 转换为 Vec<T> 向量。此方法将从切片中复制每个元素,并创建一个新的向量。
  7. 最后,函数返回转换后的向量。

这段代码的目的是将一个字节切片转换为元素类型为 T 的向量。但由于涉及到指针操作和不安全的代码,因此要特别小心使用。where T: Clone 约束确保元素类型 T 可以被克隆来进行复制操作。byte_chunk 表示原始的字节切片,number_elements 表示结果向量中元素的个数。

unsafe的用法

  1. 获取原始指针:代码中使用 byte_chunk.as_ptr() 方法获取 byte_chunk 字节切片的原始指针,然后将其转换为 T 类型的常量指针。这个操作涉及到底层的指针操作,访问和操作内存的原始指针需要使用 unsafe 代码块。

  2. 从原始指针创建切片:使用 std::slice::from_raw_parts() 方法根据原始指针和元素数量创建了一个 T 类型的切片。这个方法也需要使用 unsafe 代码块,因为它接触到了指针和内存操作。

需要注意的是,使用 unsafe 关键字会打开 Rust 中的一些安全性限制。在使用 unsafe 代码块时,需要确保代码正确地处理了指针和内存操作,以避免造成内存安全和未定义行为。

在这段代码中使用 unsafe 主要是为了利用底层的指针操作和内存操作,来直接操作原始数据。这样可以避免不必要的数据复制,并提高性能。但同时,需要非常小心,确保代码在使用指针和访问内存时不会引发潜在的错误。

as *const T的用法

在 let data = byte_chunk.as_ptr() as *const T 这段代码中的 * 运算符是指针类型转换中的解引用运算符,访问指针所指向的数据。

在这段代码中,byte_chunk.as_ptr() 返回了 byte_chunk 字节切片的原始指针(raw pointer),然后使用 as *const T 将其转换为 *const T 类型的常量指针,以便与类型 T 相匹配。

需要注意的是,在这个上下文中的 * 运算符并不是乘法运算符,它是指针类型转换语法的一部分。解引用运算符不能在 Safe Rust 中直接使用,因为它是一个不安全操作,需要使用 unsafe 关键字包裹。例如,let value = *ptr; 将会将 value 变量设置为指针 ptr 所指向位置的值。

加载模型

读取原始的 bin 文件并指定对应的参数大小

et token_embedding_table_size = config.vocab_size * config.dim;
// offset.. 表示从 offset 往后的所有元素
let token_embedding_table: Vec<f32> = byte_chunk_to_vec(&mmap[offset..], token_embedding_table_size);
  1. 这行代码计算了 token_embedding_table 的大小,即词嵌入表的大小。它将 config.vocab_size(词汇表大小)乘以 config.dim(维度),并将结果赋值给变量 token_embedding_table_size
  2. 这行代码创建了一个名为 token_embedding_table 的 Vec<f32> 向量。它调用了 byte_chunk_to_vec 函数,并传递了从 mmap[offset..] 中的指定偏移位置开始的字节切片,以及 token_embedding_table_size 作为参数。

因为向量数据通常取决于从某个字节切片转换而来,所以需要借助 byte_chunk_to_vec 函数来执行转换。

简单的内容就到这里了,后面上硬菜,未完待续~~

  • 24
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值