![2ea2ff58fbabefb73e2e2c37dfdbc2b7.png](https://img-blog.csdnimg.cn/img_convert/2ea2ff58fbabefb73e2e2c37dfdbc2b7.png)
我在Jupyter lab里面训练出的PyTorch模型,它的输入是包含两个float数值的tensor,输出是包含一个float数值的 1x1的tensor;
那么,我需要解决的具体问题如下:
(1)将C++中两个float类型的数值转变成一个 1x2 的 tensor;
(2)调用PyTorch模型,输入该1x2 的 tensor 进行推理,输出一个 1x1 的 tensor (代表含水率);
(3)将输出的 1x1 tensor 转变成一个 C++ float/double 类型的数值,方便Qt程序调用;
一、接下来贴上具体解决的代码!
1.将Qt C++中的 float 类型的 模型的输入数据 变成一个 tensor;
从数组中获取Tensor,可以理解为将数组转化为Tensor,有两种方法:
方法一 :用 torch : : from_blob( )
参考:https://blog.csdn.net/baidu_34595620/article/details/101532239
例子如下:
float
方法二 :用 torch : : tensor( )
参考:https://blog.csdn.net/qq_14975217/article/details/90512374
C++中
**
2. 调用PyTorch模型,输入该1x2 的 tensor 进行推理,输出一个 1x1 的 tensor (代表含水率);
参考:https://pytorch.apachecn.org/docs/1.0/cpp_export.html?h=LibTorch
调用PyTorch模型的步骤:
// 1.实例化一个模型对象
3. 将输出的 1x1 tensor 转变成一个 C++ float/double 类型的数值
参考: https://blog.csdn.net/baidu_34595620/article/details/101532239
方法:用 accessor 方法提取数据,函数原型:accessor<dtype,dim>
访问 Tensor 中的数据,并将数据返回为 dtype 类型
CSDN【chencision】给出的例子
torch
PyTorch官网的例子,链接:https://pytorch.org/cppdocs/notes/tensor_basics.html
![831c9c0ee829b0f7a55a34f9b0d9e862.png](https://img-blog.csdnimg.cn/img_convert/831c9c0ee829b0f7a55a34f9b0d9e862.png)
accessor<dtype,dim> 中,参数 dim 表示 tensor 有两个维度
二、我自己Qt中解决上述3个小问题的代码
double
感谢在学习调用 PyTorch 模型的过程中结识的大佬朋友 OLDPAN 、街道口扛把子、chencision 等等;
整理 码字不易,希望小伙伴们引用时注明参考出处,谢谢。
参考文章
(1)https://blog.csdn.net/baidu_34595620/article/details/101532239
(2)PyTorch C++ API - PyTorch master documentation (官网API)
(3)https://blog.csdn.net/luoyexuge/article/details/81871866 (可能是tensorflow的,但是可以学习 vector 的用法)
(4)https://blog.csdn.net/weixin_38664232/article/details/94191029
(5)使用 PyTorch C++ 前端 (看了多遍的)
(6)https://blog.csdn.net/qq_14975217/article/details/90512374
(7)https://blog.csdn.net/weixin_45415546/article/details/99639632 (torch::from_blob函数)
(8)https://blog.csdn.net/Suan2014/article/details/94559294 (似曾相识的报错)
(9)https://blog.csdn.net/qq_14975217/article/details/90512374 (类型转换,很好的文章,靠他解决了问题!)
(10)leaf:C++ load pytorch 时的数据转换 (上面CSDN博文同步的知乎页)