TensorFlow深度学习4

该博客介绍了如何运用TensorFlow构建全连接网络进行MNIST手写体数字识别,通过L2正则化、滑动平均模型和学习率衰减等优化手段,使模型在识别任务中达到98%的高准确率。
摘要由CSDN通过智能技术生成

1全连接网络的经典设计,基于MNIST手写体数字识别数据集。

1.1 下面这个网络用到了很多的优化方法,如L2正则化、滑动平均模型以及学习率衰减,这样可以提高模型准确率。准确率保持在98%的水平。

import tensorflow as tf

#数据的读入
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/home/jiangziyang/MNIST_data",one_hot=True)

batch_size = 100                #设置每一轮训练的batch大小
learning_rate = 0.8             #学习率
learning_rate_decay = 0.999     #学习率的衰减
max_steps = 30000               #最大训练步数

#定义存储训练轮数的变量,在使用Tensorflow训练神经网络时,
#一般会将代表训练轮数的变量通过trainable参数设置为不可训练的
training_step = tf.Variable(0,trainable=False)

#定义得到隐藏层和输出层的前向传播计算方式,激活函数使用relu()
def hidden_layer(input_tensor,weights1,biases1,weights2,biases2,layer_name):
    layer1=tf.nn.relu(tf.matmul(input_tensor,weights1)+biases1)
    return tf.matmul(layer1,weights2)+biases2

x = tf.placeholder(tf.float32,[None,784],name="x-input")   #INPUT_NODE=784
y_ = tf.placeholder(tf.float32,[None,10],name="y-output")   #OUT_PUT=10
#生成隐藏层参数,其中weights包含784x500=392000个参数
weights1=tf.Variable(tf.truncated_normal([784,500],stddev=0.1))
biases1 = tf.Variable(tf.constant(0.1,shape=[500]))
#生成输出层参数,其中weights2包含500x10=5000个参数
weights2 = tf.Variable(tf.truncated_normal([500, 10], stddev=0.1))
biases2 = tf.Variable(tf.constant(0.1, shape=[10]))

#计算经过神经网络前向传播后得到的y的值,这里没有使用滑动平均
y = hidden_layer(x,weights1,biases1,weights2,biases2,'y')


#初始化一个滑动平均类,衰减率为0.99
#为了使模型在训练前期可以更新地更快,这里提供了num_updates参数
#并设置为当前网络的训练轮数
averages_class = tf.train.ExponentialMovingAverage(0.99,training_step)
#定义一个更新变量滑动平均值的操作需要向滑动平均类的apply()函数提供一个参数列表
#train_variables()函数返回集合图上Graph.TRAINABLE_VARIABLES中的元素,
#这个集合的元素就是所有没有指定trainable_variables=False的参数
averages_op = averages_class.apply(tf.trainable_variables())
#再次计算经过神经网络前向传播后得到的y的值,这里使用了滑动平均,但要牢记滑动平均值只是一个影子变量
average_y 
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值