tensorflow2 搭建LeNet5训练MINST手写数字数据集并用c++ opencv4.5.5 DNN加载模型预测结果

一、LeNet5网络介绍

LeNet5 这个网络包含了深度学习的基本模块:卷积层,池化层,全链接层。是其他深度学习模型的基础。LeNet-5共有7层,不包含输入,每层都包含可训练参数;每个层有多个Feature Map,每个FeatureMap通过一种卷积滤波器提取输入的一种特征,然后每个FeatureMap有多个神经元。
在这里插入图片描述

二、环境搭建

本人环境配置如下:

pycharm2021
vs2022
Anaconda3
tensorflow=2.3
opencv=4.5.5

前几个安装相对轻松,直接上官网安装即可,tensorflow使用pip命令安装,c++ opencv相对较为麻烦,可以参考本人以前的安装方法:c++ opencv 学习笔记(一) Visual Studio 2019 + OpenCV4.5.5 配置详解

三、网络搭建以及训练

3.1、加载数据集

tensorflow内置了MINST数据集,从tensorflow中导入即可

import tensorflow as tf
mnist = tf.keras.datasets.mnist
train, test = mnist.load_data()

将数据按照batch提供给网络模型

import numpy as np

class MNISTData:
    def __init__(self, data, need_shuffle, batch_size=128):
         """
        :param datas: 数据集,格式为 data,label
        :param shuffle: 是否随机打乱数据 True or False
        :param batch_size: 一批数据大小
        """
        self._data = data[0]
        self._labels = data[1]
        self.num_examples = self._data.shape[0]
        self._need_shuffle = need_shuffle
        self._indicator = 0
        self._batch_size = batch_size
        if self._need_shuffle:
            self._shuffle_data()

    def __iter__(self):
        return self

    def _shuffle_data(self):
        p = np.random.permutation(self.num_examples)
        self._data = self._data[p]
        self._labels = self._labels[p]

    def next_batch(self):
        end_indicator = self._indicator + self._batch_size
        if end_indicator > self.num_examples:
            if self._need_shuffle:
                self._shuffle_data()
                self._indicator = 0
                end_indicator = self._batch_size
            else:
                self._indicator = 0
                end_indicator = self._batch_size
        if end_indicator > self.num_examples:
            raise StopIteration
        batch_data = self._data[self._indicator: end_indicator] / 255.0 # 归一化
        batch_labels = self._labels[self._indicator: end_indicator]
        self._indicator = end_indicator

        return batch_data, batch_labels

    def __next__(self):
        return self.next_batch()
        
train_dataset = dataset.MNISTData(train, True)
test_dateset = dataset.MNISTData(test, False)

查看数据集

def display(train_images, train_labels):
    plt.figure(figsize=(10,10))
    for i in range(25):
        plt.subplot(5,5,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(train_images[i], cmap=plt.cm.binary)
        plt.xlabel(train_labels[i])
    plt.show()
    
for data in train_dataset:
    display(*data)

在这里插入图片描述

3.2、网络搭建

使用tensorflow中的keras搭建网络结构,激活函数使用Mish
在这里插入图片描述

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import *
from tensorflow.keras.utils import get_custom_objects


class Mish(Activation):
    def __init__(self, activate, **kwargs):
        super(Mish, self).__init__(activate, **kwargs)
        self.__name__ = "Mish"


def mish(inputs):
    return inputs * tf.math.tanh(tf.math.softplus(inputs))


def LeNet5(input_shape=[32, 32, 3]):
	get_custom_objects().update({'Mish': Mish(mish)})
	#输入层
    inputs = Input(shape=input_shape)
	#第一个卷积-池化层
    conv1 = Conv2D(6, 5, activation="relu", padding='same')(inputs)
    pool1 = MaxPooling2D((2, 2))(conv1)
    #第二个卷积-池化层
    conv2 = Conv2D(16, 5, activation="relu", padding='same')(pool1)
    pool2 = MaxPooling2D((2, 2))(conv2)
    #第三个卷积层
    conv2 = Conv2D(120, 5, activation="relu", padding='same')(pool2)
    fc = Flatten()(conv2)
	#全连接层
    fc1 = Dense(120, activation="relu")(fc)
    #输出层
    fc2 = Dense(10, activation="softmax")(fc1)
    model = Model(inputs, fc2)

    return model
model = LeNet5(input_shape=[28, 28, 1])

在这里插入图片描述

定义损失函数以及优化器

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy']
)

保存模型

model_filepath = 'model/'
checkpoint_filepath = model_filepath + 'tmp/'
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_best_only=True,
    save_weights_only=True,
    monitor='accuracy',
    mode='max'
)

3.3、模型训练

开始训练(俗称炼丹)

# 是否使用GPU
use_gpu = True
tf.debugging.set_log_device_placement(True)
if use_gpu:
    gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(device=gpu, enable=True)
            tf.print(gpu)

    else:
        os.environ["CUDA_VISIBLE_DEVICE"] = "-1"

else:
    os.environ["CUDA_VISIBLE_DEVICE"] = "-1"

# TensorBoard可视化工具
log_path = 'logging/'
logging = tf.keras.callbacks.TensorBoard(log_dir=log_path)
model_filepath = 'model/'
checkpoint_filepath = model_filepath + 'tmp/'
history = model.fit(
    train_dataset,
    epochs=10,
    steps_per_epoch=train_dataset.num_examples // BATCH_SIZE + 1,
    validation_data=test_dateset,
    validation_steps=test_dateset.num_examples // BATCH_SIZE + 1,
    callbacks=[cp_callback, logging ]
)

model.load_weights(checkpoint_filepath)
model.save(model_filepath + 'model')

在这里插入图片描述

可视化训练过程
TensorBoard是一个可视化工具,它可以用来展示网络图、张量的指标变化、张量的分布情况等。进入logging文件夹的上一层文件夹,在DOS窗口运行命令:

tensorboard --logdir=./logging

在浏览器输入网址:http://localhost:6006,或者输入上图提示的网址,即可查看生成图。

在这里插入图片描述
在这里插入图片描述

3.4、模型固化

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

def export_frozen_graph(model, name, input_size) :
	f = tf.function(lambda x: model(x))
	f = f.get_concrete_function(x=tf.TensorSpec(shape=[None, input_size[0], input_size[1], input_size[2]], dtype=tf.float32))
	f2 = convert_variables_to_constants_v2(f)
	graph_def = f2.graph.as_graph_def()

	# Export frozen graph
	with tf.io.gfile.GFile(name, 'wb') as f:
		f.write(graph_def.SerializeToString())
		
export_frozen_graph(model, model_filepath + 'frozen_graph.pb', (input_size, input_size, 1))

四、c++ opencv加载模型

#include <opencv2/opencv.hpp>
#include <iostream>
#include <vector>

using namespace std;

//多分类问题用这个函数判断类别,二分类的话不用也行
std::vector<int> Argmax(cv::Mat x)
{
	std::vector<int> res;
	for (int i = 0; i < x.rows; i++)
	{
		int maxIdx = 0;
		float maxNum = 0.0;
		for (int j = 0; j < x.cols; j++)
		{
			float tmp = x.at<float>(i, j);
			if (tmp > maxNum)
			{
				maxIdx = j; //更新最优值序号
				maxNum = tmp; //更新最优值
			}
		}
		res.push_back(maxIdx); //最优预测值的序号
	}
	return res;
}

int main()
{
	//cv加载模型
	cv::dnn::Net net = cv::dnn::readNetFromTensorflow("frozen_graph.pb");
	//加载图片
	cv::Mat src = cv::imread("8.jpg", cv::IMREAD_COLOR);
	cv::Mat img = src;
	cv::cvtColor(img, img, cv::COLOR_BGR2GRAY);
	//调整图片大小
	cv::resize(img, img, cv::Size(28, 28));
	//归一化 0-1之间
	img.convertTo(img, CV_32FC1, 1.f / 255.f, -1.f);
	//格式转化
	cv::dnn::blobFromImage(img, img, 1.0, cv::Size(), cv::Scalar(), false, false, CV_32F);
	//将数据喂给网络
	net.setInput(img);
	//前向传播,得到传播结果
	cv::Mat pred = net.forward();
	//输出结果
	vector<int> res = Argmax(pred);
	
	//输出标签
	stringstream ss;
	string str;
	ss << "label:" << res[0];
	ss >> str;
	//放大图片便于观察
	cv::resize(src, src, cv::Size(280, 280));
	cv::putText(src, str, cv::Size(0, 40), cv::FONT_HERSHEY_COMPLEX, 1, cv::Scalar(0, 255, 0), 1);
	cv::imshow("", src);
	cv::waitKey();
}

结果如下:
在这里插入图片描述
有需要的可以下载完整项目链接进行测试:

GitHub:https://github.com/small-guang/LeNet5
CSDN:https://download.csdn.net/download/qq_45723275/77992089
其他项目链接:tensorflow2.3 搭建 vgg16训练cifar10数据集

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值