循环神经网络也可以实现手写数字的识别。本文是在python3 28.keras使用循环神经网络(SimpleRNN)对MNIST数据集进行分类 学习笔记 一文的基础上进行的改进。
主要的改进是去掉了独热编码,loss函数使用mse。输出层使用1个神经元,激活函数是relu。
与原方法的比较,迭代3次就可以达到97%的准确率。
详细的解释如下:
- my_load_data:是定制化的数据加载函数。也可以换成keras中读取mnist的函数。不过我的网络有的时候会连接失败,于是我就开发了一个可以从缓存目录中读取数据的函数。
- x_train和x_test是28*28的图像数据。数值是0~255的整型数。要归一化到0-1。因此除以255.
- y_train和y_test是输出值,为0~9。因为是分类问题采用了独热编码。
- input_size和time_step都是28,正好和图片的大小一致。
- cell_size 隐含层的神经网络数目
- 损失函数使用categorical_crossentropy。
总体来说,循环神经网络对于数字识别的效果比多层神经网络和卷积神经网络要差一些。
import tensorflow as tf
import numpy as np
from tensorflow import keras
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
#读取本地mnist数据
def my_load_data(path='mnist.npz'):
origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = tf.keras.utils