tensorflow2.0 CNN实现mnist数据集训练、模型保存加载以及自身数据预测

本博客主要内容为tensorflow2.0 CNN实现mnist数据集训练、模型保存加载以及自身数据预测,分为模型的训练生成及模型的保存,以及模型的加载及自身数据预测

一、模型的训练生成及模型的保存

#相关库的导入
from tensorflow import keras
import tensorflow.keras.layers as layers
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np 
import cv2
# 导入数据
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# 观察数据
print (x_train.shape)
plt.imshow(x_train[10000])
print (y_train[10000])
#数据处理,维度要一致
x_train = x_train.reshape((-1,28,28,1)).astype('float32') 
x_test = x_test.reshape((-1,28,28,1)).astype('float32') #-1代表那个地方由其余几个值算来的
x_train = x_train/255
x_test = x_test/255
print (x_train.shape)

#定义模型方法一-add()
#这里我的模型比较简单,但准确率也有98%多,感兴趣的可以扩大深度或者引入其他方法优化下
model = keras.Sequential()
model.add(layers.Conv2D(input_shape=(28, 28, 1),
                        filters=32, kernel_size=(3,3), strides=(1,1), padding='valid',
                        activation='relu'))#卷积层加激活
model.add(layers.MaxPool2D(pool_size=(2,2)))#池化层
model.add(layers.Flatten())#全连接层
model.add(layers.Dense(32, activation='relu'))
# 分类层
model.add(layers.Dense(10, activation='softmax'))
#训练模型配置
model.compile(optimizer=keras.optimizers.Adam(),
             # loss=keras.losses.CategoricalCrossentropy(),  
            # 损失函数多分类使用交叉熵(这里还要看标签是否为one-hot编码)
           # 回归问题用均方差
             loss=keras.losses.SparseCategoricalCrossentropy(),
             metrics=['accuracy'])
model.summary()
#进行模型训练
history = model.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.1)
#保存下模型(这个方法比较常用,也可以考虑适合部署的SavedModel 方式)
model.save('cnn1_save1.h5')#保存模型,名字可以任取,但要由.h5后缀
#测试模型
model.evaluate(x_test, y_test)

二、模型的加载及自身数据预测

#相关库的导入
from tensorflow import keras
import tensorflow.keras.layers as layers
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np 
import cv2
#加载模型
model=tf.keras.models.load_model('cnn1_save1.h5')
#自身数据加载(随便去百度或者那截张图,然后读取处理后就可以)

def output(y_pre , y):
    temp = np.argmax(y_pre)
    print ('预测结果为'+str(temp))
    print ('实际结果为'+str(y))
    if str(temp)==str(y):
        print('预测结果正确')
    else:
        print('预测结果错误')

def readnum(path):#读取自己的数字数据进行预测
    img = cv2.imread(path, 0)#灰度图读入
    #查看数据
    img = cv2.resize(img,(28,28))
#    cv2.imshow('img', img)
#    cv2.waitKey(0)
#    cv2.destroyAllWindows()
    img = np.array(img)
    img = img.reshape((-1,28,28,1)).astype('float32')
    img = 1-img/255.0 #因为我的数字是相反颜色的所以有个反转
    return img
#预测并输出结果
path = 'C:\\photo\\5.jpg'
x_pre = readnum(path)
y_pre = model.predict(x_pre)
output(y_pre, 5)


在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述


在这里插入图片描述
在这里插入图片描述

④当然有时也会搞错
在这里插入图片描述

在这里插入图片描述

  • 5
    点赞
  • 35
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值