最近大模型和以文生图很火,看到了一些Rust在高性能推理方面的应用价值,其次也看到了Rust在量化方向的潜力,准备探索一下副业。
安装
使用cargo新建项目,然后添加burn库
cargo new test_burn
cargo add burn --features wgpu,tch,cuda,dataset,train,candle
burn-tch是封装了c++版本的torch库,需要去torch官网手动下载c++版本的压缩包。
后端
burn支持多种后端进行实际的计算,如webgpu、torch、candle等,如果追求跨平台可以考虑webgpu,如果追求性能,首选torch,candle比较新,不太好评价。
测试
这里测试webgpu、torch、candle。
use burn::backend::{libtorch, LibTorch};
use burn::backend::{wgpu, Wgpu};
use burn::backend::{candle, Candle};
use burn::tensor::Tensor;
fn main() {
{
type MyBackend = Wgpu<wgpu::Vulkan, f32, i32>;
let device = wgpu::WgpuDevice::DiscreteGpu(0);
println!("{:#?}", device);
let tensor_1 = Tensor::<MyBackend, 2>::from_data([[2., 3.], [4., 5.]], &device);
let tensor_2 = Tensor::<MyBackend, 2>::ones_like(&tensor_1);
println!("{}", tensor_1 + tensor_2);
}
{
type MyBackend = Candle;
let device = candle::CandleDevice::Cuda(0);
println!("{:#?}", device);
let tensor_1 = Tensor::<MyBackend, 2>::from_data([[2., 3.], [4., 5.]], &device);
let tensor_2 = Tensor::<MyBackend, 2>::ones_like(&tensor_1);
println!("{}", tensor_1 + tensor_2);
}
{
type MyBackend = LibTorch;
let device = libtorch::LibTorchDevice::Cuda(0);
println!("{:#?}", device);
let tensor_1 = Tensor::<MyBackend, 2>::from_data([[2., 3.], [4., 5.]], &device);
let tensor_2 = Tensor::<MyBackend, 2>::ones_like(&tensor_1);
println!("{}", tensor_1 + tensor_2);
}
}
运行时需要设置torch相关的环境变量,参考tch-rs