# -*- coding: utf-8 -*-
"""
Spyder Editor
This is a temporary script file.
"""
import numpy as np
#导入数据,下面地址为自己保存的MNIST数据集地址
def load_data(path='D:/360Downloads/python/mnist.npz'):
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
f.close()
return (x_train, y_train), (x_test, y_test)
# 将数据读入
(train_images, train_labels), (test_images, test_labels) = load_data()
print(train_images.shape,train_labels.shape,test_images.shape,test_labels.shape)
#创建一个两行两列的polt,便于后面验证及显示图片
import matplotlib.pyplot as plt
# plot 4 images as gray scale
plt.subplot(221)
print(train_labels[4545],train_labels[56])
plt.imshow(train_images[4545], cmap=plt.get_cmap('gray'))
plt.subplot(222)
plt.imshow(train_images[56], cmap=plt.get_cmap('gray'))
# show the plot
plt.show()
import tensorflow as tf
from tensorflow import keras
#增加层级
network = keras.models.Sequential()
network.add(keras.layers.Conv2D(32, (3, 3),
activation='relu',
input_shape=(28, 28, 1)))
network.add(keras.layers.MaxPooling2D((2, 2)))
network.add(keras.layers.Conv2D(64, (3, 3), activation='relu'))
network.add(keras.layers.MaxPooling2D((2, 2)))
network.add(keras.layers.Conv2D(64, (3, 3), activation='relu'))
#network.add(keras.layers.Dense(512, activation = 'relu', input_shape=(28*28,)))
network.add(keras.layers.Flatten())
network.add(keras.layers.Dense(512, activation = 'relu'))
network.add(keras.layers.Dense(10,activation = 'softmax'))
network.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
train_images = train_images.reshape((60000, 28, 28, 1))
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1))
test_images = test_images.astype('float32') / 255
train_labels = keras.utils.to_categorical(train_labels)
test_labels = keras.utils.to_categorical(test_labels)
network.fit(train_images, train_labels, epochs=5, batch_size=128)
network.summary()
#保存训练结果
network.save_weights('weights.ckpt')
"""
test_loss, test_acc = network.evaluate(test_images, test_labels)
print('test_acc:', test_acc)
"""