新的测试集中,不再将测试和训练跑在一起,训练过程中,每1000轮输出一次在当前训练batch上损失函数的大小来大致估计训练效果。在上面的程序中,每1000轮保存一次训练好的模型,这样就可以通过一个训练好的模型更加方便的在滑动平均模型上做测试。以下代码给出了测试程序mnist_neval.py。
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import mnist_train
EVAL_ENTERVAL_SECS=10
def evaluate(mnist):
with tf.Graph().as_default() as g:
x=tf.placeholder(tf.float32,[None,mnist_inference.INPUT_NODE],name="x-input")
y_=tf.placeholder(tf.float32,[None,mnist_inference.OUT_NODE],name="y-input")
validate_feed={x:mnist.validation.images,y_:mnist.validation.labels}
y=mnist_inference.inference(x,None)
corrent_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(corrent_prediction,tf.float32))
variable_averages=tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
variables_to_restore=variable_averages.variables_to_restore()
saver=tf.train.Saver(variables_to_restore)
while True:
with tf.Session() as sess:
ckpt=tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
global_step=ckpt.model_checkpoint_path.split("/")[-1].split("-")[-1]
accuracy_score=sess.run(accuracy,feed_dict=validate_feed)
print("after %s training steps,validation accuracy is %g"%(global_step,accuracy_score))
else:
print("no checkpoint file found")
return
time.sleep(EVAL_ENTERVAL_SECS)
def main(argv=None):
mnist=input_data.read_data_sets("/tmp/data",one_hot=True)
evaluate(mnist)
if __name__=="__main__":
tf.app.run()
上面给出的程序会每10秒运行一次,每次运行都读取最新保存的模型,并在mnist模型验证数据集上计算模型的正确率。