TensorFlow 训练 CNN 分类器
前面两篇文章分别介绍了怎么安装 TensorFlow 和怎么使用 TensorFlow 自带的目标检测 API。从这边文章开始介绍怎么使用 TensorFlow 来搭建自己的网络,怎么保存训练好的模型,怎么导入保存的模型来推断,怎么使用更方便 tf.contrib.slim 来训练神经网络等。
本文通过一个简单的分类任务来说明怎么使用 TensorFlow 训练 CNN 模型。
一、简单的 10 分类任务
现在有一个任务,需要训练一个 10 分类器,区分图一中的图像。这些图像都是通过 Python 的一个自动生成验证码的第三方库 captcha (使用 sudo pip/pip3 install captcha 安装)随机生成的,每张图像都包含 0-9 这 10 个数字中的一个,可以看到图像带有很强的背景噪声。图像的大小为 28 x 28 像素,命名规则为 image序号_类标号.jpg
。从图一可以发现,通过肉眼只有第 1 行第 4 张,第 2 行第 4 张,第 5 行第 4 张和最后一行第 4 张图像稍微容易辨认一点。
为了能够训练出一个准确率比较高的分类器,需要准备大量的训练数据,使用如下代码生成 50000 张训练图像:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 22 13:43:34 2018
@author: shirhe-lyh
"""
import cv2
import numpy as np
from captcha.image import ImageCaptcha
def generate_captcha(text='1'):
"""Generate a digit image."""
capt = ImageCaptcha(width=28, height=28, font_sizes=[24])
image = capt.generate_image(text)
image = np.array(image, dtype=np.uint8)
return image
if __name__ == '__main__':
output_dir = './datasets/images/'
for i in range(50000):
label = np.random.randint(0, 10)
image = generate_captcha(str(label))
image_name = 'image{}_{}.jpg'.format(i+1, label)
output_path = output_dir + image_name
cv2.imwrite(output_path, image)
这些图像保存在文件夹 ./datasets/images/ 内,实际执行上述代码时请手动创建该文件夹,或者指定其它文件夹。
二、创建简单的 CNN 模型
众所周知,机器学习/深度学习的算法都有相同的模式(或流程),一般包括数据预处理、预测、后处理和计算损失这几个过程。所以为了一般化使用,可以定义一个抽象类,以后的模型定义都继承自该类。前面已经注意到了,通过 captcha 生成的 50000 张训练图像使用肉眼已经较难区分,因此需要搭建一个比较深的网络,下面创建的网络包括 6 个卷基层和 3 个全连接层,准确率已经可以达到 99% 以上了。
TensorFlow 建立神经网络通过底层的 tf.nn
模块实现,如卷积操作通过函数 tf.nn.conv2d
来实现,池化操作通过函数 tf.nn.max_pool
来实现,而全连接层没有封装的现成函数,需要通过矩阵乘法 tf.matmul
和加法 tf.add
自己实现(以后会介绍创建神经网络更方便的模块 tf.contrib.slim
)。
话不多说,直接上代码(命名为 model.py):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 16:54:02 2018
@author: shirhe-lyh
"""
import tensorflow as tf
from abc import ABCMeta
from abc import abstractmethod
class BaseModel(object):
"""Abstract base class for any model."""
__metaclass__ = ABCMeta
def __init__(self, num_classes):
"""Constructor.
Args:
num_classes: Number of classes.
"""
self._num_classes = num_classes
@property
def num_classes(self):
return self._num_classes
@abstractmethod
def preprocess(self, inputs):
"""Input preprocessing. To be override by implementations.
Args:
inputs: A float32 tensor with shape [batch_size, height, width,
num_channels] representing a batch of images.
Returns:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, widht, num_channels] representing a batch of images.
"""
pass
@abstractmethod
def predict(self, preprocessed_inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
pass
@abstractmethod
def postprocess(self, prediction_dict, *