Tensorflow MNIST原始图片TFRecord方式识别---3. 从TFRecord文件中,提取图片数据,进行训练

Tensorflow MNIST原始图片TFRecord方式识别---3. 从TFRecord文件中,提取图片数据,进行训练


这篇文章是续《Tensorflow MNIST原始图片TFRecord方式识别》的“提取图片数据,进行训练”章节。

1. 提取图片数据

提取图片数据,分两步:
1)从tfrecord读取原始数据;
2)随机抽取这些数据以batch_size大小批处理。每次批处理的训练,最好存在多个样本一起,可以加快训练速度、且使梯度下降更加平稳。如果每次批处理训练,重样样本占50%以上,训练模型基本作废,识别准确率很差。

1.1 从tfrecord读取原始数据

def read_data_sets(tfrecord_file):
    """从TFRecord文件中提取训练图像数据"""
    reader = tf.TFRecordReader()
    filename_queue = tf.train.string_input_producer([tfrecord_file])
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
        })

    image = tf.decode_raw(features['image_raw'], tf.uint8)
    # inference.IMAGE_SIZE是在inference py文件中定义的常量 28
    image = tf.reshape(image, [inference.IMAGE_SIZE, inference.IMAGE_SIZE, 1])
    image = tf.cast(image, tf.float32) * (1. / 255)

    label = tf.cast(features['label'], tf.int32)
    return image, label

1.2 以batch_size大小批处理随机抽取

def input(tfrecord_file, batch_size):
    """从TFRecord文件中批处理随机读取batchsize数据"""
    if not os.path.isfile(tfrecord_file):
        # HAND_SHUFFLE_DATA_PATH 是手动打乱0-9数据的文件路径。
        # gen_hand_shuffle_tfrecord 生成tfrecord文件的函数。在文章《Tensorflow MNIST原始图片TFRecord方式识别--
        # -1. 原始图片生成TFRecord文件》有定义。
        gen_hand_shuffle_tfrecord(HAND_SHUFFLE_DATA_PATH, tfrecord_file)
    image, label = read_data_sets(tfrecord_file)
    # 开始随机读取batch_size个数图像数据
    images, labels = tf.train.shuffle_batch(
        [image, label],
        batch_size=batch_size,
        num_threads=1,
        capacity=200,
        min_after_dequeue=100)
    '''
    # 这样运行打印就可以.测试代码
    with tf.Session() as sess:
        # 启动多线程处理输入数据。
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(10):
            label_= sess.run(label)
            print("前十个label i = %d, label = %d" %(i, label_))
    '''

    return images, labels

2. 训练模型

import tensorflow as tf
import inference
import os
import cv2
import numpy as np
import input_data

#神经网络相关参数
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.01
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 6000
MOVING_AVERAGE_DECAY = 0.99
ROOT_PATH = "../src/LeNet_Mnist_Origin/"

#训练过程
def train():
    # input_data.HAND_SHUFFLE_DATA_TFRECORD这个常量在input_data文件中定义,值为
    # '../src/LeNet_Mnist_Origin/tfrecord_data/hand_shuffle_mnist.tfrecord'
    xs, ys = input_data.input(input_data.HAND_SHUFFLE_DATA_TFRECORD, BATCH_SIZE)
    # inference.IMAGE_SIZE 在inference文件中定义的常量图像大小 28
    xs_reshpae = tf.reshape(xs,[
        BATCH_SIZE,
        inference.IMAGE_SIZE,
        inference.IMAGE_SIZE,
        1])
    y_ = tf.one_hot(ys, 10, on_value=1.0, off_value=0.0)
    ckpt_dir = os.path.join(ROOT_PATH, 'ckpt_dir')
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    # L2正则化
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    # inference.inference(...)接口在inference文件中定义。就是前面讲的模型。 
    y = inference.inference(xs_reshpae,False,regularizer)
    global_step = tf.Variable(0, trainable=False)
    saver = tf.train.Saver()

    # 定义损失函数、学习率、滑动平均操作以及训练过程。
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    # sparse softmax交叉熵
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
    # 交叉熵矩阵的所有元素均值
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    # get_collection返回一个列表,这个列表是所有这个集合中的元素. 通过add_n将列表的元素值相加.
    loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
    # 指数衰减法的学习率。初始学习率LEARNING_RATE_BASE=0.01
    # 迭代当前轮 global_step, 
    # 衰减速度mnist.train.num_examples/BATCH_SIZE=minist样本数/100
    # 衰减系数LEARNING_RATE_DECAY=0.99
    learning_rate = tf.train.exponential_decay(
        LEARNING_RATE_BASE,
        global_step,
        60000 / BATCH_SIZE,
        LEARNING_RATE_DECAY,
        staircase=True)

    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
    with tf.control_dependencies([train_step, variables_averages_op]):
        train_op = tf.no_op(name='train')
        
    # 初始化TensorFlow持久化类。
    saver = tf.train.Saver()
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        ckpt = tf.train.get_checkpoint_state(ckpt_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path) # restore all variables
        else:
            tf.global_variables_initializer().run()
        for i in range(TRAINING_STEPS):
            yy = sess.run(y_)
            if i % 50 == 0:
                print(yy)
            _, loss_value, step = sess.run([train_op, loss, global_step])
            if i % 10 == 0:
                print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
        saver.save(sess, ckpt_dir + "/hand_shuffle_model.ckpt")
        coord.request_stop()
        coord.join(threads)

#主程序入口
def main(argv=None):
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

    # if not os.path.isfile(input_data.SMALL_TFRECORD_FILE):
    #     print("tfrecord file not exist, need to gen it")
    #     num_examples = input_data.gen_tfrecord(input_data.ORIGIN_DATA_PATH, input_data.SMALL_TFRECORD_FILE)
    #     print("examples num  ", num_examples)

    # image, label = input_data.input(input_data.SMALL_TFRECORD_FILE, BATCH_SIZE)
    # with tf.Session() as sess:
    #     coord = tf.train.Coordinator()
    #     threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    #     for i in range(500):
    #         x, y = sess.run([image, label])
    #         print(y)
    #         print(x.shape)
    #         print("\n")
    #     coord.request_stop()
    #     coord.join(threads)

    train()

if __name__ == '__main__':
    main()

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值