最近在看Morvan老师的视频,入门RNN,在这里也贴上自己根据老师的课程修改过的RNN代码,作为学习~
用到的是RNN 神经网络,mnist数据集
# -*- coding: utf-8 -*-
import numpy as np
np.random.seed(1337)
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense
from keras.optimizers import Adam
#每个图片就是28行,一个时间段就读取一行
TIME_STEPS = 28
#输入就是一个时间段输入28个列
INPUT_SIZE = 28
#一次循环放入多少张图片
BATCH_SIZE = 50
#
BATCH_INDEX = 0
#输出的标签大小是多少
OUTPUT_SIZE = 10
#RNN中间神经元的数量
CELL_SIZE = 50
#学习率
LR = 0.001
(X_train,y_train),(X_test,y_test) = mnist.load_data()
#data process---------------------------
X_train = X_train.reshape(-1, 28, 28) / 255. # normalize
X_test = X_test.reshape(-1, 28, 28) / 255. # normalize
#对标签进行one-hot 编码
y_train = np_utils.to_categorical(y_train, num_classes=10)
y_test = np_utils.to_categorical(y_test, num_classes=10)
#st