Tensorflow MNIST原始图片TFRecord方式识别---3. 从TFRecord文件中,提取图片数据,进行训练
这篇文章是续《Tensorflow MNIST原始图片TFRecord方式识别》的“提取图片数据,进行训练”章节。
1. 提取图片数据
提取图片数据,分两步:
1)从tfrecord读取原始数据;
2)随机抽取这些数据以batch_size大小批处理。每次批处理的训练,最好存在多个样本一起,可以加快训练速度、且使梯度下降更加平稳。如果每次批处理训练,重样样本占50%以上,训练模型基本作废,识别准确率很差。
1.1 从tfrecord读取原始数据
def read_data_sets(tfrecord_file):
"""从TFRecord文件中提取训练图像数据"""
reader = tf.TFRecordReader()
filename_queue = tf.train.string_input_producer([tfrecord_file])
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
# inference.IMAGE_SIZE是在inference py文件中定义的常量 28
image = tf.reshape(image, [inference.IMAGE_SIZE, inference.IMAGE_SIZE, 1])
image = tf.cast(image, tf.float32) * (1. / 255)
label = tf.cast(features['label'], tf.int32)
return image, label
1.2 以batch_size大小批处理随机抽取
def input(tfrecord_file, batch_size):
"""从TFRecord文件中批处理随机读取batchsize数据"""
if not os.path.isfile(tfrecord_file):
# HAND_SHUFFLE_DATA_PATH 是手动打乱0-9数据的文件路径。
# gen_hand_shuffle_tfrecord 生成tfrecord文件的函数。在文章《Tensorflow MNIST原始图片TFRecord方式识别--
# -1. 原始图片生成TFRecord文件》有定义。
gen_hand_shuffle_tfrecord(HAND_SHUFFLE_DATA_PATH, tfrecord_file)
image, label = read_data_sets(tfrecord_file)
# 开始随机读取batch_size个数图像数据
images, labels = tf.train.shuffle_batch(
[image, label],
batch_size=batch_size,
num_threads=1,
capacity=200,
min_after_dequeue=100)
'''
# 这样运行打印就可以.测试代码
with tf.Session() as sess:
# 启动多线程处理输入数据。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(10):
label_= sess.run(label)
print("前十个label i = %d, label = %d" %(i, label_))
'''
return images, labels
2. 训练模型
import tensorflow as tf
import inference
import os
import cv2
import numpy as np
import input_data
#神经网络相关参数
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.01
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 6000
MOVING_AVERAGE_DECAY = 0.99
ROOT_PATH = "../src/LeNet_Mnist_Origin/"
#训练过程
def train():
# input_data.HAND_SHUFFLE_DATA_TFRECORD这个常量在input_data文件中定义,值为
# '../src/LeNet_Mnist_Origin/tfrecord_data/hand_shuffle_mnist.tfrecord'
xs, ys = input_data.input(input_data.HAND_SHUFFLE_DATA_TFRECORD, BATCH_SIZE)
# inference.IMAGE_SIZE 在inference文件中定义的常量图像大小 28
xs_reshpae = tf.reshape(xs,[
BATCH_SIZE,
inference.IMAGE_SIZE,
inference.IMAGE_SIZE,
1])
y_ = tf.one_hot(ys, 10, on_value=1.0, off_value=0.0)
ckpt_dir = os.path.join(ROOT_PATH, 'ckpt_dir')
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
# L2正则化
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
# inference.inference(...)接口在inference文件中定义。就是前面讲的模型。
y = inference.inference(xs_reshpae,False,regularizer)
global_step = tf.Variable(0, trainable=False)
saver = tf.train.Saver()
# 定义损失函数、学习率、滑动平均操作以及训练过程。
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
# sparse softmax交叉熵
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
# 交叉熵矩阵的所有元素均值
cross_entropy_mean = tf.reduce_mean(cross_entropy)
# get_collection返回一个列表,这个列表是所有这个集合中的元素. 通过add_n将列表的元素值相加.
loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
# 指数衰减法的学习率。初始学习率LEARNING_RATE_BASE=0.01
# 迭代当前轮 global_step,
# 衰减速度mnist.train.num_examples/BATCH_SIZE=minist样本数/100
# 衰减系数LEARNING_RATE_DECAY=0.99
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
60000 / BATCH_SIZE,
LEARNING_RATE_DECAY,
staircase=True)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
with tf.control_dependencies([train_step, variables_averages_op]):
train_op = tf.no_op(name='train')
# 初始化TensorFlow持久化类。
saver = tf.train.Saver()
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
ckpt = tf.train.get_checkpoint_state(ckpt_dir)
if ckpt and ckpt.model_checkpoint_path:
print(ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path) # restore all variables
else:
tf.global_variables_initializer().run()
for i in range(TRAINING_STEPS):
yy = sess.run(y_)
if i % 50 == 0:
print(yy)
_, loss_value, step = sess.run([train_op, loss, global_step])
if i % 10 == 0:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
saver.save(sess, ckpt_dir + "/hand_shuffle_model.ckpt")
coord.request_stop()
coord.join(threads)
#主程序入口
def main(argv=None):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# if not os.path.isfile(input_data.SMALL_TFRECORD_FILE):
# print("tfrecord file not exist, need to gen it")
# num_examples = input_data.gen_tfrecord(input_data.ORIGIN_DATA_PATH, input_data.SMALL_TFRECORD_FILE)
# print("examples num ", num_examples)
# image, label = input_data.input(input_data.SMALL_TFRECORD_FILE, BATCH_SIZE)
# with tf.Session() as sess:
# coord = tf.train.Coordinator()
# threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# for i in range(500):
# x, y = sess.run([image, label])
# print(y)
# print(x.shape)
# print("\n")
# coord.request_stop()
# coord.join(threads)
train()
if __name__ == '__main__':
main()