TensorFlow 训练 CNN 分类器

本文详述如何使用TensorFlow构建一个10分类的CNN模型,用于识别带有噪声的验证码图像。通过创建包含6个卷积层和3个全连接层的网络,实现了超过99%的准确率。文章涵盖数据预处理、模型定义、训练过程以及模型测试,展示了一个完整的深度学习训练流程。
摘要由CSDN通过智能技术生成

TensorFlow 训练 CNN 分类器

        前面两篇文章分别介绍了怎么安装 TensorFlow 和怎么使用 TensorFlow 自带的目标检测 API。从这边文章开始介绍怎么使用 TensorFlow 来搭建自己的网络,怎么保存训练好的模型,怎么导入保存的模型来推断,怎么使用更方便 tf.contrib.slim 来训练神经网络等。

        本文通过一个简单的分类任务来说明怎么使用 TensorFlow 训练 CNN 模型。

一、简单的 10 分类任务

        现在有一个任务,需要训练一个 10 分类器,区分图一中的图像。这些图像都是通过 Python 的一个自动生成验证码的第三方库 captcha (使用 sudo pip/pip3 install captcha 安装)随机生成的,每张图像都包含 0-9 这 10 个数字中的一个,可以看到图像带有很强的背景噪声。图像的大小为 28 x 28 像素,命名规则为 image序号_类标号.jpg。从图一可以发现,通过肉眼只有第 1 行第 4 张,第 2 行第 4 张,第 5 行第 4 张和最后一行第 4 张图像稍微容易辨认一点。
图1 由 captcha 生成的 0-9 数字图像

        为了能够训练出一个准确率比较高的分类器,需要准备大量的训练数据,使用如下代码生成 50000 张训练图像:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 22 13:43:34 2018

@author: shirhe-lyh
"""

import cv2
import numpy as np

from captcha.image import ImageCaptcha


def generate_captcha(text='1'):
    """Generate a digit image."""
    capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
    image = capt.generate_image(text)
    image = np.array(image, dtype=np.uint8)
    return image


if __name__ == '__main__':
    output_dir = './datasets/images/'
    for i in range(50000):
        label = np.random.randint(0, 10)
        image = generate_captcha(str(label))
        image_name = 'image{}_{}.jpg'.format(i+1, label)
        output_path = output_dir + image_name
        cv2.imwrite(output_path, image)

这些图像保存在文件夹 ./datasets/images/ 内,实际执行上述代码时请手动创建该文件夹,或者指定其它文件夹。

二、创建简单的 CNN 模型

        众所周知,机器学习/深度学习的算法都有相同的模式(或流程),一般包括数据预处理、预测、后处理和计算损失这几个过程。所以为了一般化使用,可以定义一个抽象类,以后的模型定义都继承自该类。前面已经注意到了,通过 captcha 生成的 50000 张训练图像使用肉眼已经较难区分,因此需要搭建一个比较深的网络,下面创建的网络包括 6 个卷基层和 3 个全连接层,准确率已经可以达到 99% 以上了。

        TensorFlow 建立神经网络通过底层的 tf.nn 模块实现,如卷积操作通过函数 tf.nn.conv2d来实现,池化操作通过函数 tf.nn.max_pool 来实现,而全连接层没有封装的现成函数,需要通过矩阵乘法 tf.matmul 和加法 tf.add 自己实现(以后会介绍创建神经网络更方便的模块 tf.contrib.slim)。

        话不多说,直接上代码(命名为 model.py):

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 16:54:02 2018

@author: shirhe-lyh
"""

import tensorflow as tf

from abc import ABCMeta
from abc import abstractmethod


class BaseModel(object):
    """Abstract base class for any model."""
    __metaclass__ = ABCMeta

    def __init__(self, num_classes):
        """Constructor.

        Args:
            num_classes: Number of classes.
        """
        self._num_classes = num_classes

    @property
    def num_classes(self):
        return self._num_classes

    @abstractmethod
    def preprocess(self, inputs):
        """Input preprocessing. To be override by implementations.

        Args:
            inputs: A float32 tensor with shape [batch_size, height, width,
                num_channels] representing a batch of images.

        Returns:
            preprocessed_inputs: A float32 tensor with shape [batch_size, 
                height, widht, num_channels] representing a batch of images.
        """
        pass

    @abstractmethod
    def predict(self, preprocessed_inputs):
        """Predict prediction tensors from inputs tensor.

        Outputs of this function can be passed to loss or postprocess functions.

        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.

        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        pass

    @abstractmethod
    def postprocess(self, prediction_dict, *
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值