c++测试pytorch训练的模型
pytorch训练的pth模型转换成onnx模型,使用c++测试onnx模型。
1. 模型.pth转.onnx
化繁为简
写那么多废话不如简单明了
import torch
from Unet import Unet
def pth2onnx(input, pth_path, onnx_path):
model = Unet() # 导入自己的网络模型
model.load_state_dict(torch.load(pth_path)) # 初始化权重
model.eval()
torch.onnx.export(model, input, onnx_path, verbose=True)
if __name__ == '__main__':
pth_path = r'./best_model.pth' # 训练的pth路径
onnx_path = r'./best_model.onnx' # 保存onnx的路径
model_input = torch.randn(1, 1, 512, 512) # 模型输入[B,C,H,W]
pth2onnx(input=model_input, pth_path=pth_path, onnx_path=onnx_path)
(可选)2. 测试.onnx模型转换是否正确
如果第3步模型测试不正确,才需要用第2步来检查是不是模型转换出了问题。
import cv2
import onnxruntime
import numpy as np
onnx_path = './best_model.onnx' # 上一步生成的onnx模型
image_path = './data/test/1.bmp' # 测试图像
image = cv2.imread(image_path) # 读取图像
image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_LINEAR) # resize成相应尺寸
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # 灰度处理,使用的是单通道图像预测
# 处理成模型需要的格式,[B,C,H,W]
input = image.reshape(1, 1, image.shape[0], image.shape[1]).astype(np.float32)
session = onnxruntime.InferenceSession(onnx_path)
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: input})
print(outputs[0].shape) # (1, 1, 512, 512)
pred = np.array(outputs[0])[0][0]
pred[pred > 0] = 255
pred[pred <= 0] = 0
cv2.imwrite("pred.bmp", pred)
3. C++测试
使用的是Qt进行的测试。
#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
using namespace cv;
using namespace cv::dnn;
using namespace std;
int main()
{
int h = 512; int w = 512;
String modelFile = "F:/project/Unet_model/best_model.onnx";
String imageFile = "F:/project/Unet_model/data/test/1.bmp";
Mat img = imread(imageFile); // 读取测试图片
cvtColor(img, img, cv::COLOR_BGR2GRAY); // 灰度化
resize(img, img, Size(h, w));
Mat inputBolb = blobFromImage(img); // 转换输入图像的格式[B,C,H,W]
dnn::Net net = cv::dnn::readNetFromONNX(modelFile); //读取网络和参数
net.setInput(inputBolb);
Mat output = net.forward(); // 输出4D mat
int B = inputBolb.size[0];
int C = inputBolb.size[1];
int H = inputBolb.size[2];
int W = inputBolb.size[3];
Mat predMat = Mat::zeros(h, w, CV_32F);
for(int i = 0; i < B; i++){
for(int j = 0; j < C; j++){
for(int m = 0; m < H; m++){
for(int n =0; n < W; n++){
float pred = output.ptr<float>(i,j,m)[n];
if(pred > 0){
predMat.at<float>(m,n) = 255;
}
else{
predMat.at<float>(m,n) = 0;
}
}
}
}
}
cv::imwrite("F:/QtProject/pred.bmp", predMat);
}