今天在学习使用RUST来实现机器学习中的核心Tensor。
什么是Tensor
Tensor 为我们提供了一种通用的方式来描述 n 阶数组的扩展。0阶的Tensor可以理解为标量,1阶的Tensor是向量,2 阶的Tensor就是矩阵。
Tensor类型有多个,比如PyTorch框架中的Tensor,本次学习的Tensor来自HuggingFace开发的Candle框架。
Candle框架
Candle 的核心目标是让 Serverless 推理成为可能。像 PyTorch 这样的完整机器学习框架非常大,这使得在集群上创建实例的速度很慢。Candle 允许部署轻量级二进制文件。另外,Candle 可以让用户从生产工作负载中删除 Python。Python 开销会严重影响性能,而GIL是众所周知的令人头疼的问题。
当前 Candle 已经支持如今的前沿模型,像 Llama2。经过改写的模型,比如 Llama2 能够方便、快速的运行在容器环境,甚至可以运行在浏览器中。Candle 结构包括:
- Candle-core:核心操作、设备和 Tensor 结构定义。
- Candle-nn:构建真实模型的工具。
- Candle-examples:在实际设置中使用库的示例。
- Candle-kernels:CUDA 自定义内核;
- Candle-datasets:数据集和数据加载器。
- Candle-Transformers:与 Transformers 相关的实用程序。
- Candle-flash-attn:Flash attention v2 层。
操作Tensor
Candle 在 candle_core::Tensor 中定义了 Tensor 类型,可按照以下方式使用它:
use candle_core::{Tensor, DType, Device};
let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?
let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?
let c = a.matmul(&b)?
这个命令生成了两个向量(1 阶的 Tensor),然后将它们转换为矩阵,并对它们进行矩阵乘法。其中:
先调用了 Tensor
类型的arange()
方法。该方法通常用于生成一个具有一系列数值的张量。
arange()
方法接受三个参数:范围的起始值(1.0f32),范围的结束值(7.0f32)以及张量所在的设备(Device::Cpu
)
然后arange()
方法返回一个 Result
类型的值。好奇的是命令末尾都有一个?,这是什么用法呢?
关于?的使用
上网查了下关于?在RUST里的使用方法。
有时我们只是想
unwrap
且避免产生panic
。到现在为止,对unwrap
的错误处理都在强迫我们一层层地嵌套,然而我们只是想把里面的变量拿出来。?
正是为这种情况准备的。
当找到一个 Err
时,可以采取两种行动:
panic!
,不过我们已经决定要尽可能避免 panic 了。- 返回它,因为
Err
就意味着它已经不能被处理了。
?
就等于一个会返回 Err
而不是 panic
的 unwrap
。举例来说明:
use std::num::ParseIntError;
fn multiply(first_number_str: &str, second_number_str: &str) -> Result<i32, ParseIntError> {
let first_number = first_number_str.parse::<i32>()?;
let second_number = second_number_str.parse::<i32>()?;
Ok(first_number * second_number)
}
fn print(result: Result<i32, ParseIntError>) {
match result {
Ok(n) => println!("n is {}", n),
Err(e) => println!("Error: {}", e),
}
}
fn main() {
print(multiply("10", "2"));
print(multiply("t", "2"));
}
在 ?
出现以前,相同的功能是使用 try!
宏完成的。现在官方推荐使用 ?
运算符,但是在老代码中仍然会看到 try!
。
所以上面看到的?是一种错误的处理方式。在创建张量过程中如果发生任何错误,它将从包含这行代码的函数或代码块中返回。