网络结构:
就是LSTM后面再加一层全连接层。
输入输出分析:
MNIST数据集详解:点击查看MNIST详解
- 数据集中的手写数字是个(28,28)的灰度像素二维数组。
- 在此可以将其行向量看做是连续的28个时间序列,故代码中TIME_STEP=28,每一行就是LSTM在时间t的输入其输入数据的维度INPUT_SIZE
= 28- 代码中NUM_UNITS = 128 表示LSTM的输出h_t的维度是128,具体见下图,其中num_units就是输出维度。
代码:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# property:train test validation -->images(n,784) labels(n,10)
"""加载数据"""
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
"""参数设置"""
BATCH_SIZE = 128 # BATCH的大小,相当于一次处理128个image
TIME_STEP = 28 # 一个LSTM中,输入序列的长度,image有28行
INPUT_SIZE = 28 # x_i 的向量长度,image有28列
LR = 0.001 # 学习率
NUM_UNITS = 128 # 多少个LTSM单元
ITERATIONS = 8000 # 迭代次数
N_CLASSES = 10 # 输出大小,0-9十个数字的概率
"""定义计算"""
# 定义 placeholders 以便接收x,y
# 维度是[BATCH_SIZE,TIME_STEP * INPUT_SIZE]
train_x = tf.placeholder(tf.float32, [None, TIME_STEP * INPUT_SIZE])
# 输入的是二维数据,将其还原为三维,维度是[BATCH_SIZE, TIME_STEP, INPUT_SIZE]
image = tf.reshape(train_x, [-1, TIME_STEP, INPUT_SIZE])
train_y = tf.placeholder(tf.int32, [None, N_CLASSES])
# 定义RNN(LSTM)结构
rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=NUM_UNITS)
outputs, final_state = tf.nn.dynamic_rnn(
cell=rnn_cell, # 选择传入的cell
inputs=image, # 传入的数据
initial_state=None, # 初始状态
dtype=tf.float32, # 数据类型
# False: (batch, time_step, x_input); True: (time_step,batch,x_input),
# 这里根据image结构选择False
# If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
time_major=False,
)
# output = outputs[:, -1, :]
output = tf.layers.dense(
inputs=outputs[:, -1, :], units=N_CLASSES) # 取最后一路输出送入全连接层
"""定义损失和优化方法"""
loss = tf.losses.softmax_cross_entropy(
onehot_labels=train_y,
logits=output) # 计算loss
train_op = tf.train.AdamOptimizer(LR).minimize(loss) # 选择优化方法
correct_prediction = tf.equal(
tf.argmax(
train_y, axis=1), tf.argmax(
output, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float')) # 计算正确率
"""训练"""
with tf.Session() as sess:
sess.run(tf.global_variables_initializer()) # 初始化计算图中的变量
for step in range(ITERATIONS): # 开始训练
x, y = mnist.train.next_batch(BATCH_SIZE)
_, loss_ = sess.run([train_op, loss], {train_x: x, train_y: y})
if step % 500 == 0: # test(validation)
test_x, test_y = mnist.test.next_batch(5000)
accuracy_ = sess.run(accuracy, {train_x: test_x, train_y: test_y})
print('train loss: %f' % loss_, '| validation accuracy: %f' % accuracy_)
参考引用:
https://www.cnblogs.com/sandy-t/p/6930608.html [搭建RNN(LSTM)进行MNIST 手写数字辨识]