1. Introduction
今天是尝试用 PyTorch 框架来跑 MNIST 手写数字数据集的第二天,主要学习测试网络。本 blog 主要记录一个学习的路径以及学习资料的汇总。
注意:这是用 Python 2.7 版本写的代码
第一天(LeNet 网络的搭建):https://blog.csdn.net/qq_36627158/article/details/108245969
第二天(训练网络):https://blog.csdn.net/qq_36627158/article/details/108315239
第三天(测试网络):https://blog.csdn.net/qq_36627158/article/details/108321673
第四天(单例测试):https://blog.csdn.net/qq_36627158/article/details/108397018
2. Code(mnist_test.py)
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_lenet
def test_model(test_dataset):
num_of_test_data = test_dataset.images.shape[0]
images_holder = tf.placeholder(
dtype=tf.float32,
shape=[num_of_test_data, 28, 28, 1]
)
labels_holder = tf.placeholder(
dtype=tf.float32,
shape=[num_of_test_data, 10]
)
test_images = test_dataset.images
test_labels = test_dataset.labels
test_images_reshaped = tf.reshape(
tensor=test_images,
shape=[num_of_test_data, 28, 28, 1]
)
label_predict = mnist_lenet.build_model_and_forward(images_holder)
correct_prediction = tf.equal(
tf.argmax(test_labels, 1),
tf.argmax(label_predict, 1)
)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver = tf.train.Saver()
with tf.Session() as sess:
test_feed = {
images_holder: test_images_reshaped.eval(),
labels_holder: test_labels
}
ckpt = tf.train.get_checkpoint_state("./models")
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
accuracy_score = sess.run(accuracy, feed_dict=test_feed)
print "After training the model, the test accuracy =" , \
accuracy_score * 100, "%"
else:
print("No checkpoint file found")
return
if __name__ == '__main__':
mnist_data = input_data.read_data_sets("MNIST_data/", one_hot=True)
test_model(mnist_data.test)
3、Code Details
1、tf.equal()
https://cloud.tencent.com/developer/article/1406384
2、tf.train.get_checkpoint_state()
- https://zhuanlan.zhihu.com/p/41660278
- https://blog.csdn.net/changeforeve/article/details/80268522
- https://blog.csdn.net/weixin_38314865/article/details/86711288
3、saver.restore()
https://blog.csdn.net/qq_37285386/article/details/88957558