基于TensorFlow2的mnist数据集手写字体识别

使用版本:

Anaconda 2022.05、Pycharm 2022.02、TensorFlow 2.8.0、Python 3.9.12

项目结构

完整程序

训练程序(mnist.py)

#导入模块
import keras
import tensorflow as tf
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
#导入数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
class_names=['0','1','2','3','4','5','6','7','8','9']
#图像预处理
x_train4D = x_train.reshape(x_train.shape[0],28,28,1).astype('float32')
x_test4D = x_test.reshape(x_test.shape[0],28,28,1).astype('float32')
#像素标准化
x_train, x_test = x_train4D / 255.0, x_test4D / 255.0
#模型搭建
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(filters=16, kernel_size=(5,5), padding='same',
                 input_shape=(28,28,1),  activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Conv2D(filters=36, kernel_size=(5,5), padding='same',
    			 activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10,activation='softmax')
])

#打印模型
print(model.summary())
#训练配置
model.compile(loss='sparse_categorical_crossentropy',
              optimizer='adam', metrics=['accuracy'])
#开始训练
model.fit(x=x_train, y=y_train, validation_split=0.2,
                        epochs=20, batch_size=300, verbose=2)
#保存模型
model.save('my_model.h5')

测试程序(test.py),自己手写的图片放在项目下image文件夹下

#导入模块
import keras
import os
import matplotlib.pyplot as plt
import numpy as np
from skimage.transform import resize as imresize
import tensorflow as tf
#载入模型
new_model = keras.models.load_model('my_model.h5')
new_model.compile(loss='sparse_categorical_crossentropy',
              optimizer='adam', metrics=['accuracy'])
new_model.summary()
class_names=['0','1','2','3','4','5','6','7','8','9']
#预测
mypath = 'C:\\users\\admin\\desktop\\mnist-test\\image'
def getimg(mypath):
 listdir = os.listdir(mypath)
 imgs = []
 for p in listdir:
  img = plt.imread(mypath+'\\'+p)
  img = np.abs(img/255-1)
  img = imresize(img, [28, 28])
  imgs.append(img[:,:,0])
 return np.array(imgs),len(imgs)
imgs = getimg(mypath)
test_images = np.reshape(imgs[0],[-1,28,28,1])
predictions = new_model.predict(test_images)
plt.figure()
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

运行结果:

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值