使用版本:
Anaconda 2022.05、Pycharm 2022.02、TensorFlow 2.8.0、Python 3.9.12
项目结构
完整程序
训练程序(mnist.py)
#导入模块
import keras
import tensorflow as tf
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
#导入数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
class_names=['0','1','2','3','4','5','6','7','8','9']
#图像预处理
x_train4D = x_train.reshape(x_train.shape[0],28,28,1).astype('float32')
x_test4D = x_test.reshape(x_test.shape[0],28,28,1).astype('float32')
#像素标准化
x_train, x_test = x_train4D / 255.0, x_test4D / 255.0
#模型搭建
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=16, kernel_size=(5,5), padding='same',
input_shape=(28,28,1), activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Conv2D(filters=36, kernel_size=(5,5), padding='same',
activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10,activation='softmax')
])
#打印模型
print(model.summary())
#训练配置
model.compile(loss='sparse_categorical_crossentropy',
optimizer='adam', metrics=['accuracy'])
#开始训练
model.fit(x=x_train, y=y_train, validation_split=0.2,
epochs=20, batch_size=300, verbose=2)
#保存模型
model.save('my_model.h5')
测试程序(test.py),自己手写的图片放在项目下image文件夹下
#导入模块
import keras
import os
import matplotlib.pyplot as plt
import numpy as np
from skimage.transform import resize as imresize
import tensorflow as tf
#载入模型
new_model = keras.models.load_model('my_model.h5')
new_model.compile(loss='sparse_categorical_crossentropy',
optimizer='adam', metrics=['accuracy'])
new_model.summary()
class_names=['0','1','2','3','4','5','6','7','8','9']
#预测
mypath = 'C:\\users\\admin\\desktop\\mnist-test\\image'
def getimg(mypath):
listdir = os.listdir(mypath)
imgs = []
for p in listdir:
img = plt.imread(mypath+'\\'+p)
img = np.abs(img/255-1)
img = imresize(img, [28, 28])
imgs.append(img[:,:,0])
return np.array(imgs),len(imgs)
imgs = getimg(mypath)
test_images = np.reshape(imgs[0],[-1,28,28,1])
predictions = new_model.predict(test_images)
plt.figure()
for i in range(imgs[1]):
c = np.argmax(predictions[i])
plt.subplot(3,3,i+1)
plt.xticks([])
plt.yticks([])
plt.imshow(test_images[i,:,:,0])
plt.title(class_names[c])
plt.show()
运行结果: