c++opencv调用tensorflow模型(以mnist为例)

7 篇文章 0 订阅
2 篇文章 0 订阅

c++opencv调用tensorflow模型(以mnist为例)

0. 环境

win10 + vs2019 + opencv(OpenCV_VERSION 3.4.6) + keras(tensorflow后端)

1. 准备数据

从http://yann.lecun.com/exdb/mnist/下载
“data/t10k-labels-idx1-ubyte.gz”;
“data/t10k-images-idx3-ubyte.gz”;
“data/train-labels-idx1-ubyte.gz”;
“data/train-images-idx3-ubyte.gz”;

用解压文件直接解压为以下文件
“data/t10k-labels-idx1-ubyte”;
“data/t10k-images-idx3-ubyte”;
“data/train-labels-idx1-ubyte”;
“data/train-images-idx3-ubyte”;

2. 读取mnist文件

使用C++和OpenCV读取MNIST文件[https://blog.csdn.net/sheng_ai/article/details/23267039]
(代码可以直接使用)

3. python下使用keras生成tensorflow pb模型

#%%
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Conv2D, MaxPool2D, ReLU, Input, Softmax, Reshape
from keras import backend as K
import tensorflow as tf 

#%%
tf.reset_default_graph()
#%%

def net(input_size, optimizer):
    input_x = Input(input_size, name="x")
    conv1 = Conv2D(20, 5, padding = 'same', kernel_initializer = 'he_normal')(input_x)
    pool1 = MaxPool2D(2)(conv1)
    relu1 = ReLU()(pool1)

    conv2 = Conv2D(50, 5, padding = 'same', kernel_initializer = 'he_normal')(relu1)
    pool2 = MaxPool2D(2)(conv2)
    relu2 = ReLU()(pool2)

    conv3 = Conv2D(10, 5, padding = 'same', kernel_initializer = 'he_normal')(relu2)
    pool3 = MaxPool2D(7)(conv3)

    out = Reshape([10])(pool3)

    out = Softmax(name="output")(out)

    model = Model(inputs=input_x, outputs = out)
    model.compile(
        optimizer=optimizer,
        loss = "categorical_crossentropy",
        metrics=["accuracy"]
    )

    return model


(train_x, train_y), (test_x, test_y) = mnist.load_data()

train_x = train_x.reshape(train_x.shape[0], 28, 28, 1) / 255
test_x = test_x.reshape(test_x.shape[0], 28, 28, 1) / 255

train_y = np_utils.to_categorical(train_y, num_classes=10)
test_y = np_utils.to_categorical(test_y, num_classes=10)

rmsprop = keras.optimizers.RMSprop(lr=0.001, rho=0.9, epsilon=1e-08, decay=0.0)
model = net([28, 28, 1], rmsprop)

#%%
print(model.summary())

#%%
print("Training --------------")
model.fit(train_x, train_y, epochs=4, batch_size=32)

print("Testing --------------")
loss, accuracy = model.evaluate(test_x, test_y)

print("test loss: ", loss)
print("test accuracy: ", accuracy)

# 查看所有节点
# tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
# print(tensor_name_list)

#%%
# 输出Pb模型
sess = K.get_session()
frozen_graph = tf.compat.v1.graph_util.convert_variables_to_constants(
    sess,
    sess.graph_def,
    output_node_names=["output/Softmax"]
)

tf.train.write_graph(frozen_graph, "log/1.2", "mnist.pb", as_text=False)

输出结果为:

Epoch 1/4
60000/60000 [==============================] - 13s 216us/step - loss: 0.2229 - acc: 0.9322
Epoch 2/4
60000/60000 [==============================] - 10s 167us/step - loss: 0.0754 - acc: 0.9767
Epoch 3/4
60000/60000 [==============================] - 10s 167us/step - loss: 0.0550 - acc: 0.9828
Epoch 4/4
60000/60000 [==============================] - 10s 168us/step - loss: 0.0450 - acc: 0.9859
Testing --------------
10000/10000 [==============================] - 1s 62us/step
test loss:  0.05498976424507564
test accuracy:  0.9838

4. c++ + opencv读取pb模型, 验证测试集

#include <iostream>
#include <fstream>
#include <string>
#include <vector>
#include <opencv2/opencv.hpp>

#include "MNIST.h"

std::vector<int> Argmax(cv::Mat x)
{
	std::vector<int> res;
	for (int i = 0; i < x.rows; i++)
	{
		int maxIdx = 0;
		float maxNum = 0.0;
		for (int j = 0; j < x.cols; j++)
		{
			float tmp = x.at<float>(i, j);
			if (tmp > maxNum)
			{
				maxIdx = j;
				maxNum = tmp;

			}
		}
		res.push_back(maxIdx);
	}

	return res;
}

float Accuracy(cv::Mat x, cv::Mat y, std::string pbfile)
{
	float count = 0.0;
	
	cv::dnn::Net net = cv::dnn::readNetFromTensorflow(pbfile);

    // blob输入时需要至少dims为3的数据, 其数据形状为(图片数目, 宽度, 高度)
	int size[] = { x.rows, 28, 28 };
	cv::Mat imgs = cv::Mat(3, size, CV_8UC1, x.data);

	cv::Mat blob = cv::dnn::blobFromImages(imgs, 1.0 / 255.0, cv::Size(28, 28), cv::Scalar(), false, false);
	net.setInput(blob);
	cv::Mat pred = net.forward();

	std::vector<int> res = Argmax(pred);

	for (int i = 0; i < res.size(); i++)
	{
		if (*(y.ptr<int>(0) + i) == res[i])
			{
				count = count + 1;
			}
	}
	return count / x.rows;


}

int main()
{
	std::string testLableFile = "data/t10k-labels-idx1-ubyte";
	std::string testImageFile = "data/t10k-images-idx3-ubyte";

	std::string trainLableFile = "data/train-labels-idx1-ubyte";
	std::string trainImageFile = "data/train-images-idx3-ubyte";


	cv::Mat trainY = ReadLabels(trainLableFile);
	cv::Mat testY = ReadLabels(testLableFile);

	cv::Mat trainX = ReadImages(trainImageFile);	
	cv::Mat testX = ReadImages(testImageFile);
	

	testY.convertTo(testY, CV_32SC1);

	std::string pbfile = "mnist.pb";
	//testX.convertTo(testX, CV_32FC1, 1.0/255.0, 0);
	float acc = Accuracy(testX, testY, pbfile);
	std::cout << acc;
	return 0;
}

控制台输出为:

[ INFO:0] Initialize OpenCL runtime...
0.9838

5. mnist.h mnist.cpp

#pragma once
#ifndef MNIST_H
#define MNIST_H

#include <iostream>
#include <fstream>
#include <opencv2/opencv.hpp>

struct MNISTImageFileHeader
{
	unsigned char MagicNumber[4];
	unsigned char NumberOfImages[4];
	unsigned char NumberOfRows[4];
	unsigned char NumberOfColums[4];
};


struct MNISTLabelFileHeader
{
	unsigned char MagicNumber[4];
	unsigned char NumberOfLabels[4];
};

const int MAGICNUMBEROFIMAGE = 2051;
const int MAGICNUMBEROFLABEL = 2049;

int ConvertCharArrayToInt(unsigned char* array, int LengthOfArray);

bool IsImageDataFile(unsigned char* MagicNumber, int LengthOfArray);

bool IsLabelDataFile(unsigned char* MagicNumber, int LengthOfArray);

cv::Mat ReadData(std::fstream& DataFile, int NumberOfData, int DataSizeInBytes);

cv::Mat ReadImageData(std::fstream& ImageDataFile, int NumberOfImages);

cv::Mat ReadLabelData(std::fstream& LabelDataFile, int NumberOfLabel);

cv::Mat ReadImages(std::string& FileName);

cv::Mat ReadLabels(std::string& FileName);




#endif // MNIST_H


#include "MNIST.h"

int ConvertCharArrayToInt(unsigned char* array, int LengthOfArray)
{
	if (LengthOfArray < 0)
	{
		return -1;
	}
	int result = static_cast<signed int>(array[0]);
	for (int i = 1; i < LengthOfArray; i++)
	{
		result = (result << 8) + array[i];
	}
	return result;
}



bool IsImageDataFile(unsigned char* MagicNumber, int LengthOfArray)
{
	int MagicNumberOfImage = ConvertCharArrayToInt(MagicNumber, LengthOfArray);
	if (MagicNumberOfImage == MAGICNUMBEROFIMAGE)
	{
		return true;
	}

	return false;
}




/**
 * @brief IsImageDataFile  Check the input MagicNumber is equal to
 *                         MAGICNUMBEROFLABEL
 * @param MagicNumber      The array of the magicnumber to be checked
 * @param LengthOfArray    The length of the array
 * @return true, if the magcinumber is mathed;
 *         false, otherwise.
 *
 * @author sheng
 * @version 1.0.0
 * @date  2014-04-08
 *
 * @histroy     <author>      <date>      <version>      <description>
 *               sheng      2014-04-08      1.0.0      build the function
 */
bool IsLabelDataFile(unsigned char* MagicNumber, int LengthOfArray)
{
	int MagicNumberOfLabel = ConvertCharArrayToInt(MagicNumber, LengthOfArray);
	if (MagicNumberOfLabel == MAGICNUMBEROFLABEL)
	{
		return true;
	}

	return false;
}




/**
 * @brief ReadData  Read the data in a opened file
 * @param DataFile  The file which the data is read from.
 * @param NumberOfData  The number of the data
 * @param DataSizeInBytes  The size fo the every data
 * @return The Mat which rows is a data,
 *         Return a empty Mat if the file is not opened or the some flag was
 *                 seted when reading the  data.
 *
 * @author sheng
 * @version 1.0.0
 * @date  2014-04-08
 *
 * @histroy     <author>      <date>      <version>      <description>
 *               sheng      2014-04-08      1.0.0      build the function
 */

cv::Mat ReadData(std::fstream& DataFile, int NumberOfData, int DataSizeInBytes)
{
	cv::Mat DataMat;


	// read the data if the file is opened.
	if (DataFile.is_open())
	{


		int AllDataSizeInBytes = DataSizeInBytes * NumberOfData;
		char* TmpData = new char[AllDataSizeInBytes];
		DataFile.read((char*)TmpData, AllDataSizeInBytes);

		//        // If the state is good, convert the array to a mat.
		//        if (!DataFile.fail())
		//        {
		//            DataMat = cv::Mat(NumberOfData, DataSizeInBytes, CV_8UC1,
		//                              TmpData).clone();
		//        }
		DataMat = cv::Mat(NumberOfData, DataSizeInBytes, CV_8UC1,
			TmpData).clone();

		delete[] TmpData;
		DataFile.close();
		

	}

	return DataMat;
}

/**
 * @brief ReadImageData  Read the Image data from the MNIST file.
 * @param ImageDataFile  The file which contains the Images.
 * @param NumberOfImages The number of the images.
 * @return The mat contains the image and each row of the mat is a image.
 *         Return empty mat is the file is closed or the data is not matching
 *                the number.
 *
 * @author sheng
 * @version 1.0.0
 * @date  2014-04-08
 *
 * @histroy     <author>      <date>      <version>      <description>
 *               sheng      2014-04-08      1.0.0      build the function
 */
cv::Mat ReadImageData(std::fstream& ImageDataFile, int NumberOfImages)
{
	int ImageSizeInBytes = 28*28;

	return ReadData(ImageDataFile, NumberOfImages, ImageSizeInBytes);
}



/**
 * @brief ReadLabelData Read the label data from the MNIST file.
 * @param LabelDataFile The file contained the labels.
 * @param NumberOfLabel The number of the labels.
 * @return The mat contains the labels and each row of the mat is a label.
 *         Return empty mat is the file is closed or the data is not matching
 *                the number.
 *
 * @author sheng
 * @version 1.0.0
 * @date  2014-04-08
 *
 * @histroy     <author>      <date>      <version>      <description>
 *               sheng      2014-04-08      1.0.0      build the function
 */
cv::Mat ReadLabelData(std::fstream& LabelDataFile, int NumberOfLabel)
{
	int LabelSizeInBytes = 1;

	return ReadData(LabelDataFile, NumberOfLabel, LabelSizeInBytes);
}




/**
 * @brief ReadImages Read the Training images.
 * @param FileName  The name of the file.
 * @return The mat contains the image and each row of the mat is a image.
 *         Return empty mat is the file is closed or the data is not matched.
 *
 * @author sheng
 * @version 1.0.0
 * @date  2014-04-08
 *
 * @histroy     <author>      <date>      <version>      <description>
 *               sheng      2014-04-08      1.0.0      build the function
 */
cv::Mat ReadImages(std::string& FileName)
{
	std::fstream File(FileName.c_str(), std::ios_base::in | std::ios_base::binary);

	if (!File.is_open())
	{
		return cv::Mat();
	}

	MNISTImageFileHeader FileHeader;
	File.read((char*)(&FileHeader), sizeof(FileHeader));

	if (!IsImageDataFile(FileHeader.MagicNumber, 4))
	{
		return cv::Mat();
	}

	int NumberOfImage = ConvertCharArrayToInt(FileHeader.NumberOfImages, 4);

	return ReadImageData(File, NumberOfImage);
}




/**
 * @brief ReadLabels  Read the label from the MNIST file.
 * @param FileName  The name of the file.
 * @return The mat contains the image and each row of the mat is a image.
 *         Return empty mat is the file is closed or the data is not matched.
 *
 * @author sheng
 * @version 1.0.0
 * @date  2014-04-08
 *
 * @histroy     <author>      <date>      <version>      <description>
 *               sheng      2014-04-08      1.0.0      build the function
 */
cv::Mat ReadLabels(std::string& FileName)
{
	std::fstream File(FileName.c_str(), std::ios_base::in | std::ios_base::binary);

	if (!File.is_open())
	{
		return cv::Mat();
	}

	MNISTLabelFileHeader FileHeader;
	File.read((char*)(&FileHeader), sizeof(FileHeader));

	if (!IsLabelDataFile(FileHeader.MagicNumber, 4))
	{
		return cv::Mat();
	}

	int NumberOfImage = ConvertCharArrayToInt(FileHeader.NumberOfLabels, 4);

	return ReadLabelData(File, NumberOfImage);
}

  • 1
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值