#include <cuda_runtime_api.h>
std::string filename = "centernet.pt"//模型路径
int gpu_id = 1; //gpu id 0代表第一块可见gpu
cudaSetDevice(gpu_id); //切换显卡
torch::jit::script::Module module = torch::jit::load(filename,torch::Device(torch::DeviceType::CUDA,gpu_id));//加载模型
libtorch 加载torchscript模型有三个重载函数
TORCH_API script::Module load(
std::istream& in,
c10::optional<c10::Device> device = c10::nullopt,
script::ExtraFilesMap& extra_files = default_extra_files);
TORCH_API script::Module load(
const std::string& filename,
c10::optional<c10::Device> device = c10::nullopt,
script::ExtraFilesMap& extra_files = default_extra_files);
TORCH_API script::Module load(
std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
c10::optional<c10::Device> device = c10::nullopt,
script::ExtraFilesMap& extra_files = default_extra_files);
目前我是从文件加载模型,用第二个函数,选择设备这里主要关注第二个参数
c10::optional<c10::Device> device
这里我们需要构造一个device类传入,我们看Device类定义
Device(DeviceType type, DeviceIndex index = -1)
这里很显然第一个是设备类型,第二个是设备索引
第一个是枚举类:我们选择torch::DeviceType::CUDA 也就是nvidia显卡计算平台
第一个就是显卡id 我们填0代表第一块显卡,1代表第二块显卡