C++运行Pytorch模型

4 篇文章 0 订阅
1 篇文章 0 订阅

这个事,干了好几次了,不过之前是在Windows上,这次改为linux,竟然一遍成功了。

首先官网把包下载下来
在这里插入图片描述
一开始不知道怎么用,省流,它是cmake的包。
然后解压到cmake环境下。给出官网LOADING A TORCHSCRIPT MODEL IN C++
然后跟着走,竟然成功了。

贴出代码防丢失。
跑完的目录结构

.
├── build
│   ├── CMakeCache.txt
│   ├── CMakeFiles
│   │   ├── 3.22.1
│   │   │   ├── CMakeCCompiler.cmake
│   │   │   ├── CMakeCXXCompiler.cmake
│   │   │   ├── CMakeDetermineCompilerABI_C.bin
│   │   │   ├── CMakeDetermineCompilerABI_CXX.bin
│   │   │   ├── CMakeSystem.cmake
│   │   │   ├── CompilerIdC
│   │   │   │   ├── a.out
│   │   │   │   ├── CMakeCCompilerId.c
│   │   │   │   └── tmp
│   │   │   └── CompilerIdCXX
│   │   │       ├── a.out
│   │   │       ├── CMakeCXXCompilerId.cpp
│   │   │       └── tmp
│   │   ├── cmake.check_cache
│   │   ├── CMakeDirectoryInformation.cmake
│   │   ├── CMakeOutput.log
│   │   ├── CMakeTmp
│   │   ├── hcn.dir
│   │   │   ├── build.make
│   │   │   ├── cmake_clean.cmake
│   │   │   ├── compiler_depend.internal
│   │   │   ├── compiler_depend.make
│   │   │   ├── compiler_depend.ts
│   │   │   ├── DependInfo.cmake
│   │   │   ├── depend.make
│   │   │   ├── flags.make
│   │   │   ├── hcn.cpp.o
│   │   │   ├── hcn.cpp.o.d
│   │   │   ├── link.txt
│   │   │   ├── main.cpp.o
│   │   │   ├── main.cpp.o.d
│   │   │   └── progress.make
│   │   ├── Makefile2
│   │   ├── Makefile.cmake
│   │   ├── progress.marks
│   │   └── TargetDirectories.txt
│   ├── cmake_install.cmake
│   ├── compile_commands.json
│   ├── hcn
│   └── Makefile
├── CMakeLists.txt
├── hcn2.pt
├── hcn.cpp
├── hcn.h
├── hcn.pt
└── main.cpp

CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(hcn)

set(CMAKE_PREFIX_PATH "~/cmake/libtorch")
set(CXX_STANDARD 17)
set(CMAKE_BUILD_TYPE Debug)

# include_directories("./")
# LINK_DIRECTORIES("~/cmake/libtorch")

find_package(Torch REQUIRED)
# message(TORCH_LIBRARIES: "${TORCH_LIBRARIES}")

add_executable(hcn main.cpp hcn.cpp)
target_link_options(hcn PRIVATE -static-libgcc -static-libstdc++)
target_link_libraries(hcn "${TORCH_LIBRARIES}")

hcn.h

#pragma once

#include <iostream>
#include <torch/script.h>

namespace myHcn {

class Hcn
{
public:
    std::vector<torch::jit::IValue> inputs;
    c10::IValue output;

    Hcn(const std::string & modelFile);
    ~Hcn() = default;

    torch::Tensor getHc();
    torch::Tensor run(float tfValue);
    int shotOverCallback();

private:
    torch::jit::script::Module model;
    torch::Tensor hc;                       // 额外传入model的参数
    bool setHc();
    bool isRunning;
};

}   // namespace myHcn

hcn.cpp

#include <iostream>
#include <torch/script.h>
#include "hcn.h"

namespace myHcn {

using namespace std;
using namespace torch;

Hcn::Hcn(const string &modelFile)
{
    model = jit::load(modelFile);
    isRunning = false;
}

bool Hcn::setHc()
{
    if (isRunning) {
        hc = output.toTuple()->elements()[1].toTensor();
    } else {
        hc = torch::zeros({4,1,1,96});
    }
    return true;
}

Tensor Hcn::getHc()
{
    return hc;
}

Tensor Hcn::run(float tfValue)
{
    setHc();
    isRunning = true;
    // Tensor input1 = tensor({tfValue});
    // inputs.clear();
    // inputs.push_back(input1);
    // inputs.push_back(hc);
    
    inputs = {tensor({tfValue}),hc};
    output = model.forward(inputs);

    return output.toTuple()->elements()[0].toTensor();
}

int Hcn::shotOverCallback()
{
    isRunning = false;
    return 0;
}

}   // namespace myHcn

main.cpp

#include <iostream>
#include <chrono>
#include <torch/script.h>
#include "hcn.h"

using namespace std;
using namespace myHcn;

int main(int argc, const char *argv[]) 
{
    string modelFile("../hcn2.pt");
    Hcn pt = Hcn(modelFile);

    // 接收数据,运行部分
    torch::Tensor hcnValue = pt.run(9);
    cout << "==================" << endl;
    cout << hcnValue << endl;

    hcnValue = pt.run(10);
    cout << "==================" << endl;
    cout << hcnValue << endl;

    // 一炮停止标志位
    pt.shotOverCallback();

    float itValue = 0.1;
    auto start = chrono::high_resolution_clock::now();
    for (int i=0; i<10000; ++i) {
        itValue += 0.001;
        hcnValue = pt.run(itValue);
    }
    auto end = chrono::high_resolution_clock::now();
    chrono::duration<double> elapsed = end - start;
    cout << "Elapsed time: " << elapsed.count() << " seconds" << std::endl;

    // 一炮停止标志位
    pt.shotOverCallback();

    return 0;
}

在这里插入图片描述
推理10000次,用时2.5秒,比python快多了,可以可以。
下一步准备转onnx格式试试。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小小鹅卵石

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值