Keras实现mnist手写数字识别

一、编译环境
tf1.14/1.6.0
python3.6.5
keras2.1.5

二、代码

#coding=utf-8
[1]
#加载并打印数据结构
import keras
from keras.datasets import mnist
(train_images,train_labels),(test_images,test_labels)=mnist.load_data()
print("训练集数据结构{};\n训练集标签数据结构{};\n测试集数据结构{};\n测试集标签数据结构{};\n".format(train_images.shape,train_labels.shape,test_images.shape,test_labels.shape))
#注:后面models.fit()函数中设置了shuffle=True,所以此处不再设置

[2]
#打印训练集的第一张图片
import matplotlib.pyplot as plt
digit=train_images[0]
plt.imshow(digit,cmap=plt.cm.binary)   #cmap=plt.cm.binary:二值图
plt.show()

[3]
#构建网络结构
'''
1、layers:表示神经网络中的一个数据处理层(dense:全连接层)
2、models.Sequential():表示把每一个数据处理层串联起来
3、layers.Dense(……):构建一个数据处理层
4、input_shape(28*28):表示当前处理层接收到的数据格式为长和宽都是28的二维数组,后面的“,”表示数组里面的每个元素到底包含了多少个数字都没有关系
'''
from keras import layers,models
network=models.Sequential()  #sequential:adj 连续的,有序的
network.add(layers.Dense(512,activation='relu',input_shape=(28*28,)))
network.add(layers.Dense(10,activation='softmax'))
network.compile(optimizer='RMSProp',loss='categorical_crossentropy',metrics=['accuracy'])
'''
model.compile(   #compile:汇编、编制
optimizer=优化器  可以字符串形式给出优化器的名字,也可以是函数形式  #RMSProp 算法通过累计各个变量的梯度的平方和r,然后用每个变量的梯度除以r,即可有效的缓解梯度变量间的梯度差异
loss=损失函数  可以字符串形式给出损失函数的名字,也可以是函数形式  categorical_crossentropy :交叉熵损失函数,因为数据标签有么是1要么是0,所以不用考虑log(0)的情况
metrics=['准确率']  标注网络评价指标,'accuracy':真实值y和预测值y_都有数值,如y=[1],如y_=[1]
)
'''

[4]
#数据传入网络模型前,先做归一化处理
'''
1.reshape(60000, 28*28):train_images数组原来含有60000个元素,每个元素是一个28行,28列的二维数组,现在把每个二维数组转变为一个含有28*28个元素的一维数组.
2.由于数字图案是一个灰度图,图片中每个像素点值的大小范围在0到255之间.
3.train_images.astype(“float32”)/255 把每个像素点的值从范围0-255转变为范围在0-1之间的浮点值。
'''
train_images=train_images.reshape(len(train_images),784)
train_images=train_images.astype('float32')/255
test_images=test_images.reshape(len(test_images),784)
test_images=test_images.astype('float32')/255

'''
把图片对应的标记也做一个更改:
目前所有图片的数字图案对应的是0到9。
例如test_images[0]对应的是数字7的手写图案,那么其对应的标记test_labels[0]的值就是7。
我们需要把数值7变成一个含有10个元素的数组,然后在第7个元素设置为1,其他元素设置为0。
例如test_lables[0] 的值由7转变为数组[0,0,0,0,0,0,0,1,0,0,]
'''
from keras.utils import to_categorical   #to_categorical就是将类别向量转换为二进制(只有0和1)的矩阵类型表示。其表现为将原有的类别向量转换为独热编码的形式。
print('原测试集第一张图片的标签为:',test_labels[0])
train_labels=to_categorical(train_labels,10)
test_labels=to_categorical(test_labels,10)
print('one_hot处理后测试集一张图片标签为:',test_labels[0])

[5]
'''
把数据输入网络进行训练:
train_images:用于训练的手写数字图片;
train_labels:对应的是图片的标记;
batch_size:每次网络从输入的图片数组中随机选取128个作为一组进行计算。
epochs:每次计算的循环是五次
'''
network.fit(train_images,train_labels,epochs=2,batch_size=128,shuffle=True)

[6]
'''
model.evaluate()函数:在测试的模式下返回模型的误差值和评估标准值,计算是分批进行的
evaluate(x=None,y=None,batch=None,verbose=1,sample_weight=None,steps=None)
verbose:0或者1,日志显示模式。0—安静模式;1——进度条
参考链接:https://keras.io/zh/models/model/
'''
test_loss,test_acc=network.evaluate(test_images,test_labels,verbose=0)
print('误差(loss)值:{};\n正确率(acc):{};'.format(test_loss,test_acc))

[7]
#随便输入一张图片,测试网络的识别效果
import random
(train_images,train_labels),(test_images,test_labels)=mnist.load_data()
#在测试集中随便抽取一张
n=int(random.randint(0,len(test_images)))
digit=test_images[n]
plt.imshow(digit,cmap=plt.cm.binary)
plt.show()
test_image=digit.reshape(1,28*28)
res=network.predict(test_image)
'''
model.predict():为输入样本生成输出预测,计算分批进行
predict(x,batch_size=None,verbose=0,steps=None)
batch_size:整数,如未指定,默认为32
verbose:日志显示模式,0或1
steps:声明预测结束之前的总步数(批次样本),默认值为0
'''
print(res)
print(res.shape)
for i in range(res.shape[1]):  #此处res.shape[1]与one_hot的深度有关,最好不要写成10
    if(res[0][i]==1):
        print('图片上的数字是:{};\n预测成功率为:{}'.format(i,test_acc))
        break

三、输出结果
训练集数据结构(60000, 28, 28);
训练集标签数据结构(60000,);
测试集数据结构(10000, 28, 28);
测试集标签数据结构(10000,);

原测试集第一张图片的标签为: 7
one_hot处理后测试集一张图片标签为: [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]

误差(loss)值:0.09232732669711113;
正确率(acc):0.9707;
[[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]]
(1, 10)
图片上的数字是:2;
预测成功率为:0.9707

进程已结束,退出代码 0

在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值