2021-01-06

PyTorch模型转化到可部署模型

import torch
import torchvision.models as models
from PIL import Image
import numpy as np
image = Image.open("build/airliner.jpg") #图片发在了build文件夹下
image = image.resize((224, 224),Image.ANTIALIAS)
image = np.asarray(image)
image = image / 255
image = torch.Tensor(image).unsqueeze_(dim=0)
image = image.permute((0, 3, 1, 2)).float()

model = models.resnet50(pretrained=True)
model = model.eval()
resnet = torch.jit.trace(model, torch.rand(1,3,224,224))
# output=resnet(torch.ones(1,3,224,224))
output = resnet(image)
max_index = torch.max(output, 1)[1].item()
print(max_index) # ImageNet1000类的类别序
resnet.save('resnet.pt')

C++调用训练好的模型

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)

project(example_torch)
set(CMAKE_PREFIX_PATH "XXX/libtorch") //注意这里填自己解压libtorch时的路径

find_package(Torch REQUIRED)
find_package(OpenCV 3.0 QUIET)
if(NOT OpenCV_FOUND)
    find_package(OpenCV 2.4.3 QUIET)
    if(NOT OpenCV_FOUND)
        message(FATAL_ERROR "OpenCV > 2.4.3 not found.")
    endif()
endif()
add_executable(${PROJECT_NAME} "main.cpp")
target_link_libraries(${PROJECT_NAME} ${TORCH_LIBRARIES} ${OpenCV_LIBS})
set_property(TARGET ${PROJECT_NAME} PROPERTY CXX_STANDARD 11)```
#include <torch/torch.h>
#include <torch/script.h>
#include <iostream>
#include <vector>
#include <opencv2/highgui.hpp>
#include <opencv2/core/core.hpp>
#include <opencv2/opencv.hpp>

void TorchTest(){
    std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("../resnet.pt");
    assert(module != nullptr);
    std::cout << "Load model successful!" << std::endl;
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::zeros({1,3,224,224}));
    at::Tensor output = module->forward(inputs).toTensor();
    auto max_result = output.max(1, true);
    auto max_index = std::get<1>(max_result).item<float>();
    std::cout << max_index << std::endl;
}

void Classfier(cv::Mat &image){
    torch::Tensor img_tensor = torch::from_blob(image.data, {1, image.rows, image.cols, 3}, torch::kByte);
    img_tensor = img_tensor.permute({0, 3, 1, 2});
    img_tensor = img_tensor.toType(torch::kFloat);
    img_tensor = img_tensor.div(255);
    std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("../Train/resnet.pt");
    torch::Tensor output = module->forward({img_tensor}).toTensor();
    auto max_result = output.max(1, true);
    auto max_index = std::get<1>(max_result).item<float>();
    std::cout << max_index << std::endl;

}

int main() {
//    TorchTest();
    cv::Mat image = cv::imread("airliner.jpg");
    cv::resize(image,image, cv::Size(224,224));
    std::cout << image.rows <<" " << image.cols <<" " << image.channels() << std::endl;
    Classfier(image);
    return 0;
}

参考文章

  • https://zhuanlan.zhihu.com/p/72750321
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值