Tensorflow笔记__使用mnist数据集并测试自己的手写图片

内容源于曹建老师的tensorflow笔记课程

源码链接:https://github.com/cj0012/AI-Practice-Tensorflow-Notes

测试图片下载:https://github.com/cj0012/AI-Practice-Tensorflow-Notes/blob/master/num.zip

主要包含四个文件,主要是mnist_forward.py,mnist_backward.py,mnist_test.py,mnist_app.py

定义前向传播过程 mnist_forward.py:

import tensorflow as tf

INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER_NODE = 500

# 定义神经网络参数,传入两个参数,一个是shape一个是正则化参数大小
def get_weight(shape,regularizer):
# tf.truncated_normal截断的正态分布函数,超过标准差的重新生成
w = tf.Variable(tf.truncated_normal(shape,stddev=0.1))
if regularizer != None:
# 将正则化结果存入losses
tf.add_to_collection(“losses”,tf.contrib.layers.l2_regularizer(regularizer)(w))
return w

# 定义偏置b,传入shape参数
def get_bias(shape):
# 初始化为0
b = tf.Variable(tf.zeros(shape))
return b

# 定义前向传播过程,两个参数,一个是输入数据,一个是正则化参数
def forward(x,regularizer):
# w1的维度就是[输入神经元大小,第一层隐含层神经元大小]
w1 = get_weight([INPUT_NODE,LAYER_NODE],regularizer)
# 偏置b参数,w的后一个参数相同
b1 = get_bias(LAYER_NODE)
# 激活函数
y1 = tf.nn.relu(tf.matmul(x,w1)+b1)

w2 = get_weight([LAYER_NODE,OUTPUT_NODE],regularizer)
b2 = get_bias(OUTPUT_NODE)
y = tf.matmul(y1,w2)+b2

return y


定义反向传播过程 mnist_backward.py:

#coding:utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os

BATCH_SIZE = 200
#学习率衰减的原始值
LEARNING_RATE_BASE = 0.1
# 学习率衰减率
LEARNING_RATE_DECAY = 0.99
# 正则化参数
REGULARIZER = 0.0001
# 训练轮数
STEPS = 50000
#这个使用滑动平均的衰减率
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = “./model/”
MODEL_NAME = “mnist_model”

def backward(mnist):
#一共有多少个特征,784,一列
x = tf.placeholder(tf.float32,[None,mnist_forward.INPUT_NODE])
y_ = tf.placeholder(tf.float32,[None,mnist_forward.OUTPUT_NODE])
# 给前向传播传入参数x和正则化参数计算出y的值
y = mnist_forward.forward(x,REGULARIZER)
# 初始化global—step,它会随着训练轮数增加
global_step = tf.Variable(0,trainable=False)

# softmax和交叉商一起运算的函数,logits传入是x*w,也就是y
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))
cem = tf.reduce_mean(ce)
loss = cem + tf.add_n(tf.get_collection(“losses”))

learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples/BATCH_SIZE,
LEARNING_RATE_DECAY,
staircase = True)

train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step = global_step)

<span style="color:#808080;"># </span><span style="color:#808080;font-family:'AR PL UKai CN';">滑动平均处理</span><span style="color:#808080;">,</span><span style="color:#808080;font-family:'AR PL UKai CN';">可以提高泛华能力

ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)
ema_op = ema.apply(tf.trainable_variables())
# train_step和滑动平均计算ema_op放在同一个节点
with tf.control_dependencies([train_step,ema_op]):
train_op = tf.no_op(name=“train”)

saver = tf.train.Saver()

with tf.Session() as sess:

  init_op = tf.global_variables_initializer()
  sess.run(init_op)

  <span style="color:#cc7832;">for </span>i <span style="color:#cc7832;">in </span><span style="color:#8888c6;">range</span>(STEPS):
     <span style="color:#808080;"># mnist.train.next_batch()</span><span style="color:#808080;font-family:'AR PL UKai CN';">函数包含一个参数</span><span style="color:#808080;">BATCH_SIZE,</span><span style="color:#808080;font-family:'AR PL UKai CN';">表示随机从训练集中抽取</span><span style="color:#808080;">BATCH_SIZE</span><span style="color:#808080;font-family:'AR PL UKai CN';">个样本输入到神经网络

# next_batch函数返回的是image的像素和标签label
xs,ys = mnist.train.next_batch(BATCH_SIZE)
# ,表示后面不使用这个变量
,loss_value,step = sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys})

     <span style="color:#cc7832;">if </span>i % <span style="color:#6897bb;">1000 </span>== <span style="color:#6897bb;">0</span>:
        <span style="color:#8888c6;">print</span>(<span style="color:#6a8759;">"Ater {} training step(s),loss on training batch is {} "</span>.format(step<span style="color:#cc7832;">,</span>loss_value))
        saver.save(sess<span style="color:#cc7832;">,</span>os.path.join(MODEL_SAVE_PATH<span style="color:#cc7832;">,</span>MODEL_NAME)<span style="color:#cc7832;">,</span><span style="color:#aa4926;">global_step</span>=global_step)

def main():

mnist = input_data.read_data_sets("./data",one_hot = True)
backward(mnist)

if name == main:
main()

定义测试部分 mnist_test.py:

#coding:utf-8
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backward

TEST_INTERVAL_SECS = 5

def test(mnist):
with tf.Graph().as_default() as g:
# 占位符,第一个参数是tf.float32数据类型,第二个参数是shape,shape[0]=None表示输入维度任意,shpe[1]表示输入数据特征数
x = tf.placeholder(tf.float32,shape = [None,mnist_forward.INPUT_NODE])
y_ = tf.placeholder(tf.float32,shape = [None,mnist_forward.OUTPUT_NODE])
“”“注意这里没有传入正则化参数,需要明确的是,在测试的时候不要正则化,不要dropout”""
y = mnist_forward.forward(x,None)

    <span style="color:#808080;"># </span><span style="color:#808080;font-family:'AR PL UKai CN';">实例化可还原的滑动平均模型

ema = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
ema_restore = ema.variables_to_restore()
saver = tf.train.Saver(ema_restore)

    <span style="color:#808080;"># y</span><span style="color:#808080;font-family:'AR PL UKai CN';">计算的过程:</span><span style="color:#808080;">x</span><span style="color:#808080;font-family:'AR PL UKai CN';">是</span><span style="color:#808080;">mnist.test.images</span><span style="color:#808080;font-family:'AR PL UKai CN';">是</span><span style="color:#808080;"><a href="https://www.baidu.com/s?wd=10000&amp;tn=24004469_oem_dg&amp;rsv_dl=gh_pl_sl_csd" target="_blank">10000</a>×784</span><span style="color:#808080;font-family:'AR PL UKai CN';">的,最后输出的</span><span style="color:#808080;">y</span><span style="color:#808080;font-family:'AR PL UKai CN';">仕</span><span style="color:#808080;">10000×10</span><span style="color:#808080;font-family:'AR PL UKai CN';">的,</span><span style="color:#808080;">y_:mnist.test.labels</span><span style="color:#808080;font-family:'AR PL UKai CN';">也是</span><span style="color:#808080;">10000×10</span><span style="color:#808080;font-family:'AR PL UKai CN';">的

correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
# tf.cast可以师兄数据类型的转换,tf.equal返回的只有TrueFalse
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

    <span style="color:#cc7832;">while True</span>:
        <span style="color:#cc7832;">with </span>tf.Session() <span style="color:#cc7832;">as </span>sess:
            <span style="color:#808080;"># </span><span style="color:#808080;font-family:'AR PL UKai CN';">加载训练好的模型

ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
# 恢复模型到当前会话
saver.restore(sess,ckpt.model_checkpoint_path)
# 恢复轮数
global_step = ckpt.model_checkpoint_path.split("/")[-1].split("-")[-1]
# 计算准确率
accuracy_score = sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})
print("After {} training step(s),test accuracy is: {} ".format(global_step,accuracy_score))
else:
print(“No chekpoint file found”)
print(sess.run(y,feed_dict={x:mnist.test.images}))
return
time.sleep(TEST_INTERVAL_SECS)

def main():

mnist = input_data.read_data_sets(<span style="color:#6a8759;">"./data"</span><span style="color:#cc7832;">,</span><span style="color:#aa4926;">one_hot</span>=<span style="color:#cc7832;">True</span>)
test(mnist)

if name== main:
main()

定义使用手写图片部分mnist_app.py:

import tensorflow as tf
import numpy as np
from PIL import Image
import mnist_forward
import mnist_backward

# 定义加载使用模型进行预测的函数
def restore_model(testPicArr):

<span style="color:#cc7832;">with </span>tf.Graph().as_default() <span style="color:#cc7832;">as </span>tg:
    
    x = tf.placeholder(tf.float32<span style="color:#cc7832;">,</span>[<span style="color:#cc7832;">None,</span>mnist_forward.INPUT_NODE])
    y = mnist_forward.forward(x<span style="color:#cc7832;">,None</span>)
    preValue = tf.argmax(y<span style="color:#cc7832;">,</span><span style="color:#6897bb;">1</span>)
    <span style="color:#808080;"># </span><span style="color:#808080;font-family:'AR PL UKai CN';">加载滑动平均模型

variable_averages = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)

    <span style="color:#cc7832;">with </span>tf.Session() <span style="color:#cc7832;">as </span>sess:
        
        ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
        <span style="color:#cc7832;">if </span>ckpt <span style="color:#cc7832;">and </span>ckpt.model_checkpoint_path:
            <span style="color:#808080;"># </span><span style="color:#808080;font-family:'AR PL UKai CN';">恢复当前会话</span><span style="color:#808080;">,</span><span style="color:#808080;font-family:'AR PL UKai CN';">将</span><span style="color:#808080;">ckpt</span><span style="color:#808080;font-family:'AR PL UKai CN';">中的值赋值给</span><span style="color:#808080;">w</span><span style="color:#808080;font-family:'AR PL UKai CN';">和</span><span style="color:#808080;">b

saver.restore(sess,ckpt.model_checkpoint_path)
# 执行图计算
preValue = sess.run(preValue,feed_dict={x:testPicArr})
return preValue
else:
print(“No checkpoint file found”)
return -1
# 图片预处理函数
def pre_pic(picName):
# 先打开传入的原始图片
img = Image.open(picName)
# 使用消除锯齿的方法resize图片
reIm = img.resize((28,28),Image.ANTIALIAS)
# 变成灰度图,转换成矩阵
im_arr = np.array(reIm.convert(“L”))
threshold = 50#对图像进行二值化处理,设置合理的阈值,可以过滤掉噪声,让他只有纯白色的点和纯黑色点
for i in range(28):
for j in range(28):
im_arr[i][j] = 255-im_arr[i][j]
if (im_arr[i][j]<threshold):
im_arr[i][j] = 0
else:
im_arr[i][j] = 255
# 将图像矩阵拉成1784列,并将值变成浮点型(像素要求的仕0-1的浮点型输入)
nm_arr = im_arr.reshape([1,784])
nm_arr = nm_arr.astype(np.float32)
img_ready = np.multiply(nm_arr,1.0/255.0)

<span style="color:#cc7832;">return </span>img_ready

def application():
# input函数可以从控制台接受数字
testNum = int(input(“input the number of test images:”))
# 使用循环来历遍需要测试的图片才结束
for i in range(testNum):
# input可以实现从控制台接收字符格式,图片存储路径
testPic = input(“the path of test picture:”)
# 将图片路径传入图像预处理函数中
testPicArr = pre_pic(testPic)
# 将处理后的结果输入到预测函数最后返回预测结果
preValue = restore_model(testPicArr)
print(“The prediction number is :”,preValue)

def main():
application()

if name == main:
main()

output:


The end.

展开阅读全文

没有更多推荐了,返回首页