#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>
int main() {
// Deserialize the ScriptModule from a file using torch::jit::load().
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("D:/Desktop/pytorch1_0/model.pt");
assert(module != nullptr);
std::cout << "ok\n";
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({ 1, 3, 224, 224 }));
// Execute the model and turn its output into a tensor.
at::Tensor output = module->forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
while (1);
}
运行时会出现 error C2440: “初始化”: 无法从“torch::jit::script::Module”转换为“std::shared_ptr<torch::jit::script::Module>”的Error
的错误,原因是std::shared_ptr这个是libtorch测试版本使用的变量类型,现在已经变更
将
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("../xxx.pt");
修改为
torch::jit::script::Module module = torch::jit::load("../xxx.pt");
同时,module已经不是指针,将
at::Tensor output = module->forward(inputs).toTensor();
修改为
torch::Tensor output = module.forward(std::move(inputs)).toTensor();
即可。
参考:https://github.com/pytorch/pytorch/issues/22382