- RNN神经网络多用于自然语言处理,但也可以用于简单的图像分类,这里做一个小小的尝试
- 由于对RNN网络理解不够深刻,只能做一些简单的解释
代码如下
# 导入全连接层函数,也可以使用tf.layers.dense()来完成最终的分类
from tensorflow.contrib.layers import fully_connected
# 导入mnist数据
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
# 这里理解为rnn网络输出数据时候的时刻,一个数据分为28个时刻
n_steps = 28
# 每个时刻有28个特征
n_inputs = 28
# 神经元个数
n_neurons = 150
# 分类的类别,共10类
n_outputs = 10
# 学习率
learning_rate = 0.001
# 读取数据 如果数据存在,直接读取,如果数据不存在,则会自动联网下载数据
mnist = input_data.read_data_sets("MNIST_data")
# 获取测试集数据,转为(-1,28,28)的矩阵格式,用于喂给循环神经网络
X_test = mnist.test.images.reshape((-1, n_steps, n_inputs))
# 获取测试集标签
y_test = mnist.test.labels
# 占位符
X = tf