tch-rs指南 - Tensor的基本操作

1 概述

在使用rust进行torch模型部署时,不可避免地会用到tch-rs。但是tch-rs文档太过简洁,和没有一样,网上的资料也少得可怜,很多操作需要我们自己去试。这些内容虽然简单,但是自己找起来很费时间。

这篇文章总结了如何使用tch-rs进行tensor的基本操作。讲述的内容参考了pytorch的tensor教程

运行环境:

[dependencies]
tch = "0.7.0"
opencv = "0.63"

2 Tensor的基本操作

用到的库

use std::iter;

use opencv::prelude::*;
use opencv::core::{Mat, Scalar};
use opencv::core::{CV_8UC3};
use tch::IndexOp;
use tch::{Device, Tensor};

2.1 Tensor的初始化

(1)通过数组创建
let t = Tensor::of_slice::<i32>(&[1, 2, 3, 4, 5]);
t.print();
// vector也是一样的
let v = vec![1,2,3];
let t = Tensor::of_slice::<i32>(&v);
t.print();
// 2d vector
let v = vec![[1.5,2.0,3.9,4.4], [3.1,4.3,5.1,6.9]];
let v:Vec<f32> = v
    .iter()
    .flat_map(|array| array.iter())
    .cloned()
    .collect();
let data = unsafe{
    std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * std::mem::size_of::<f32>())
};
let t = Tensor::of_data_size(data, &[2,4], tch::Kind::Float);
t.print();

print的结果是

 1
 2
 3
 4
 5
[ CPUIntType{5} ]
 1
 2
 3
[ CPUIntType{3} ]
 1.5000  2.0000  3.9000  4.4000
 3.1000  4.3000  5.1000  6.9000
[ CPUFloatType{2,4} ]
(2)通过默认方法创建
let t = Tensor::randn(&[2, 3], (tch::Kind::Float, Device::Cpu));
t.print();
let t = Tensor::ones(&[2, 3], (tch::Kind::Float, Device::Cpu));
t.print();
let t = Tensor::zeros(&[2, 3], (tch::Kind::Float, Device::Cpu));
t.print();
let t = Tensor::arange_start(0, 2 * 3, (tch::Kind::Float, Device::Cpu)).view([2, 3]);
t.print();

print的结果是

 1.0522  0.6981  0.9236
 0.2324 -1.1048 -2.5820
[ CPUFloatType{2,3} ]
 1  1  1
 1  1  1
[ CPUFloatType{2,3} ]
 0  0  0
 0  0  0
[ CPUFloatType{2,3} ]
 0  1  2
 3  4  5
[ CPUFloatType{2,3} ]
(3)通过其他的tensor创建
let t = Tensor::randn(&[2, 3], (tch::Kind::Float, Device::Cpu));
let t = t.rand_like();
t.print();

print的结果是

 0.3376  0.1885  0.3415
 0.5135  0.8321  0.4140
[ CPUFloatType{2,3} ]
(4)通过opencv::core::Mat创建

这可以用在opencv读取图像后,转为torch tensor。当然tch-rs本身也有各种读取图片的方式,可见tch::vision::image。这里介绍两种方法,一种通过tch::Tensor::f_of_blob,一种通过tch::Tensor::of_data_size

// 创建一个(row, col, channel)=(2, 3, 3)=(height, width, channel)的Mat
let mat = Mat::new_rows_cols_with_default(
    2, 3, CV_8UC3, Scalar::from((3.0, 2.0, 1.0))
).unwrap();
// 获取mat的size,这里的结果是[2, 3, 3]
let size: Vec<_> = mat.mat_size().iter().cloned().map(|dim| dim as i64).chain(iter::once(mat.channels() as i64)).collect();
// 获取每个dimension的stride,这里的结果是[9, 3, 1]
let strides = {
    let mut strides: Vec<_> = size
        .iter()
        .rev()
        .cloned()
        .scan(1, |prev, dim| {
            let stride = *prev;
            *prev *= dim;
            Some(stride)
        })
        .collect();
    strides.reverse();
    strides
};
// 构建tensor
let t = unsafe {
    let ptr = mat.ptr(0).unwrap() as *const u8;
    tch::Tensor::f_of_blob(ptr, &size, &strides, tch::Kind::Uint8, tch::Device::Cpu).unwrap()
};
t.print();

print的结果是

(1,.,.) = 
  3  2  1
  3  2  1
  3  2  1

(2,.,.) = 
  3  2  1
  3  2  1
  3  2  1
[ CPUByteType{2,3,3} ]

还有一种比较简洁的转换方法

let mut mat = Mat::new_rows_cols_with_default(
    2, 3, CV_8UC3, Scalar::from((3.0, 2.0, 1.0))
).unwrap();
let h = mat.size().unwrap().height;
let w = mat.size().unwrap().width;   
let data = mat.data_bytes_mut().unwrap(); 
let t = tch::Tensor::of_data_size(data, &[h as i64, w as i64, 3], tch::Kind::Uint8);
t.print();

print的结果也是

(1,.,.) = 
  3  2  1
  3  2  1
  3  2  1

(2,.,.) = 
  3  2  1
  3  2  1
  3  2  1
[ CPUByteType{2,3,3} ]
test tensor_ops::init_ops ... ok

2.2 Tensor的属性

用tch::Tensor的print()方法可打印出数据的所有属性,但是想要获取到这些属性,需要用其他的方法。

let t = Tensor::randn(&[2, 3], (tch::Kind::Float, Device::Cpu));
println!("size of the tensor: {:?}", t.size());
println!("kind of the tensor: {:?}", t.kind());
println!("device on which the tensor is located: {:?}", t.device());

打印的结果是

size of the tensor: [2, 3]
kind of the tensor: Float
device on which the tensor is located: Cpu

2.3 Tensor的运算

(1)改变device

.to().to_device()这两个方法都可以。

let mut t = Tensor::randn(&[2, 3], (tch::Kind::Float, Device::Cpu));
if tch::Cuda::is_available(){
    t = t.to(Device::Cuda(0));
    println!("change device to {:?}", t.device());
}
t = t.to_device(Device::Cpu);
println!("change device to {:?}", t.device());

如果是有cuda,且安装了cuda版本的tch-rs的话,就会打印出

change device to Cuda(0)
change device to Cpu
(2)获取值(indexing and slicing)

这个在tch-rs的例子中有很多,详见tests/tensor_indexing.rs。这里列几种常用的。

通过.i()进行索引

let tensor = Tensor::arange_start(0, 2 * 3, (tch::Kind::Float, Device::Cpu)).view([2, 3]);
println!("original tensor:");
tensor.print();
println!("tensor.i(0):");
tensor.i(0).print();
println!("tensor.i((1, 1)):");
tensor.i((1, 1)).print();
println!("tensor.i((.., 2)):");
tensor.i((.., 2)).print();
println!("tensor.i((.., -1)):");
tensor.i((.., -1)).print();
println!("tensor.i((.., [2, 0])):");
let index: &[_] = &[2, 0];
tensor.i((.., index)).print();

打印的结果是

original tensor:
 0  1  2
 3  4  5
[ CPUFloatType{2,3} ]
tensor.i(0):
 0
 1
 2
[ CPUFloatType{3} ]
tensor.i((1, 1)):
4
[ CPUFloatType{} ]
tensor.i((.., 2)):
 2
 5
[ CPUFloatType{2} ]
tensor.i((.., -1)):
 2
 5
[ CPUFloatType{2} ]
tensor.i((.., [2, 0])):
 2  0
 5  3
[ CPUFloatType{2,2} ]

通过.index()进行索引

let tensor = Tensor::arange(6, (tch::Kind::Int64, Device::Cpu)).view((2, 3));
println!("original tensor:");
tensor.print();
let rows_select = Tensor::of_slice(&[0i64, 1, 0]);
let column_select = Tensor::of_slice(&[1i64, 2, 2]);
let selected = tensor.index(&[Some(rows_select), Some(column_select)]);
println!("selecte by row and column:");
selected.print();

打印的结果是

original tensor:
 0  1  2
 3  4  5
[ CPULongType{2,3} ]
selecte by row and column:
 1
 5
 2
[ CPULongType{3} ]
(3)合并tensors

Tensor::f_cat不会生成新的axis,而Tensor::stack会生成新的axis。

let t1 = Tensor::arange(6, (tch::Kind::Int64, Device::Cpu)).view((2, 3));
let t2 = Tensor::arange_start(6, 12, (tch::Kind::Int64, Device::Cpu)).view((2, 3));
let tensor = Tensor::f_cat(&[t1.copy(), t2.copy()], 1).unwrap();
println!("using Tensor::f_cat");
tensor.print();
let tensor = Tensor::stack(&[t1.copy(), t2.copy()], 1);
println!("using Tensor::stack");
tensor.print();

打印的结果是

using Tensor::f_cat
  0   1   2   6   7   8
  3   4   5   9  10  11
[ CPULongType{2,6} ]
using Tensor::stack
(1,.,.) = 
  0  1  2
  6  7  8

(2,.,.) = 
   3   4   5
   9  10  11
[ CPULongType{2,2,3} ]
(4)四则运算

tch-rs对[+, -, *, /]都进行了重载,可以实现和标量的直接运算。涉及到dim的复杂运算可以用tensor来处理。下面以加法为例,其他与f_add对应的分别是f_subf_mulf_div

let tensor = Tensor::ones(&[2, 4, 3], (tch::Kind::Float, Device::Cpu));
tensor.print();
// add with scalar
let add_tensor = &tensor + 0.5;
add_tensor.print();
// add with tensor
let add_tensor = Tensor::of_slice::<f32>(&[1.0,2.0,3.0]).view((1,1,3));
let add_tensor = &tensor.f_add(&add_tensor).unwrap();
add_tensor.print();

打印的结果为

original tensor:
(1,.,.) = 
  1  1  1
  1  1  1
  1  1  1
  1  1  1

(2,.,.) = 
  1  1  1
  1  1  1
  1  1  1
  1  1  1
[ CPUFloatType{2,4,3} ]
add with scalar:
(1,.,.) = 
  1.5000  1.5000  1.5000
  1.5000  1.5000  1.5000
  1.5000  1.5000  1.5000
  1.5000  1.5000  1.5000

(2,.,.) = 
  1.5000  1.5000  1.5000
  1.5000  1.5000  1.5000
  1.5000  1.5000  1.5000
  1.5000  1.5000  1.5000
[ CPUFloatType{2,4,3} ]
add with tensor:
(1,.,.) = 
  2  3  4
  2  3  4
  2  3  4
  2  3  4

(2,.,.) = 
  2  3  4
  2  3  4
  2  3  4
  2  3  4
[ CPUFloatType{2,4,3} ]

参考资料

[1] https://github.com/LaurentMazare/tch-rs
[2] https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html#

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

七元权

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值