1. Introduction
今天是尝试用 PyTorch 框架来跑 MNIST 手写数字数据集的第二天,主要学习单例测试。本 blog 主要记录一个学习的路径以及学习资料的汇总。
注意:这是用 Python 2.7 版本写的代码
第一天(LeNet 网络的搭建):https://blog.csdn.net/qq_36627158/article/details/108245969
第二天(训练网络):https://blog.csdn.net/qq_36627158/article/details/108315239
第三天(测试网络):https://blog.csdn.net/qq_36627158/article/details/108321673
第四天(单例测试):https://blog.csdn.net/qq_36627158/article/details/108397018
2. Code(mnist_classify.py)
import tensorflow as tf
import mnist_lenet
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
image_data = Image.open("/home/ubuntu/Downloads/C6/3.jpg")
image_data = ImageOps.invert(image_data)
input = image_data.resize((28, 28)).split()[0]
plt.figure()
plt.imshow(input)
plt.show()
decode_img = tf.image.convert_image_dtype(input, tf.float32)
image = tf.reshape(decode_img, [1, 28, 28, 1])
output = mnist_lenet.build_model_and_forward(image)
probabilities = tf.nn.softmax(output)
correct_prediction = tf.argmax(probabilities, 1)
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state("./models")
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
probabilities, label = sess.run([probabilities, correct_prediction])
print "The num in this image is", label.item(), \
". And the probability is", probabilities[0][label.item()]
else:
print "No checkpoint file found"
这份代码基本上就和测试模型差不多,就没有什么需要查询的啦~