项目已上传至 GitHub —— best
下载MNIST数据集
以下有两种下载方式,如果链接失效可以搜索网上的资源
下载之后将其放在 mnist/data/ 文件夹下,目录结构如下
mnist/
data/
train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz
t10k-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte.gz
代码重构
为了使代码有更好的可读性和扩展性,需要将之按功能分为不同的模块,并将可重用的代码抽象成库函数
所以可以把以前臃肿的 MNIST 代码分成三个模块
- inference
- train
- eval
具体的文件夹目录如下
mnist/
data/
......
best/
inference.py
train.py
eval.py
完整代码
首先是 inference.py ,这个库函数负责模型训练及测试的前向传播过程
import tensorflow as tf
# 定义神经网络相关参数
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500
# 创建权重变量,并加入正则化损失集合
def get_weight_variable(shape, regularizer):
weights = tf.get_variable(
'weights',
shape,
initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer != None:
tf.add_to_collection('losses', regularizer(weights))
return weights
# 前向传播
def inference(input_tensor, regularizer):
# 声明隐藏层的变量并进行前向传播
with tf.variable_scope('layer1'):
weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
biases = tf.get_variable(
'biases', [LAYER1_NODE], initializer=tf.constant_initializer(