Tensorflow笔记(4):全连接网络基础

5.1 MNIST 数据集

MNIST数据集: 
6W张28*28的0~9手写数字图片和标签,用于训练 
1W张28*28的0~9手写数字图片和标签,用于测试 
每张图片的784个像素点(28*28)组成长度为784的一维数组,作为输入特征 
图片的标签(0~9)以一维数组的形式给出,每个元素表示对应分类出现的概率 
TF 提供 input_data 模块自动读取数据集

from tensorflow.examples.tutorials.minist import input_data
minist = input_data.read_data_set('./data/', one_hot=True)

返回各子集样本数

mnist.train.num_examples    #返回训练集样本数
mnist.validation.num_examples #返回验证集样本数
mnist.test.num_examples #返回测试集样本数

返回标签和数据

mnist.train.labels[0]   #返回标签
mnist.train.images[0]   #返回数据

取一小撮数据,准备喂入神经网络

BATCH_SIZE = 200    #定义batch size
xs, ys = mnist.train.next_batch(BATCH_SIZE)

一些常用的函数

tf.get_collection("")       #从集合中取全部变量,生成一个列表
tf.add_n([])                    #列表内对应元素相加
tf.cast(x, dtype)           #把x转换为dtype类型
tf.argmax(x, axis)      #返回最大值所在索引号 如: tf.argmax([1,0,0], 1) 返回0
import os
os.path.join("home", "name")    #f返回home/name
字符串.split()             #按照指定的拆分符对字符串切片,返回分割后的列表
#如:'./model/mnist_model-1001'.split('-')[-1] 返回1001
with tf.Graph().as_default() as g:      #其内定义的节点在计算图g中

保存模型

saver = tf.train.Saver()            #实例化saver对象
with tf.Session() as sess:          #在with结构for循环中一定轮数时保存模型到当前会话
    for i in ranges(STEPS):         #拼接成./MODEL_SAVE_PATH/MODEL_NAME-global_step
        if i  % 轮数 == 0:
            saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step = global_step)

加载模型

with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state(存储路径)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess,ckpt.model_checkpoint_path)

实例化可还原滑动平均值的saver

ema = tf.train.ExponentialMovingAverage(滑动平均基数)
ema_restore = ema.variables_to_restore()
saver = tf.train.Saver(ema_restore)

准确率计算方法

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

5.2 模块化搭建神经网络八股

forward.py

def forward(x, regularizer):
    w = 
    b = 
    y = 
    return y
def get_weight(shape, regularizer):
    pass
def get_bias(shape):
    pass

backward.py

def backward(mnist):
    x =
    y_ =
    y =         #复现前向传播,计算出y
    global_step = 
    loss =
    <正则化,指数衰减学习率,滑动平均>
    train_step = 
    实例化Saver
    with tf.Session() as sess:
        初始化
        for i in range(STEPS):
            sess.run(train_step,feed_dict={x:, y_:})
            if i%轮数 ==0:
                print
                saver.save()

损失函数loss含正则化regularization 
backward.py中加入

ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))
cem = tf.reduce_mean(ce)
loss = cem+tf.add_n(tf.get_collection('losses'))

forward.py中加入

if regularizer != None:tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))

学习率learning_rate 
backward.py中加入

learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
LEARNING_RATE_STEP,
LEARNING_RATE_DECAY,
staircase = True)

滑动平均ema

ema = tf.train.ExponentialMovingAverage(衰减率MOVING_AVERAGE_DECAY,当前轮数global_step)
ema_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([train_step,ema_op]):
    train_op = tf.no_op(name='train')

test.py

def test(mnist):
    with tf.Graph()as_default()as g:
        x = 
        y_ = 
        y = 
        实例化可还原滑动平均值的saver
        计算正确率
        while True:
            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state(存储路径)    #加载ckpt模型
                if ckpt and ckpt.model_checkpoint_path:         #如果已经有ckpt模型则恢复
                    saver.restore(sess,ckpt.model_checkpoint_path) #恢复会话
                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] #恢复轮数
                    accuracy_score = sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels}) #计算准确率
                    print("After %s training steps, test accuracy = %g" % (global_step, accuracy_score))
                else: #如果没有模型
                    print("No checkpoint file found!")  #给出提示
    return

def main():
    mnist = input_data.read_data_sets("./data/", one_hot=True)
    test(mnist)

if __name__=='__main__':
    main()    

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值