基于tensorflow2.0+keras的CNN手写数字识别(mnist数据集)

第一步:构建用来训练数据的模型

#coding:utf-8 -*-
#首先导入相关模块
import tensorflow as tf
from keras import datasets, layers, models
#开始搭建CNN模型
class CNNmodel(object):
	def __init__(self):
		model = models.Sequential() #创建一个Sequential对象,方便我们堆叠各个卷积层。
		#给模型添加第一个卷积层,32个3*3的卷积核
		model.add(layers.Conv2D(32,(3,3),activation='relu', input_shape(28,28,1)))
		#紧跟一个池化层,使用MaxPooling2D提取图像特征最强的像素
		model.add(layers.MaxPlooling2D((2,2)))
		#第二层卷积
        model.add(layers.Conv2D(64,(3,3), activation='relu'))
        #第二层池化。
        model.add(layers.MaxPooling2D((2,2)))
        #第三层卷积
        model.add(layers.Conv2D(64,(3,3), activation='relu'))
        # Flatten,将输入展平为一维向量3*3*64=576,常用在卷积层到全连接层的过度阶段
        model.add(layers.Flatten())
        #全连接层,共两层后半部分相当于是构建了一个隐藏层为64,输入层为576
        #输出层为10的普通的神经网络。最后一层的激活函数是softmax,10位恰好可以表达0-9十个数字。
        model.add(layers.Dense(64, activation='relu'))
        model.add(layers.Dense(10, activation='softmax'))
		model.summary()
        self.model = model 

第二步:数据集处理

#数据集处理
class train_Data(object):
	def __init__(self):
        #mnist数据集的存储位置,如不存在将自动下载。
        data_path = os.path.abspath(os.path.dirname(__file__)) + '/MNIST_data/mnsit.npz'
        (train_images,train_labels), (test_images, test_labels) = datasets.mnist.load_data(path=data_path)
       
        #6万张训练图片,1万张测试图片。
        train_images = train_images.reshape((60000,28,28,1))
        test_images = test_images.reshape((10000,28,28,1))


        #像素映射到 0 - 1 之间。
        train_images, test_images = train_images / 255.0 , test_images / 255.0
        
        self.train_images, self.train_labels = train_images, train_labels
        self.test_images, self.test_labels = test_images, test_labels
 

第三步:开始训练

#执行器(即执行数据训练的任务)
class Run_train:
	def __init__(self):
		self.cnnmodel = CNNmodel()
		self.data = trainData()
	def train(self):
		#模型保存格式
		path = 'ckpt/cp-{epoch:04d}.ckpt'
        # filepath: 字符串,保存模型的路径。
        # verbose: 详细信息模式,0 或者 1 。0为不打印输出信息,1打印
        # save_weights_only: 如果 True,那么只有模型的权重会被保存 (model.save_weights(filepath)), 否则的话,整个模型会被保存 (model.save(filepath))。
        # period: 每个检查点之间的间隔(训练轮数)。
		save_model_cb = tf.keras.callbacks.ModelCheckpoint(check_path, save_weights_only=True, verbose=1, save_freq=5)
        # optimizer:优化器,如Adam
        # loss:计算损失,binary_crossentropy sparse_categorical_crossentropy
        # metrics: 列表,包含评估模型在训练和测试时的性能的指标,典型用法是metrics=[‘accuracy’]
        #使用compile编译模型
        self.cnnmodel.model.compile(optimizer='adam',
                               loss='sparse_categorical_crossentropy',
                               metrics=['accuracy'])
        #开始训练(fit)    
		self.cnnmodel.model.fit(self.data.train_images,self.data.train_labels,epochs=5,callbacks=[save_model_cb])
        test_loss, test_acc = self.cnnmodel.model.evaluate(self.data.test_images,self.data.test_labels)
        print("准确率: %.4f,共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))
        print('损失精度:%.4f' % (test_loss))
if __name__ == "__main__":
    app = Run_train()
    app.train()	

训练结果:
在这里插入图片描述

最后:开始尝试预测

#coding:utf-8 -*-
from train import CNNmodel
import tensorflow as tf
import numpy as np
from PIL import Image
import cv2 as cv
import time

class Predict(object):
    def __init__(self):
        latest = tf.train.latest_checkpoint('ckpt')
        self.cnn = CNNmodel()

        #恢复网络权重
        self.cnn.model.load_weights(latest)
    def predict(self, image_path):
        # 以黑白方式读取图片
        img = Image.open(image_path).convert('L')
        flatten_img = np.reshape(img, (28, 28, 1))
        x = np.array([1 - flatten_img])

        # API refer: https://keras.io/models/model/
        y = self.cnn.model.predict(x,verbose=1)

        # 因为x只传入了一张图片,取y[0]即可
        # np.argmax()取得最大值的下标,即代表的数字
        print(y)
        
        print('        神经网络 -> 预测结果:您写入的数字为:', np.argmax(y))


if __name__ == "__main__":
    time_start=time.time()

    app = Predict()
    app.predict('Test_IMAGE/text9.png')
    app.predict('Test_IMAGE/text7.png') 
    app.predict('Test_IMAGE/text5.png')
    app.predict('Test_IMAGE/text3.png')

    time_end=time.time()
    print('totally cost',time_end-time_start)

预测结果如图:

在这里插入图片描述

在这里插入图片描述
可以看到,准确率还是可以的。

参考:https://geektutu.com/post/tensorflow-mnist-simplest.html

  • 0
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值