使用TensorFlow搭建循环神经网络
- TensorFlow版本1.4.0
- Python版本>3.5.0
循环神经网络RNN的原理可以参考这篇文章。
本教程搭建的网络结构包含LSTM和一个全连接层
网络结构图如下:
输出—>LSTM—>全连接层—>输出
1.载入MNIST数据集
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
2.定义参数
RNN输入是一个时间序列,MNIST数据集中图片大小为28px*28px,可以将每一行的像素看成一个序列长度,那么时间步长就是28.
batch_size = 64
n_input = 784 # 图像大小
time_steps = 28 # 时间步长
input_size = 28 # 序列长度
num_classes = 10
rnn_size = 128 # rnn隐藏层大小
lr = 0.01
3.定义网络输出
x = tf.placeholder(tf.float32, shape=[None, n_input])
y = tf.placeholder(tf.float32, shape=[None, num_classes])
4.定义网络主结构
def rnn_model(x):
# 将输入x变为[batch_size, time_steps, input_size]
x = tf.reshape(x, shape=[-1, time_steps, input_size])
# 构建rnn
rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
# 将输入送入rnn,得到输出与中间状态,输出shape为[batch_size, time_steps, rnn_size]
outputs, states = tf.nn.dynamic_rnn(rnn_cell, x, dtype=tf.float32)
# 获取最后一个时刻的输出,输出shape为[batch_size, rnn_size]
output = tf.transpose(outputs, [1,0,2])[-1]
# 全连接层,最终输出大小为[batch_size, num_classes]
fc_w = tf.Variable(tf.random_normal([rnn_size, num_classes]))
fc_b = tf.Variable(tf.random_normal([num_classes]))
return tf.matmul(output, fc_w) + fc_b
5.构建网络
logits = rnn_model(x)
prediction = tf.nn.softmax(logits)
6.定义损失函数与优化器
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=lr)
train_op = optimizer.minimize(loss_op)
7.定义评价指标
correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
8.训练网络
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
total_batch = mnist.train.num_examples // batch_size
for epoch in range(train_epochs):
for batch in range(total_batch):
batch_x, batch_y = mnist.train.next_batch(batch_size)
sess.run(train_op, feed_dict={x:batch_x, y:batch_y})
if batch % 200 == 0:
loss, acc = sess.run([loss_op, accuracy], feed_dict={x:batch_x, y:batch_y})
print("epoch {}, batch {}, loss {:.4f}, accuracy {:.3f}".format(epoch, batch, loss, acc))
print("optimization finished")
test_acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
print("test acc", test_acc)
github源代码
https://github.com/gamersover/tensorflow_basic_tutorial/blob/master/basic_model/rnn_mnist.py