在完成了MNIST手写数字模型的训练之后,我们就可以使用训练好的模型进行预测手写数字了。这里还是使用MNIST数据集中所提供的测试数据。
可以仅仅对测试集的数据进行预测,并直接打印出来结果即可。但是为了和原图像进行对比,这里定义了一个可视化的函数,将原图像以及预测结果值进行显示,可以使结果更加直观。
在上述基础上加上下面代码就可以了。
# 对测试集的数据进行预测
prediction_result = sess.run(tf.argmax(pred, 1), feed_dict={x: mnist.test.images})
# 定义可视化函数
def plot_images_labels_prediction(images, # 图像列表
labels, # 标签列表
prediction, # 预测值列表
index, # 从index个开始显示
num=10): # 缺省一次显示10幅
fig = plt.gcf() # 获取当前图表
fig.set_size_inches(10, 12) # 显示成英寸(1英寸等于2.54cm)
if num > 25:
num = 25 # 最多显示25幅图片
for i in range(0, num):
ax = plt.subplot(5, 5, i+1) # 画多个子图(5*5)
ax.imshow(np.reshape(images[index], (28, 28)), cmap='binary') # 显示第index张图像
title = "label=" + str(np.argmax(labels[index])) # 构建图片上要显示的title
if len(prediction) > 0:
title += ", predict=" + str(prediction[index])
ax.set_title(title, fontsize=10)
ax.set_xticks([]) # 不显示坐标轴
ax.set_yticks([])
index += 1
plt.show()
# 从第11张照片开始显示,显示25张
plot_images_labels_prediction(mnist.test.images, mnist.test.labels, prediction_result, 10, 25)
可视化结果:
从上面的预测结果我们可以看出,只有个别图片预测错误,大部分的预测数值都是正确的。