TensorFlow2学习七、使用MNIST手写体识别数据集识别自己手写图片

一、说明

本示例使用mnist数据集。mnist来自美国国家标准与技术研究所,训练集来自 250 个不同人手写的数字构成。不过这些数据集与中国书写习惯略有差别,所以直接训练好的模型识别中国手写数字准确度并不高。

二、步骤

1. 加载mnist数据集:

mnist = tf.keras.datasets.mnist

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0

2. 显示部分图片

plt.figure(figsize=(10, 10))
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i]])
plt.show()

3. 训练和保存模型,如果模型存在了,再运行就直接加载

if os.path.exists('./model.h5'):
    model = tf.keras.models.load_model('./model.h5')
else:
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    #
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    model.fit(train_images, train_labels, epochs=5)
    model.save('model.h5')

4. 评估

print('在测试集上评估')
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print('看看测试集测试结果')
predictions = model.predict(test_images)
print('预测值 = %i ; 正确值 = %i' % (np.argmax(predictions[0]), test_labels[0]))

5. 评估


def preprocess_image(image):
    image = tf.image.decode_jpeg(image, channels=1)
    image = tf.image.resize(image, [28, 28])
    image /= 255.0  # normalize to [0,1] range
    image = tf.reshape(image, [28, 28])
    return image


def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)


filepath = './3.png'
test_my_img = load_and_preprocess_image(filepath)
test_my_img = (np.expand_dims(test_my_img, 0))
my_result = model.predict(test_my_img)
print('自己的图片预测值 = %i ; 文件名 = ', (np.argmax(my_result[0]), filepath))

三、完整代码

from __future__ import absolute_import, division, print_function, unicode_literals

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import os

mnist = tf.keras.datasets.mnist

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0

# 显示一部分图片数据
# plt.figure(figsize=(10, 10))
# class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
# for i in range(25):
#     plt.subplot(5, 5, i + 1)
#     plt.xticks([])
#     plt.yticks([])
#     plt.grid(False)
#     plt.imshow(train_images[i], cmap=plt.cm.binary)
#     plt.xlabel(class_names[train_labels[i]])
# plt.show()

if os.path.exists('./model.h5'):
    model = tf.keras.models.load_model('./model.h5')
else:
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    #
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    model.fit(train_images, train_labels, epochs=5)
    model.save('model.h5')

print('在测试集上评估')
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print('看看测试集测试结果')
predictions = model.predict(test_images)
print('预测值 = %i ; 正确值 = %i' % (np.argmax(predictions[0]), test_labels[0]))

print('从测试集取一个图片测试')
img = test_images[1]
img = (np.expand_dims(img, 0))
predictions_single = model.predict(img)
print(np.argmax(predictions_single[0]), test_labels[1])


def preprocess_image(image):
    image = tf.image.decode_jpeg(image, channels=1)
    image = tf.image.resize(image, [28, 28])
    image /= 255.0  # normalize to [0,1] range
    image = tf.reshape(image, [28, 28])
    return image


def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)


filepath = './3.png'
test_my_img = load_and_preprocess_image(filepath)
test_my_img = (np.expand_dims(test_my_img, 0))
my_result = model.predict(test_my_img)
print('自己的图片预测值 = %i ; 文件名 = ', (np.argmax(my_result[0]), filepath))


自己的测试图片:

在这里插入图片描述
在这里插入图片描述
自己的手写图片要点:

  • 透明图像,颜色偏白色,png图片
  • 这里图像是28*28大小的
  • 尽量靠近中间写
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

编程圈子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值