from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical
from datasets.mnist import load_mnist
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPool2D(2, 2))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPool2D(2, 2))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
"""分类器处理1D张量,而前面的输出为3D张量,要先铺开"""
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.summary()
"""加载数据集"""
(train_images, train_labels), (test_images, test_labels) = load_mnist(normalize=False, one_hot_label=False, flatten=False)
"""数据预处理"""
train_images = train_images.reshape(60000, 28, 28, 1)
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape(10000, 28, 28, 1)
test_images = test_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
"""编译模型"""
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
"""简单的留出验证集"""
x_val = train_images[:10000]
partial_x_train = train_images[10000:]
t_val = train_labels[:10000]
partial_t_train = train_labels[10000:]
"""训练模型"""
history = model.fit(partial_x_train, partial_t_train, epochs=5, batch_size=64, validation_data=(x_val, t_val))
import matplotlib.pyplot as plt
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, 5 + 1)
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title('Training and Validation loss')
plt.legend()
plt.show()
acc = history.history['acc']
val_acc = history.history['val_acc']
epochs = range(1, 5 + 1)
plt.plot(epochs, acc, 'ro', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.title('Training and Validation accuracy')
plt.legend()
plt.show()