winform调用pytorch上训练好的unet模型:
项目是写一个辅助诊断系统软件,用winform写软件,调用pytorch和matlab的模型。这篇博客只包含调用pytorch模型的部分。
1.c++(libtorch)调用模型
2.c++生成动态链接库
3.c#调用dll
1. libtorch(cpu)调用gpu模型
首先把pytorch的模型转成libtorch的。如果模型没有控制流(if-else语句),就用简单的trace方式进行转换。
pth文件转为pt文件
import torch
import torchvision
import numpy as np
import cv2
import os
from unet_parts import UNet
device = torch.device('cpu')
model = UNet(1,1)
#服务器gpu训练模型,导成cpu模型
model.load_state_dict(torch.load("best_model.pth", map_location='cpu'))
model.eval()
traced_script_module = torch.jit.trace(model)
traced_script_module.save("best_model.pt")
VS2019部署libtorch+opencv
下载libtorch、opencv
配置项目属性
右键项目,配置属性。VC++目录添加包含目录、库目录:
包含目录
D:\Tools\libtorch\include\torch\csrc\api\include
D:\Tools\libtorch\include
D:\Tools\opencv\build\include\opencv2
D:\Tools\opencv\build\include
库目录
D:\Tools\libtorch\lib
D:\Tools\opencv\build\x64\vc15\lib
添加链接器->输入->附加依赖项,把libtorch、opencv库目录下的所有.lib文件全都复制过来。
asmjit.lib
c10.lib
caffe2_detectron_ops.lib
caffe2_module_test_dynamic.lib
Caffe2_perfkernels_avx.lib
Caffe2_perfkernels_avx2.lib
Caffe2_perfkernels_avx512.lib
clog.lib
cpuinfo.lib
dnnl.lib
fbgemm.lib
fbjni.lib
kineto.lib
XNNPACK.lib
torch_cpu.lib
torch.lib
pytorch_jni.lib
libprotoc.lib
opencv_world455.lib
项目是release版本就调用release版本的dll:opencv_xxxx.lib,debug版本opencv_xxxd.lib,注意区别。
最后跑一下模型验证效果
2.c++生成动态链接库
创建项目:Windows桌面向导
选择应用类型:DLL
选择预编译头文件。其余配置同上
#include "pch.h"
#include<iostream>
#include <torch/torch.h>
#include <torch/script.h>
#include <opencv2/opencv.hpp>
#include <opencv2/imgproc/types_c.h>
using namespace cv;
using namespace std;
void __stdcall LoadModel(char* file_name) //接口传入图片地址
{
torch::jit::script::Module module;
module = torch::jit::load("E:/Project/octa/pytorch/best_model.pt");
module.eval();
torch::Device device(torch::kCPU);
module.to(device);
Mat image = imread(file_name);
cvtColor(image, image, CV_BGR2GRAY);
resize(image, image, Size(256, 256));
//Mat to Tensor, add one dimension
torch::Tensor tensor_image = torch::from_blob(image.data, { image.rows, image.cols,image.channels() }, torch::kByte);
//-> 1*1*256*256
tensor_image = tensor_image.permute({ 2,0,1 });
//preprocessing
tensor_image = tensor_image.toType(torch::kFloat);
//tensor_image = tensor_image.div(255); //normalization
tensor_image = tensor_image.unsqueeze(0);
tensor_image.to(device);
torch::Tensor output = module.forward({ tensor_image }).toTensor();
torch::Tensor output_max = output.squeeze();
//cout << output_max[0][0][0];
//tensor to Mat
output_max = (output_max >= 0.5);
output_max = output_max.mul(255).clamp(0, 255).to(torch::kU8);
output_max = output_max.to(torch::kCPU);
Mat result_img(Size(256, 256), CV_8UC1); //8 bit 1 channel, 256 colors
memcpy((void*)result_img.data, output_max.data_ptr(), sizeof(torch::kU8) * output_max.numel());
//imshow("result", result_img);
vector<int> compression_params;
compression_params.push_back(IMWRITE_PNG_COMPRESSION);
compression_params.push_back(0); // 无压缩png.
compression_params.push_back(IMWRITE_PNG_STRATEGY);
compression_params.push_back(IMWRITE_PNG_STRATEGY_DEFAULT);
imwrite("test.png", result_img, compression_params);
//waitKey(0);
return;
}
头文件
#pragma once
//定义宏
#ifdef UNET_DLL
#define UNET_DLL __declspec(dllexport)
#else
#define UNET_DLL __declspec(dllimport)
#endif
extern "C" UNET_DLL void __stdcall LoadModel(char* file_name);
//extern "C" 避免编码问题改变函数名,导致找不到入口函数
注意!!不用提供的模板pch.h,pch.cpp时,进行如下操作:
右键项目 --> 属性 --> C/C++ --> 预编译头 -->预编译头 改为创建
3. winform调用dll
把libtorch中所有dll、刚刚生成的dll文件放到项目文件夹的bin目录下。创建Windows窗体应用程序。
using System.Runtime.InteropServices;
namespace WindowsFormsApp1
{
public partial class Form1 : Form
{
//静态调用dll
[DllImport(@"E:\vs\winform_test\WindowsFormsApp1\bin\octa.dll", EntryPoint = "LoadModel", CharSet = CharSet.Ansi, CallingConvention = CallingConvention.Winapi, ExactSpelling = false)]
public static extern void LoadModel(string s);
private void analyzeButton_Click(object sender, EventArgs e)
{
LoadModel("E:/Project/octa/dataset/test/0.png");
string ss = openFileDialog1.FileName;
if(!string.IsNullOrEmpty(ss))
{
LoadModel(ss);
MessageBox.Show("Region segmentation succeeded!");
}
}
}
}