import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import math
# from tf.keras.models import Sequential # This does not work!
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import InputLayer, Input
from tensorflow.python.keras.layers import Reshape, MaxPooling2D
from tensorflow.python.keras.layers import Conv2D, Dense, Flatten
from mnist import MNIST
data = MNIST(data_dir="data/MNIST/")
print("Size of:")
print("- Training-set:\t\t{}".format(data.num_train))
print("- Validation-set:\t{}".format(data.num_val))
print("- Test-set:\t\t{}".format(data.num_test))
# The number of pixels in each dimension of an image.
img_size = data.img_size
# The images are stored in one-dimensional arrays of this length.
img_size_flat = data.img_size_flat
# Tuple with height and width of images used to reshape arrays.
img_shape = data.img_shape
# Tuple with height, width and depth used to reshape arrays.
# This is used for reshaping in Keras.
img_shape_full = data.img_shape_full
# Number of classes, one class for each of 10 digits.
num_classes = data.num_classes
# Number of colour channels for the images: 1 channel for gray-scale.
num_channels = data.num_channels
def plot_images(images, cls_true, cls_pred=None):
assert len(images) == len(cls_true) == 9 # 判断输入参数是否符合要求,cls_true是true classes
# Create figure with 3x3 sub-plots.
#轴域(Axes)的感念确实可以先理解成一些轴(Axis)的集合,当然这个集合还有很多轴(Axis)的属性,标注等等
fig, axes = plt.subplots(3, 3)
fig.subplots_adjust(hspace=0.3, wspace=0.3)
for i, ax in enumerate(axes.flat):
# Plot image.
ax.imshow(images[i].reshape(img_shape), cmap='binary') #Map,计算机语言函数,作用是映射的关键码。
# Show true and predicted classes.
if cls_pred is None:
xlabel = "True: {0}".format(cls_true[i])
else:
xlabel = "True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i])
# Show the classes as the label on the x-axis.
ax.set_xlabel(xlabel)
# Remove ticks from the plot.
ax.set_xticks([])
ax.set_yticks([])
# Ensure the plot is shown correctly with multiple plots
# in a single Notebook cell.
#plt.imshow()函数负责对图像进行处理,并显示其格式,而plt.show()则是将plt.imshow()处理后的函数显示出来
plt.show()
# Get the first images from the test-set.
images = data.x_test[0:9]
# Get the true classes for those images.
cls_true = data.y_test_cls[0:9]
# Plot the images and labels using our helper-function above.
plot_images(images=images, cls_true=cls_true)
# Helper-function to plot example errors
def plot_example_errors(cls_pred):
# cls_pred is an array of the predicted class-number for
# all images in the test-set.
# Boolean array whether the predicted class is incorrect.
incorrect = (cls_pred != data.y_test_cls)
# Get the images from the test-set that have been
# incorrectly classified.
images = data.x_test[incorrect]
# Get the predicted classes for those images.
cls_pred = cls_pred[incorrect]
# Get the true classes for those images.
cls_true = data.y_test_cls[incorrect]
# Plot the first 9 images.
plot_images(images=images[0:9],
cls_true=cls_true[0:9],
cls_pred=cls_pred[0:9])
#Sequential Model
#The Keras API has two modes of constructing Neural Networks. The simplest is the Sequential Model which only allows for the layers to be added in sequence.