一.唠嗑
绊脚石1:
大家应该都知道jupyter这个神奇的软件,很多程序员喜欢用这种编程风格编写代码,当然在Pycharm里面是可以嵌入jupyter的,所以当我成功嵌入后,编写好代码发现无法import同文件夹下的 .ipynb文件,我度娘了一下,都说是必须得把需要导入的文件转成一个模块才可以被调用,这个...好麻烦,我果断放弃,改成 .py文件。
绊脚石2:
当我成功编写好 .py文件后,运行文件,报错信息是无法下载mnist数据,这个我度娘了一下,发现是程序请求的资源在网站上无法下载了。然后我果断下载训练集到本地,又开始报错了....这次的理由是不然是tensorflow里面的一个模块
tensorflow.examples.tutorials.mnist
不导入这个模块我就无法加载mnist训练集啊,更不用说接下来的神经网络的训练和测试。
二.项目介绍
具体代码:
前向传播文件:
# 前向传播描述神经网络结构
# coding:utf-8
import tensorflow as tf
#有输入层,隐藏层(1层),输出层
INPUT_NODE = 784 #输入每一场图片28*28,总共784个像素
OUTPUT_NODE = 10 #输出10个数字:0-9
LAYER1_NODE = 500 #隐藏层设定500个节点
def get_weight(shape, regularizer):
# 正则化:在损失函数中给每个参数w加上权重,引入模型复杂度指标,抑制模型噪声,减小过拟合
w = tf.Variable(tf.truncated_normal(shape, stddev=0.1))
# 加入正则化,L2正则化
if regularizer != None: tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
return w
def get_bias(shape):
b = tf.Variable(tf.zeros(shape))
return b
def forward(x, regularizer):
w1 = get_weight([INPUT_NODE, LAYER1_NODE], regularizer) #784行500列
b1 = get_bias([LAYER1_NODE]) #500列
y1 = tf.nn.relu(tf.matmul(x, w1) + b1) #先矩阵相乘,然后过激活函数relu()线性输出,N行500列
w2 = get_weight([LAYER1_NODE, OUTPUT_NODE], regularizer) #500行10列
b2 = get_bias([OUTPUT_NODE]) #10列
y = tf.matmul(y1, w2) + b2 #输出层不过激活函数 #N行10列
return y
反向传播文件:
# 反向传播描述模型优化参数的方法
# coding:utf-8
from tensorflow.examples.tutorials.mnist import input_data #用来导入mnist数据集
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
from .mnist_forward import * #导入前向传播:神经网络结构
import os # 把参数字符串按照路径命名规则拼接
import tensorflow as tf
BATCH_SIZE = 200
LEARNING_RATE_BASE = 0.1 #初始学习率
LEARNING_RATE_DECAY = 0.99 #学习率的衰减率
REGULARIZER = 0.0001 #正则化权重
STEPS = 50000 #训练50000轮
MOVING_AVERAGE_DECAY = 0.99 #滑动平均衰减率
MODEL_SAVE_PATH = "./model/" #模型保存路径
MODEL_NAME = "mnist_model" #模型保存的文件名
def backward(mnist): #导入数据
x = tf.placeholder(tf.float32, [None, INPUT_NODE]) #N行784列
y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE]) #N行10列
y = forward(x, REGULARIZER) #调用mnist_forward文件的forward函数,获取y(N行10列),与y_(正确答案)最比较求损失值
# 运行了几轮BATCH_SIZE的计数器,初值为0,设为不被训练
global_step = tf.Variable(0, trainable=False)
# 损失函数为交叉熵
# 在Tensorflow中,一般让模型的过输出经过sofemax函数,以获得输出分类的概率分布
# 再与标准答案对比,求出交叉熵,得到损失函数
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cem = tf.reduce_mean(ce)
loss = cem + tf.add_n(tf.get_collection('losses')) #交叉熵误差+每一个正则化w的损失,含正则化
# 定义指数下降学习率,学习率每经过一轮BATCH_SIZE,更新一次
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples / BATCH_SIZE,
LEARNING_RATE_DECAY,
staircase=True)
#反向传播方法为梯度下降,损失值 = 交叉熵误差+每一个正则化w的损失
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
# 实例化滑动平均类计算公式,滑动平均衰减率为0.99,当前轮数global_step
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
# ema.apply后的括号里是更新列表,每次运行sess.run(ema_op)时,对更新列表中的参数求滑动平均值
# 在实际应用中会使用tf.trainable_variables()自动将所有待训练的参数汇总为列表
# eme_op = ema.apply([w1])
ema_op = ema.apply(tf.trainable_variables())
#该函数实现将滑动平均和训练过程同步运行
#查看模型中参数的平均值,可以用 ema.average()函数
with tf.control_dependencies([train_step, ema_op]):
train_op = tf.no_op(name='train')
saver = tf.train.Saver() #实例化saver
with tf.Session() as sess:
#初始化模型参数
init_op = tf.global_variables_initializer()
sess.run(init_op)
for i in range(STEPS):
#BATCH_SIZE的数据集和标签喂入神经网络,并训练train_op
xs, ys = mnist.train.next_batch(BATCH_SIZE)
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
if i % 1000 == 0:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
#在反向传播过程中,一般会间隔一定轮数保存一次神经网络模型, 并产生三个文件
#(保存当前图结构的a.meta文件 、 保存当前参数名的x.index文件 、 保存当前参数的a.data文件)
#tf.train.Saver()用来实例化saver对象
# 下述代码表示,神经网络每循环规定的轮数,将神经网络模型中所有的参数等信息保存到指定的路径中
# 并在存放网络模型的文件夹名称中注明保存模型时的训练轮数
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
def main():
# 加载mnist数据集
# 在read_data_sets()函数中有两个参数,第一个参数表示数据集存放路径,
# 第二个参数表示数据集的存取形式。当第二个参数为 Ture 时,表示以独热码形式存取数据集
mnist = input_data.read_data_sets("./data/", one_hot=True)
#返回训练集train样本数
print("train data size:", mnist.train.mun_examples)
#返回验证集validation样本数
print("validation data size:", mnist.validation.mun_examples)
#返回测试集test样本数
print("test data size:", mnist.test.mun_examples)
backward(mnist)
if __name__ == '__main__':
main()
验证模型准确率的测试文件:
# coding:utf-8
import time #为了延迟
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #用来导入mnist数据集
from 全链接网络基础 import mnist_forward
from 全链接网络基础 import mnist_backward
TEST_INTERVAL_SECS = 5 #程序循环时间5秒
def test(mnist):
#tf.Graph().as_default()表示将当前图设置成为默认图,并返回一个上下文管理器
#表示将在 Graph()内定义的节点加入到计算图 g 中
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE]) #N行784列
y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE]) #N行10列,正确答案
y = mnist_forward.forward(x, None) #N行10列,预测答案
#加载模型中参数的滑动平均值
#在保存模型时,若模型中采用滑动平均,则参数的滑动平均值会保存在相应文件中
#通过实例化saver对象,实现参数滑动平均值的加载
ema = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
ema_restore = ema.variables_to_restore()
saver = tf.train.Saver(ema_restore)
#神经网络模型准确率评估方法
#得到神经网络模型在本组数据上的准确率
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
while True:
#在测试网络效果时,需要将训练好的神经网络模型加载
with tf.Session() as sess:
#在with 结构中进行加载保存的神经网络模型
# 若 ckpt 和保存的模型在指定路径中存在,则将保存的神经网络模型加载到当前会话中
ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path) #重新加载模型
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] #恢复global_step值
accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}) #执行准确率计算
print("After %s training step(s), test accuracy = %g" % (global_step, accuracy_score)) #有模型和参数
else:
print('No checkpoint file found') #无模型和参数
return
time.sleep(TEST_INTERVAL_SECS)
def main():
#加载mnist数据集
#在read_data_sets()函数中有两个参数,第一个参数表示数据集存放路径,
#第二个参数表示数据集的存取形式。当第二个参数为 Ture 时,表示以独热码形式存取数据集
mnist = input_data.read_data_sets("./data/", one_hot=True)
test(mnist)
if __name__ == '__main__':
main()