基于RNN的MNIST实例

MNIST 数据集来自美国国家标准与技术研究所,National Institute of Standards and Technology (NIST)。训练集 (training set) 由来自 250 个不同人手写的数字构成,其中 50% 是高中学生,50% 来自人口普查局 (the Census Bureau) 的工作人员。测试集(test set) 也是同样比例的手写数字数据。

一、数据准备

同样从keras的datasets中导入mnist数据集。

from keras.datasets import mnist

# 首次使用时会在线进行数据集下载
(X_train, y_train), (X_test, y_test) = mnist.load_data() 

print('图像数据格式:', X_train.shape)
print("训练集:{:.0f},测试集:{:.0f}".format(X_train.shape[0], X_test.shape[0]))

在这里插入图片描述
且每个图片大小是28*28的。

其次还要对数据进行一定的转换

# 对自变量数据做归一化处理以改善拟合效果
X_train = X_train / 255.  
X_test = X_test / 255.
from keras.utils import to_categorical

# 将因变量转换为哑变量组
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

二、定义模型

用函数式API定义模型。

from keras.layers import Input, Dense, SimpleRNN 
from keras.models import Model 

# 定义用于输入的张量
input = Input(shape = X_train[0].shape)  

x = SimpleRNN(20, activation = 'relu')(input) 

output = Dense(10, activation = 'softmax')(x)  # 定义输出层

model = Model(inputs = input, outputs = output)  # 定义模型

选择最简单的SimpleRNN模型

可以通过summary查看下模型

model.summary()

在这里插入图片描述
其中每一层的Param数都可以具体计算出来:

如:simple_rnn_1这一层的980 = (28 + 20 + 1) * 20。
其中加的1位偏置单元。

同样dense_1中210 = (20 + 1)* 10。

然后是对模型的编译:

model.compile(loss = 'categorical_crossentropy', 
              optimizer = 'rmsprop', 
              metrics = ['accuracy']) 

一个合适的优化函数对于模型非常重要。

在RNN中,优化函数通常选用RMSprop

三、模型训练

model.fit(X_train, y_train, batch_size = 200, epochs = 20)

在这里插入图片描述
训练完后对训练集的准确率达到了87.33%

四、模型评估与预测

对模型进行评估:

score = model.evaluate(X_test, y_test, batch_size = 200, verbose = 1)
print("测试集损失函数:%f,预测准确率:%2.2f%%" % (score[0], score[1] * 100))

在这里插入图片描述
测试集的准确率也达到了87%左右,说明过拟合的情况不是很严重。整个模型还是比较好的

最后对测试集进行预测:

result = model.predict(X_test, batch_size = 200, verbose = 1)
result[:2]

在这里插入图片描述
查看前两条预测结果,同样,预测结果是概率值

可以看到:
对于第一个测试数据,第8个概率值最大。说明第一个测试数据预测结果为7

但是这样看起来并不是很直观。

可以通过numpy定位最大值的索引的方式来优化结果。

import numpy as np

np.argmax(result, axis = 1)[:10]

在这里插入图片描述
这样就很明显了。前10个数据的预测结果分别为7、2、…

和之前根据概率查看预测结果是一致的。


虽然准确率达到87%左右,但是模型的设定也极为简单,通过复杂模型,使准确率达到90%以上应该还是挺容易的。

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值