安装libtorch记录
环境
- win10 + vs2017 + pytorch1.6.0 + cmake3.12.1 + opencv3.4.5 + cuda10.2
import torch as t
print(t.version.cuda)
print(t.__version__)
10.2
1.6.0
步骤
-
下载对应版本的libtorch
-
创建工程文件夹,包含build文件夹、example-app.cpp、CMakeLists.txt,使用cmake创建VS工程,需要设置对opencv和libtorch路径
cmake_minimum_required(VERSION 3.12 FATAL_ERROR)
project(example-app)
set(CMAKE_PREFIX_PATH "D:\\Program Files (x86)\\libtorch\\share\\cmake\\Torch")
find_package(Torch REQUIRED)
find_package(OpenCV REQUIRED)
if(NOT Torch_FOUND)
message(FATAL_ERROR "Pytorch Not Found!")
endif(NOT Torch_FOUND)
message(STATUS "Pytorch status:")
message(STATUS " libraries: ${TORCH_LIBRARIES}")
message(STATUS "OpenCV library status:")
message(STATUS " version: ${OpenCV_VERSION}")
message(STATUS " libraries: ${OpenCV_LIBS}")
message(STATUS " include path: ${OpenCV_INCLUDE_DIRS}")
add_executable(example-app example-app.cpp)
target_link_libraries(example-app ${TORCH_LIBRARIES} ${OpenCV_LIBS})
set_property(TARGET example-app PROPERTY CXX_STANDARD 11)
- 打开build文件夹下的vs工程,将example-app设为启动项
验证
转换权重
import torch as t
from minist_experiment.models import LeNet
model_save_path = './checkpoints/checkpoints_epoch99.pth'
model = LeNet()
model.load_state_dict(t.load(model_save_path))
model.eval()
example = t.rand(1,1,28,28)
traced_script_module = t.jit.trace(model, example)
output = traced_script_module(t.ones(1, 1, 28, 28))
print(output.data)
traced_script_module.save("./checkpoints/checkpoints_Script.pt")
C++中读取模型
#include <memory>
#include <torch/torch.h>
#include <torch/script.h> // One-stop header.
int main() {
// Deserialize the ScriptModule from a file using torch::jit::load().
torch::jit::script::Module module = torch::jit::load("D:/Code/python/pytorch/minist_experiment/checkpoints/checkpoints_Script.pt");
//assert(module != nullptr);
std::cout << "ok\n";
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({ 1, 1, 28, 28 }));
// Execute the model and turn its output into a tensor.
torch::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/10) << '\n';
while (1);
}
遇到的问题
- vs报错
error C2210: “_Ty”: 包扩展不能被用作别名模板中非打包参数的自变量
- 问题原因:C compiler版本太老
- 解决办法:更新一下VS
- 运行代码报错VCRUNTIME140_1D.dll找不到
- 解决办法:在https://www.dll-files.com/中下载缺少的dll文件,放到C:\Windows\SysWOW64下
- 缺少torch.dll等dll文件
- 解决办法:在libtorch安装路径里的lib(D:\Program Files (x86)\libtorch\lib)里找到对应文件添加到你工程文件夹下的这个目录下D:\Program Files (x86)\libtorchproject\build\x64\Debug
过程中参考了网上很多大哥的博客,这里就不一一列举了