c++测试pytorch训练的模型

6 篇文章 0 订阅
2 篇文章 0 订阅

c++测试pytorch训练的模型

pytorch训练的pth模型转换成onnx模型,使用c++测试onnx模型。

1. 模型.pth转.onnx

化繁为简

写那么多废话不如简单明了

import torch
from Unet import Unet


def pth2onnx(input, pth_path, onnx_path):
    model = Unet()  # 导入自己的网络模型
    model.load_state_dict(torch.load(pth_path))  # 初始化权重
    model.eval()

    torch.onnx.export(model, input, onnx_path, verbose=True)


if __name__ == '__main__':
    pth_path = r'./best_model.pth'  # 训练的pth路径
    onnx_path = r'./best_model.onnx'  # 保存onnx的路径
    model_input = torch.randn(1, 1, 512, 512)  # 模型输入[B,C,H,W]
    pth2onnx(input=model_input, pth_path=pth_path, onnx_path=onnx_path)

(可选)2. 测试.onnx模型转换是否正确

如果第3步模型测试不正确,才需要用第2步来检查是不是模型转换出了问题。

import cv2
import onnxruntime
import numpy as np

onnx_path = './best_model.onnx'  # 上一步生成的onnx模型
image_path = './data/test/1.bmp'  # 测试图像

image = cv2.imread(image_path)  # 读取图像
image = cv2.resize(image, (512, 512), interpolation=cv2.INTER_LINEAR)  # resize成相应尺寸
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)  # 灰度处理,使用的是单通道图像预测

# 处理成模型需要的格式,[B,C,H,W]
input = image.reshape(1, 1, image.shape[0], image.shape[1]).astype(np.float32)

session = onnxruntime.InferenceSession(onnx_path)
input_name = session.get_inputs()[0].name
outputs = session.run(None, {input_name: input})
print(outputs[0].shape)  # (1, 1, 512, 512)

pred = np.array(outputs[0])[0][0]
pred[pred > 0] = 255
pred[pred <= 0] = 0
cv2.imwrite("pred.bmp", pred)

3. C++测试

使用的是Qt进行的测试。

#include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
using namespace cv;
using namespace cv::dnn;
using namespace std;


int main()
{
    int h = 512; int w = 512;
    String modelFile = "F:/project/Unet_model/best_model.onnx";
    String imageFile = "F:/project/Unet_model/data/test/1.bmp";

    Mat img = imread(imageFile); // 读取测试图片
    cvtColor(img, img, cv::COLOR_BGR2GRAY);  // 灰度化
    resize(img, img, Size(h, w));

    Mat inputBolb = blobFromImage(img);  // 转换输入图像的格式[B,C,H,W]

    dnn::Net net = cv::dnn::readNetFromONNX(modelFile); //读取网络和参数
    net.setInput(inputBolb);
    Mat output = net.forward();  // 输出4D mat

    int B = inputBolb.size[0];
    int C = inputBolb.size[1];
    int H = inputBolb.size[2];
    int W = inputBolb.size[3];

    Mat predMat = Mat::zeros(h, w, CV_32F);

    for(int i = 0; i < B; i++){
        for(int j = 0; j < C; j++){
            for(int m = 0; m < H; m++){
                for(int n =0; n < W; n++){

                    float pred = output.ptr<float>(i,j,m)[n];

                    if(pred > 0){
                        predMat.at<float>(m,n) = 255;
                    }
                    else{
                        predMat.at<float>(m,n) = 0;
                    }

                }
            }
        }
    }

    cv::imwrite("F:/QtProject/pred.bmp", predMat);
}
  • 2
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 17
    评论
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小吕同学吖

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值