tensorflow构建神经网络识别MNIST数据集样例代码

       最近在学习郑泽宇老师的《Tensorflow实战Google深度学习框架》,书中样例代码写的非常简洁、易懂,而且逻辑性很强,在这里进行记录一下。

       以下样例代码使用tensorflow框架构建两层全连接神经网络,识别MNIST手写数字数据集。其中用到了一些优化方法:使用滑动平均模型控制权值参数的变化率、定义学习率的衰减率控制学习率的变化率,使得在模型训练初期,模型参数变化幅度较大,模型向着最优化的方向快速移动,当逐渐接近最优值时,模型参数变化率逐渐降低,逐渐逼近最优解。

#coding:utf-8
"""
    Created by cheng star at 2018/9/2 15:57
    @email : xxcheng0708@163.com
"""

import os , sys , time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

module_path = os.getcwd()
mnist_data = module_path + "/../mnist_data/"

# MNIST 数据集相关参数
INPUT_NODE = 784    # 输入层神经元个数
OUTPUT_NODE = 10    # 输出层神经元个数

# 配置神经网络参数
LAYER1_NODE = 500   #  隐藏层神经元个数
BATCH_SIZE = 100    # 每个批次训练样本的大小

LEARNING_RATE_BASE = 0.8    # 基础学习率
LEARNING_RATE_DECAY = 0.99  # 学习率的衰减率
REGULARIZATION_RATE = 0.0001    # 正则化项惩罚因子
TRAINING_STEPS = 30000          # 模型训练总次数
MOVING_AVERAGE_DECAY = 0.99     # 滑动平均模型参数衰减率


def inference(input_tensor , avg_class , weights1 , biases1 , weights2 , biases2) :
    """
        g构建2层神经网络模型
    :param input_tensor: 输入特征向量
    :param avg_class: 滑动平均模型函数
    :param weights1: 隐藏层权值参数
    :param biases1:  隐藏层偏置参数
    :param weights2:  输出层权值参数
    :param biases2: 输出层偏置参数
    :return:
    """
    # 没有滑动平均模型
    if avg_class == None :
        layer1 = tf.nn.relu(tf.matmul(input_tensor , weights1) + biases1)
        return tf.matmul(layer1 , weights2) + biases2
    else :  # 使用滑动平均模型
        layer1 = tf.nn.relu(tf.matmul(input_tensor , avg_class.average(weights1)) + avg_class.average(biases1))
        return tf.matmul(layer1 , avg_class.average(weights2)) + avg_class.average(biases2)

def train(mnist) :
    """
        训练神经网络模型
    :param mnist:   输入手写数据集
    :return:
    """
    x = tf.placeholder(shape=[None , INPUT_NODE] , dtype=tf.float32 , name="x-input")
    y_ = tf.placeholder(shape=[None , OUTPUT_NODE] , dtype=tf.float32 , name="y-input")

    # 生成隐藏层参数
    with tf.variable_scope("layer1") :
        weights1 = tf.get_variable(name="weights1" ,
                                   shape=[INPUT_NODE , LAYER1_NODE] ,
                                   initializer=tf.truncated_normal_initializer(stddev=0.1))
        biases1 = tf.get_variable(name="biases1" ,
                                  shape=[LAYER1_NODE] ,
                                  initializer=tf.constant_initializer(value=0.1))
    # 生成输出层参数
    with tf.variable_scope("output-layer") :
        weights2 = tf.get_variable(name="weights2" ,
                                   shape=[LAYER1_NODE , OUTPUT_NODE] ,
                                   initializer=tf.truncated_normal_initializer(stddev=0.1))
        biases2 = tf.get_variable(name="biases2" ,
                                  shape=[OUTPUT_NODE] ,
                                  initializer=tf.constant_initializer(value=0.1))

    # weights1 = tf.Variable(initial_value=tf.truncated_normal(shape=[INPUT_NODE , LAYER1_NODE] , stddev=0.1))
    # biases1 = tf.Variable(initial_value=tf.constant(value=0.1 , shape=[LAYER1_NODE]))
    #
    # weights2 = tf.Variable(initial_value=tf.truncated_normal(shape=[LAYER1_NODE , OUTPUT_NODE] , stddev=0.1))
    # biases2 = tf.Variable(initial_value=tf.constant(value=0.1 , shape=[OUTPUT_NODE]))

    # 生成网络模型,不适用滑动平均模型
    y = inference(x , None , weights1 , biases1 , weights2 , biases2)

    # 定义变量global_step,记录训练次数
    global_step = tf.Variable(name="global_step" ,initial_value=0 , trainable=False)

    # 定义滑动平均模型
    variable_average = tf.train.ExponentialMovingAverage(decay=MOVING_AVERAGE_DECAY , num_updates=global_step)
    variable_average_op = variable_average.apply(tf.trainable_variables())

    # 使用滑动平均模型构造模型
    average_y = inference(x , variable_average , weights1 , biases1 , weights2 , biases2)

    # 构造交叉熵损失函数
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y , labels=tf.argmax(y_ , 1))
    # 计算所有样本的交叉熵平均值
    cross_entropy_mean = tf.reduce_mean(cross_entropy)

    # 定义L2正则化损失函数
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    regularization = regularizer(weights1) + regularizer(weights2)

    # 定义总损失函数
    loss = cross_entropy_mean + regularization
    # 设置学习率衰减率
    learning_rate = tf.train.exponential_decay(
        learning_rate=LEARNING_RATE_BASE ,  # 基础学习率
        global_step=global_step ,   # 当前迭代次数
        decay_steps=mnist.train.num_examples / BATCH_SIZE , # 过完所有的训练数据需要的迭代次数
        decay_rate=LEARNING_RATE_DECAY  # 学习率衰减速度
    )
    # 训练模型
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss=loss ,global_step=global_step)

    with tf.control_dependencies([train_step , variable_average_op]) :
        train_op = tf.no_op(name="train")

    # 计算模型准确率
    correct_prediction = tf.equal(tf.argmax(average_y , 1) , tf.argmax(y_ , 1))
    # correct_prediction = tf.equal(tf.argmax(y , 1) , tf.argmax(y_ , 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction , dtype=tf.float32))

    with tf.Session() as sess :
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        # 验证数据集
        validate_feed = {
            x : mnist.validation.images ,
            y_ : mnist.validation.labels
        }
        # 测试数据集
        test_feed = {
            x : mnist.test.images ,
            y_ : mnist.test.labels
        }

        for i in range(TRAINING_STEPS) :
            if i % 1000  == 0 :
                validate_acc , loss_value  , global_step_value= sess.run([accuracy , loss , global_step] , feed_dict=validate_feed)
                print("After {0} rounds training , the global step is {1} ,"
                      "ths loss is {2} , the accuracy on validate dataset is {3}.".format(i , global_step_value ,loss_value , validate_acc))

            xs , ys = mnist.train.next_batch(BATCH_SIZE)
            sess.run(train_op , feed_dict={x : xs , y_ : ys})

        test_acc , loss_value = sess.run([accuracy , loss] , feed_dict=test_feed)
        print("After {0} rounds training , the loss is {1} , the accuracy on test dataset is {2}.".format(TRAINING_STEPS , loss_value , test_acc))

def main(argv = None) :
    mnist = input_data.read_data_sets(train_dir=mnist_data , one_hot=True)
    train(mnist)

if __name__ == "__main__" :
    tf.app.run()

模型输出结果如下所示:

C:\ProgramData\Anaconda3\python.exe E:/程序/python代码/LearningAI/learning_tensorflow/test_mnist_zzy_demo.py
WARNING:tensorflow:From E:/程序/python代码/LearningAI/learning_tensorflow/test_mnist_zzy_demo.py:150: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting E:\程序\python代码\LearningAI\learning_tensorflow/../mnist_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting E:\程序\python代码\LearningAI\learning_tensorflow/../mnist_data/train-labels-idx1-ubyte.gz
Extracting E:\程序\python代码\LearningAI\learning_tensorflow/../mnist_data/t10k-images-idx3-ubyte.gz
Extracting E:\程序\python代码\LearningAI\learning_tensorflow/../mnist_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
After 0 rounds training , the global step is 0 ,ths loss is 3.164243459701538 , the accuracy on validate dataset is 0.0957999974489212.
After 1000 rounds training , the global step is 1000 ,ths loss is 0.23361584544181824 , the accuracy on validate dataset is 0.9760000109672546.
After 2000 rounds training , the global step is 2000 ,ths loss is 0.2088349461555481 , the accuracy on validate dataset is 0.9814000129699707.
After 3000 rounds training , the global step is 3000 ,ths loss is 0.19048267602920532 , the accuracy on validate dataset is 0.9833999872207642.
After 4000 rounds training , the global step is 4000 ,ths loss is 0.1719156801700592 , the accuracy on validate dataset is 0.9833999872207642.
After 5000 rounds training , the global step is 5000 ,ths loss is 0.16165196895599365 , the accuracy on validate dataset is 0.9855999946594238.
After 6000 rounds training , the global step is 6000 ,ths loss is 0.1522310972213745 , the accuracy on validate dataset is 0.9850000143051147.
After 7000 rounds training , the global step is 7000 ,ths loss is 0.14527904987335205 , the accuracy on validate dataset is 0.9847999811172485.
After 8000 rounds training , the global step is 8000 ,ths loss is 0.13154302537441254 , the accuracy on validate dataset is 0.9850000143051147.
After 9000 rounds training , the global step is 9000 ,ths loss is 0.12546882033348083 , the accuracy on validate dataset is 0.9851999878883362.
After 10000 rounds training , the global step is 10000 ,ths loss is 0.1196407675743103 , the accuracy on validate dataset is 0.9854000210762024.
After 11000 rounds training , the global step is 11000 ,ths loss is 0.11399652063846588 , the accuracy on validate dataset is 0.9864000082015991.
After 12000 rounds training , the global step is 12000 ,ths loss is 0.109804168343544 , the accuracy on validate dataset is 0.9850000143051147.
After 13000 rounds training , the global step is 13000 ,ths loss is 0.11330102384090424 , the accuracy on validate dataset is 0.9854000210762024.
After 14000 rounds training , the global step is 14000 ,ths loss is 0.10213477909564972 , the accuracy on validate dataset is 0.9843999743461609.
After 15000 rounds training , the global step is 15000 ,ths loss is 0.09962807595729828 , the accuracy on validate dataset is 0.9854000210762024.
After 16000 rounds training , the global step is 16000 ,ths loss is 0.09647612273693085 , the accuracy on validate dataset is 0.9851999878883362.
After 17000 rounds training , the global step is 17000 ,ths loss is 0.0948617160320282 , the accuracy on validate dataset is 0.9854000210762024.
After 18000 rounds training , the global step is 18000 ,ths loss is 0.09350297600030899 , the accuracy on validate dataset is 0.98580002784729.
After 19000 rounds training , the global step is 19000 ,ths loss is 0.09059648215770721 , the accuracy on validate dataset is 0.9855999946594238.
After 20000 rounds training , the global step is 20000 ,ths loss is 0.08989834785461426 , the accuracy on validate dataset is 0.9851999878883362.
After 21000 rounds training , the global step is 21000 ,ths loss is 0.08760837465524673 , the accuracy on validate dataset is 0.98580002784729.
After 22000 rounds training , the global step is 22000 ,ths loss is 0.08716955780982971 , the accuracy on validate dataset is 0.9854000210762024.
After 23000 rounds training , the global step is 23000 ,ths loss is 0.08485446870326996 , the accuracy on validate dataset is 0.98580002784729.
After 24000 rounds training , the global step is 24000 ,ths loss is 0.08533652126789093 , the accuracy on validate dataset is 0.9855999946594238.
After 25000 rounds training , the global step is 25000 ,ths loss is 0.08394122123718262 , the accuracy on validate dataset is 0.9851999878883362.
After 26000 rounds training , the global step is 26000 ,ths loss is 0.08382612466812134 , the accuracy on validate dataset is 0.9851999878883362.
After 27000 rounds training , the global step is 27000 ,ths loss is 0.08189772069454193 , the accuracy on validate dataset is 0.9861999750137329.
After 28000 rounds training , the global step is 28000 ,ths loss is 0.08306366205215454 , the accuracy on validate dataset is 0.98580002784729.
After 29000 rounds training , the global step is 29000 ,ths loss is 0.08183949440717697 , the accuracy on validate dataset is 0.98580002784729.
After 30000 rounds training , the loss is 0.08004538714885712 , the accuracy on test dataset is 0.9836000204086304.

Process finished with exit code 0

参考文献:《Tensorflowshiz实战Google深度学习框架》 郑泽宇、顾思宇 等著

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值