1. 变量管理
当神经网络的结构复杂、参数更多时,需要更好的方式来传递和管理神经网络的参数。
1.1 通过变量名使用变量
TensorFlow提供了通过变量名称来创建或获取一个变量的机制,在不同的函数中可以直接通过变量的名字来使用变量,不需要通过参数的形式到处传递。
TensorFlow 通过变量名获取变量的机制主要通过
tf.get_variable
和tf.variable_scope
函数实现。
1.1.1 tf.get_variable
函数
1、当 tf.get_variable
函数用于创建变量时,它和 tf.Variable
的功能基本等价。
# 下面这两个定义是等价的。
v = tf.get_variable("v",shape=[1],initializer=tf.constant_initilizer(1.0))
v = tf.Variable(tf.constant(1.0,shape=[1]),name="v")
TensorFlow提供了7种不同的初始化函数:
初始化函数 | 功能 | 主要参数 |
---|---|---|
tf.constant_initializer | 将变量初始化为给定常量 | 常量的取值 |
tf.random_normal_initializer | 将变量初始化为满足正太分布的随机值 | 正太分布的均值和标准差 |
tf.truncated_normal_initializer | 将变量初始化为满足正太分布的随机值,但若随机出来的值偏离平均值超过两个标准差,那么这个数将会被重新随机 | 正太分布的均值和标准差 |
tf.random_uniform_initializer | 将变量初始化为满足平均分布的随机值 | 最大,最小值 |
tf.uniform_unit_scaling_initializer | 将变量初始化为满足平均分布但不影响输出数量级的随机值 | factor(产生随机值时乘以的系数) |
tf.zeros_initializer | 将变量设置为全为0 | 变量维度 |
tf.ones_initializer | 将变量设置为全为1 | 变量维度 |
2、tf.get_variable
与 tf.Variable
的区别:
tf.Variable
函数:变量名称是一个可选的参数,通过name="v"
的形式给出。tf.get_variable
函数:变量名是一个必填的参数。tf.get_variable
会根据这个名字去创建或获取变量。
例如上面的样例程序中:
tf.get_variable
首先会试图去创建一个名字为v
的参数,如果创建失败(比如已经有同名的参数),那么这个程序就会报错。这是为了避免无意识的变量复用造成的错误。比如在定义神经网络参数时,第一层网络的权重已经叫weights了,那么在创建第二层神经网络时,如果参数名仍然叫weights,就会触发变量重用的错误。否则两层神经网络共用一个权重会出现一些比较难以发现的错误。
1.1.2 tf.variable_scope
函数
tf.variable_scope
函数调用时提供的维度(shape)信息及初始化方法(initializer)的参数和 tf.Variable
函数调用时提供的初始化过程的参数也类似。
如果需要通过tf.get_variable
获取一个已经创建的变量,则需要通过tf.variable_scope
函数来生成一个上下文管理器,并明确指定在这个上下文管理器中,tf.get_variable
将直接获取已经创建过的变量。
示例:
#在名字为foo的命名空间内创建名字为v的变量
with tf.variable_scope("foo"):
v = tf.get_variable("v",[1],initializer=tf.constant_initializer(1.0))
#因为在命名空间foo已经存在名字为v的变量,所以下面的代码将会报错:
with tf.variable_scope("foo"):
v = tf.get_variable("v",[1])
#在生成上下文管理器时,将参数reuse设置为True。这样tf.get_variable函数将直接获取已经生成的变量
with tf.variable_scope("foo",reuse=True):
v1 = tf.get_variable("v",[1])
print v == v1 #输出为True,代表v,v1是相同的Tensorflow中的变量
#将参数reuse设置为True时,tf.variable_scope将只能获取已经创建的变量,
# 因为在命名空间bar中还没有创建变量v,所以下面的代码将会报错:
with tf.variable_scope("bar",reuse=True):
v = tf.get_variable("v",[1])
功能
1、通过 tf.variable_scope
控制 tf.get_variable
的语义:
- 如果
tf.variable_scope
函数使用参数reuse=True
生成上下文管理器时,该上下文管理器中的所有tf.get_variable
函数会直接获取已经创建的变量。如果变量不存在,tf.get_variable
函数会报错。 - 如果
tf.variable_scope
函数使用参数reuse=None
或者reuse=False
创建上下文管理器,tf.get_variable
操作将创建新的变量。如果同名的变量已经存在,则tf.get_variable
函数将报错。
2、Tensorflow 中tf.variable_scope
函数可以嵌套。
# 当 tf.variable_scope 函数嵌套时, reuse 参数的取值
with tf.variable_scope("root"):
# 可以通过 tf.get_variable_scope().reuse函数来获取当前上下文管理器中的reuse参数的取值
print(tf.get_variable_scope().reuse) # 若输出False则最外层reuse是False
with tf.variable_scope("foo",reuse=True): # 新建一个嵌套的上下文管理器,并指定reuse=True
print(tf.get_variable_scope().reuse) # 输出True
with tf.variable_scope(""bar) # 新建一个嵌套的上下文管理器,不指定reuse,
# 此时reuse回合外面一层保持一致
print(tf.get_variable_scope().reuse) # 输出True
print(tf.get_variable_scope().reuse) # 输出False
# 退出reuse设置为True设置的上
# 下文之后,
# reuse的值又回到了False
3、 tf.variable_scope
函数提供了一个管理变量命名空间的方式
v1 = tf.get_variable("v", [1])
print(v1.name) # 输出v:0,v是变量名称,“:0”表示这个变量是生成变量这个运算的第一个结果
with tf.variable_scope("foo"):
v2 = tf.get_variable("v", [1])
print(v2.name) # 输出foo/v:0,在tf.variable_scope 中创建的变量,名称前面会加入命名空间的名称,
# 并通过/来分割命名空间的名称和变量的名称
with tf.variable_scope("foo"):
with tf.variable_scope("bar"):
v3 = tf.get_variable("v", [1])
print(v3.name) # 输出foo/bar/v:0,命名空间可以嵌套,同时变量的名称也会加入所以命名空间的名称作为前缀。
v4 = tf.get_variable("v1", [1])
print(v4.name) # 输出foo/v1:0,当命名空间退出后,变量名称也就不会再被加入其前缀了。
# 创建一个名称为空的命名空间,并设置reuse=True,我们可以通过变量的名称来获取变量
with tf.variable_scope("",reuse=True):
v5 = tf.get_variable("foo/bar/v", [1])
# 可以直接通过带命名空间名称的变量名来获取其他命名空间下的变量
# 这里是通过指定名称foo/bar/v来获取在命名空间foo/bar/中创建的变量
print(v5 == v3) # 输出:True
v6 = tf.get_variable("v1", [1])
print(v6 == v4) # 输出:True
1.2 优化程序
优化之前博客:MNIST 手写体数字识别完整 TensorFlow 程序
中计算前向传播的辅助函数。
def inference(input_tensor, reuse=False):
# 定义第一层神经网络的变量和前向传播过程
with tf.variable_scope('layer1', reuse=reuse):
# 根据传进来的 reuse 来判断是创建新变量还是使用已创建好的。
# 在第一次构造网络时需要创建新的变量,以后每次调用这个函数都直接使用
# reuse=True 就不需要每次将变量传进来了。
weights = tf.get_variable("weights", [INPUT_NODE, LAYER1_NODE],
initializer=tf.truncated_normal_initializer(stddev=0.1))
biases = tf.get_variable("biases", [LAYER1_NODE],
initializer=tf.constant_initializer(0.0))
layer1 = tf.nn.relu(tf.matmul(input_tensor,weights) + biases)
# 类似地定义第二层神经网络的变量和前向传播过程。
with tf.variable_scope('layer2', reuse=reuse):
weights = tf.get_variable("weights", [LAYER1_NODE, OUTPUT_NODE],
initializer=tf.truncated_normal_initializer(stddev=0.1))
biases = tf.get_variable("biases", [OUTPUT_NODE],
initializer=tf.constant_initializer(0.0))
layer2 = tf.nn.relu(tf.matmul(layer1,weights) + biases)
# 返回最后的前向传播结果。
return layer2
2. 模型持久化
为了将训练得到的模型保存下来方便下次使用,即训练结果可以复用,需要将训练得到的神经网络模型持久化。
2.1 持久化代码实现
2.1.1 tf.train.Saver
类
TensorFlow 提供了一个非常简单的 API 来保存和还原神经网络模型,该 API 就是tf.train.Saver
类。
1. 保存 TensorFlow 神经网络模型
import tensorflow as tf
# 声明两个变量并计算其和
v1=tf.Variable(tf.constant(1.0,shape=[1]),name="v1")
v2=tf.Variable(tf.constant(2.0,shape=[1]),name="v2")
init_op=tf.global_variables_initializer()
#声明 tf.train.Saver 类用于保存模型
saver=tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
#将模型保存到文件
saver.save(sess,"./model/model.ckpt")
上面代码会生成 3 个文件:
- 第一个文件为
model.ckpt.meta
,它保存了 Tensorflow 计算图的结构,即神经网络的网络结构; - 第二个文件为
model.ckpt
,这个文件保存了 Tensorflow 程序中每一个变量的取值; - 第三个文件为
checkpoint
文件,这个文件保存了一个目录下所有的模型文件列表。
2. 加载已经保存的 TensorFlow 模型
1、加载已经创建的模型
import tensorflow as tf
# 使用和保存模型代码中一样的方式来声明变量
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
# 加载已经保存的模型,并通过已经保存的模型中的变量的值来计算加法
saver.restore(sess, "./model/model.ckpt")
print(sess.run(result))
加载模型的代码中,没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。
2、直接加载已经持久化的模型
不需要重复定义图上的运算。
- 默认保存和加载 Tensorflow 计算图上定义的全部变量。
import tensorflow as tf
# 直接加载
saver=tf.train.import_meta_graph("./model/model.ckpt.meta")
with tf.Session() as sess:
saver.restore(sess,"./model/model.ckpt")
# 通过张量名称来获取张量
print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
# 输出 [ 3.]
- 保存和加载部分变量。
使用场景: 有一个之前训练好的五层神经网络模型,但现在想尝试一个六层的神经网络,那么可以将前面五层神经网络中的参数直接加载到新的模型,而仅仅将最后一层神经网络重新训练。
方法: 在声明tf.train.Saver
类时可以提供一个列表来指定需要保存或者加载的变量。
例子:在加载模型的代码中使用
saver=tf.train.Saver([v1])
命令来构建tf.train.Saver
类,那么只有变量v1
会被加载进来,如果运行修改后只加载v1
的代码会得到变量未初始化的错误:
tensorflow.python.framework.errors.FailedPreconditionError:Attempting to use uninitialized value v2
因为v2
没有被加载,所以v2
在运行初始化之前是没有值的。除了可以选取需要被加载的变量。
- 在保存或者加载时给变量重命名。
# 这里声明的变量名称和已经保存的模型中变量的名称不同
v1 = tf.Variable(tf.constant(1.0,shape=[1]),name="other-v1")
v2 = tf.Variable(tf.constant(2.0,shape=[1]),name="other-v2")
# 如果直接使用 tf.train.Saver 类来加载模型会报变量找不到的错误。
# 使用一个字典(dictionary)来重命名变量就可以加载原来的模型了。这个字典指定了
# 原来名称为 v1 的变量现在加载到变量 v1 中(名称为other-v1),名称为 v2 的变量
# 加载到变量 v2 中(名称为other-v2)
saver = tf.train.Saver({"v1": v1, "v2": v2})
Tensorflow 通过字典(dictionary)将模型保存时的变量名和需要加载的变量联系起来。
目的: 方便使用变量的滑动平均值。
在Tensorflow中,每一个变量的滑动平均值是通过影子变量维护的,所以要获取变量的滑动平均值实际上就是获取这个影子变量的取值。
如果在加载模型时直接将影子变量映射到变量自身,那么在使用训练好的模型时就不需要再调用函数来获取变量的滑动平均值了。
# 此代码给出了一个保存滑动平均模型的样例:
import tensorflow as tf
# 1. 使用滑动平均
v = tf.Variables(0, dtype=tf.float32, name="v")
for variables in tf.global_variables():
print(variables.name)
# 在没有申明滑动平均模型时只有一个变量 v,所以上面的语句只会输出 "v:0"
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():
print(variables.name)
# 在申明滑动平均模型之后,TensorFlow 会自动生成一个影子变量
# "v:0" 和 "v/ExponentialMovingAverage"
# 所以上面语句会输出:
# "v:0" 和 "v/ExponentialMovingAverage:0"
# 2. 保存滑动平均模型
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
sess.run(tf.assign(v, 10))
sess.run(maintain_averages_op)
# 保存的时候会将 v:0 和 v/ExponentialMovingAverage:0这两个变量都存下来。
saver.save(sess, "./model/model.ckpt")
print(sess.run([v, ema.average(v)]))
# 输出 [10.0, 0.099999905]
# 3. 加载滑动平均模型
v = tf.Variable(0, dtype=tf.float32, name="v")
# 通过变量重命名将原来变量v的滑动平均值直接赋值给v。
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:
saver.restore(sess, "Saved_model/model2.ckpt")
print(sess.run(v))
# 输出:0.0999999
为了方便加载时重命名滑动平均变量,tf.train.ExponentialMovingAverage
类提供了 variables_to_restore
函数来生成tf.train.Saver
类所需要的变量重命名字典。
import tensorflow as tf
v = tf.Varaibles(0, dtype=tf.float32, name="v")
ems = tf.train.ExponentialMovingAverage(0.99)
# 通过使用 variables_to_restore 函数可以直接生成上面代码中提供的字典
# {"v/ExponentialMovingAverage": v}
print(ems.variables_to_restore())
# 上面代码会输出:
# {'v/ExponentialMovingAverage': <tensorflow.python.ops.variables.Variable Object at 0x7ff6454ddc10>}
# 其中后面的 Variable 类就代表了变量 v
saver = tf.train.Saver(ems.variables_to_restore())
with tf.Session() as sess:
saver.restore(sess, "Saved_model/model2.ckpt")
print(sess.run(v))
# 输出:0.0999999,即原来模型中变量 v 的滑动平均值。
2.1.2 convert_variables_to_constants
函数
问题:
- 使用
tf.train.Saver
会保存运行 Tensorflow 程序所需要的全部信息,然后有时并不需要某些信息。
比如在测试或者离线预测时,只需要知道如何从神经网络的输入层经过前向传播计算得到输出层即可,而不需要类似于变量初始化,模型保存等辅助节点的信息。
- 将变量取值和计算图结构分成不同的文件存储有时候并不方便。
目的:
通过此函数可以将计算图中的变量及其取值通过常量的方式保存,这样整个 Tensorflow 计算图可以统一存放在一个文件中。
import tensorflow as tf
from tensorflow.python.framework import graph_util
v1 = tf.Variable(tf.constant(1.0,shape=[1]),name="v1")
v2 = tf.Variable(tf.constant(2.0,shape=[1]),name="v2")
result = v1 + v2
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# 导出当前计算图的 GraphDef 部分,只需要这部分就可以完成从输入层到输出层的计算过程
graph_def = tf.get_default_graph().as_graph_def()
# 将图中的变量及其取值转化为常量,同时将图中不必要的节点去掉
# 下面这行代码中,最后一个参数['add']给出了需要保存的节点名称。add 节点是上面定义的两个变量相加的操作。
# 注意这里给出的是计算节点的名称,所以没有后面的:0。
output_graph_def = graph_util.convert_variables_to_constants(
sess,graph_def,['add'])
# 将导出的模型存入文件。
with tf.gfile.GFile("./model/combined_model.pb","wb") as f:f.write(output_graph_def.SerializeToString())
通过下面的程序可以直接计算定义的加法运算的结果,当只需要得到计算图中某个节点的取值时,该方法更为简便。
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
model_filename = "/path/to/model/combined_model.pb"
#读取保存的模型文件,并将文件解析成对应的GraphDef Protocol Buffer
with gfile.FastGFile(model_filename,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
#将graph_def中保存的图加载到当前的图中.return_elements=["add:0"]给出了返回的张量的名称。
#在保存能的时候给出的时计算节点的名称,所以为"add",在加载的时候给出的是张量的名称,所以时add:0
result = tf.import_graph_def(graph_def,return_elements=["add:0"])
print(sess.run(result))
2.2 持久化原理及数据格式
- Tensorflow 是一个通过图的形式来表达计算的编程系统,Tensorflow 程序中的所有计算都会表达为计算图上的节点。
- Tensorflow 通过元图(MetGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。
- Tensorflow 中元图是由 MetaGraphDef Protocol BUffer 定义的,MetaGraphDef 中的内容就构成了 Tensorflow 持久化时的第一个文件。
# 以下代码给出了 MetaGraphDef 类型的定义
message MetaGraphDef{
MetaInfoDef meta_info_def = 1;
GraphDef graph_def = 2;
saverDef saver_def = 3;
map<string, CollectionDef> collection_def = 4;
map<string, SignatureDef> signature_def = 5;
}
为了方便调试,TensorFlow 提供了 export_meta_graph
函数以 json 格式导出 MetaGraphDef Protocol BUffer。
import tensorflow as tf
# 定义变量相加的计算。
v1 = tf.Variable(tf.constant(1.0, shape=[1]),name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]),name="v2")
result1 = v1 + v2
saver = tf.train.Saver()
# 通过 export_meta_graph 函数导出 TensorFlow 计算图的元图,并保存为 json 格式
saver.export_meta_graph("./model/model.ckpt.meta.json",as_text=true)
下面分别介绍元图存储的信息。
meta_info_def
属性:记录了计算图中的元数据(计算图版本号、标签等)及程序中所有用到的运算方法信息。graph_def
属性:记录了计算图上的节点信息,因为在meta_info_def
属性已经包含了所有运算的信息,所以graph_def
只关注运算的连接结构。saver_def
属性:记录了持久化模型时需要使用的一些参数,如保存到文件的文件名、保存操作和加载操作的名称,以及保存频率等。collection_def
属性:计算图中维护集合的底层实现,该属性是一个从集合名称到集合内容的映射。
2.3 最佳实践样例程序
优化方向:
- 将不同功能模块分开:将神经网络的训练和测试分成两个独立的程序,这样可以使得每一个组件更加灵活。
例如训练神经网络的程序可以持续输出训练好的模型,而测试程序可以每隔一段时间检验最新模型的正确率,如果模型效果更好,则将这个模型提供给产品使用。 - 将前向传播的过程抽象成一个单独的库函数。
重构前面写过的 MNIST 手写数字识别程序,将会拆成三个程序:
1. mnist_inference.py
定义神经网络前向传播的过程及神经网络中的参数。
import tensorflow as tf
# 1. 定义神经网络结构相关的参数
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500
# 2. 通过 tf.get_variable 函数来获取变量。
def get_weight_variable(shape, regularizer):
weights = tf.get_variable(
"weights", shape,
initializer=tf.truncated_normal_initializer(stddev=0.1))
# 当给出了正则化生成函数,将当前变量的正则化损失加入名字为 losses 的集合。
if regularizer != None:
tf.add_to_collection("losses",regularizer(weights))
return weights
# 3. 定义神经网络的前向传播过程
def inference(input_tensor, regularizer):
with tf.variable_scope('layer1'):
weights = get_weight_variable(
[INPUT_NODE, LAYER1_NODE], regularizer)
biases = tf.get_variable(
"biases", [LAYER1_NODE],
initializer=tf.constant_initializer(0.0))
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
with tf.variable_scope('layer2'):
weights = get_weight_variable(
[LAYER1_NODE, OUTPUT_NODE], regularizer)
biases = tf.get_variable(
"biases", [OUTPUT_NODE],
initializer=tf.constant_initializer(0.0))
layer2 = tf.nn.relu(tf.matmul(layer1, weights) + biases)
# 返回最后前向传播的结果
return layer2
2. mnist_train.py
神经网络的训练程序。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 加载 mnist_inference.py 中定义的常量和前向传播的函数
import mnist_inference
import os
# 1. 配置神经网络的参数
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 30000
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./MNIST_model/" # 在当前目录下存在MNIST_model子文件夹
MODEL_NAME = "mnist_model"
# 2. 定义训练过程
def train(mnist):
# 定义输入输出 placeholder
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
y = mnist_inference.inference(x, regularizer)
global_step = tf.Variable(0, trainable=False)
# 定义损失函数、学习率、滑动平均操作以及训练过程。
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,
staircase=True)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
with tf.control_dependencies([train_step, variables_averages_op]):
train_op = tf.no_op(name='train')
# 初始化TensorFlow持久化类。
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(TRAINING_STEPS):
xs, ys = mnist.train.next_batch(BATCH_SIZE)
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
if i % 1000 == 0:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
# 3. 主程序入口
def main(argv=None):
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
train(mnist)
if __name__ == '__main__':
main()
3. mnist_eval.py
测试程序。
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import mnist_train
# 1. 每10秒加载一次最新的模型
# 加载的时间间隔。
EVAL_INTERVAL_SECS = 10
def evaluate(mnist):
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
# 测试时不再关注正则化损失的值,所以这里用于计算正则化损失的函数被设置为 None
y = mnist_inference.inference(x, None)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
while True:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
accuracy_score = sess.run(accuracy, feed_dict=validate_feed)
print("After %s training step(s), validation accuracy = %g" % (global_step, accuracy_score))
else:
print('No checkpoint file found')
return
time.sleep(EVAL_INTERVAL_SECS)
# 主程序
def main(argv=None):
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
evaluate(mnist)
if __name__ == '__main__':
main()