关于RNN的介绍可以参考:RNN介绍
下面将RNN用于前文所提到的MNIST手写数字识别中。
1.获取数据集
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("D:\BaiDu\MNIST_data",one_hot=True)
2.参数定义
# 输入图片是28*28
n_inputs = 28 #输入一行,一行有28个数据
max_time = 28 #一共28行
lstm_size = 100 #隐层单元
n_classes = 10 # 10个分类
batch_size = 50 #每批次50个样本
n_batch = mnist.train.num_examples // batch_size #计算一共有多少个批次
#这里的none表示第一个维度可以是任意的长度
x = tf.placeholder(tf.float32,[None,784])
#正确的标签
y = tf.placeholder(tf.float32,[None,10])
#初始化权值
weights = tf.Variable(tf.truncated_normal([lstm_siz