在TensorFlow上实现基础LSTM网络
此笔记的主要目的就是使读者熟悉在TensorFlow
上实现基础LSTM
网络的详细过程。我们将选用MNIST
作为数据集,它包括手写数字的图像和对应的标签,我们可以根据以下内置功能从TensorFlow
上下载并读取数据:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./data/", one_hot=True)
数据被分成3
个部分:训练数据(mnist.train
)有55000
张图像,测试数据(mnist.test
)有10000
张图像,验证数据(mnist.validation
)有5000
张图像。
数据的形态
训练数据集包括55000
张28 * 28
像素的图像,这些784
(28 * 28
)像素值被展开成一个维度为784
的单一向量,所有55000
个像素向量被储存为形态为(55000, 784)
的numpy
数组,并命名为mnist.train.images
。所有这55000
张图像都关联了一个类别标签(表示其所属类别),一共有10
个类别(0
至9
),类别标签使用独热编码的形式表示。因此标签将作为形态为(55000, 10)
的数组保存,并命名为mnist.train.labels
。
为什么要选择MNIST?
LSTM
通常用来解决复杂的序列处理问题,比如包含了NLP
概念(词嵌入、编码器等)的语言建模问题。这些问题本身需要大量理解,那么将问题简化并集中于在TensorFlow
上实现LSTM
的细节(比如输入格式化、LSTM
单元格以及网络结构设计),会是个不错的选择。MNIST
就正好提供了这样的机会,其中的输入数据是一个像素值的集合。我们可以轻易地将其格式化,将注意力集中在LSTM
实现细节上。
VANILLA RNN
循环神经网络按时间轴展开的时候,如下图所示:
其中,xt
代表时间步t
的输入;st
代表时间步t
的隐藏状态,可看作该网络的记忆
;ot
作为时间步t
时刻的输出;U
、V
、W
是所有时间步共享的参数,共享的重要性在于我们的模型在每一时间步以不同的输入执行相同的任务。当把RNN
展开的时候,网络可被看作每一个时间步都受上一时间步输出影响(时间步之间存在连接)的前馈网络。
LSTM单元格的解释
在TensorFlow
中,基础的LSTM
单元格声明为:
lstm_layer = rnn.BasicLSTMCell(num_units, forget_bias=1)
这里的num_units
指一个LSTM
单元格中的单元数。num_units
可以比作前馈神经网络中的隐藏层,前馈神经网络的隐藏层的节点数量等于每一个时间步中一个LSTM
单元格内LSTM
单元的num_units
数量:
每一个num_units
的LSTM
单元都可以看作一个标准的LSTM
单元:
在TensorFlow
中,最简单的RNN
形式是static_rnn
:
outputs, _ = rnn.static_rnn(lstm_layer, input, dtype="float32")
input
接受形态为[batch_size, input_size]
的张量列表,列表的长度为将网络展开后的时间步数,即列表中每一个元素都分别对应网络展开的时间步。比如在MNIST
数据集中,我们有28 * 28
像素的图像,每一张都可以看成拥有28
行28
个像素的图像。我们将网络按28
个时间步展开,以使在每一个时间步中,可以输入一行28
个像素(input_size
),从而经过28
个时间步输入整张图像。给定图像的batch_size
值,则每一个时间步将分别收到batch_size
个图像:
由static_rnn
生成的输出是一个形态为[batch_size, n_hidden]
的张量列表。列表的长度为将网络展开后的时间步数,即每一个时间步输出一个张量。在这个实现中我们只需关心最后一个时间步的输出,因为一张图像的所有行都输入到RNN
,预测即将在最后一个时间步生成。
在开始的时候,先导入一些必要的依赖关系、数据集,并声明一些常量。设定batch_size
为128
,num_units
为128
:
import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./data/", one_hot=True)
time_steps = 28 # unrolled through 28 time steps
num_units = 128 # hidden LSTM units
n_input = 28 # rows of 28 pixels
learning_rate = 0.001 # learning rate for adam
n_classes = 10 # mnist is meant to be classified in 10 classes(0 - 9)
batch_size = 128 # size of batch
现在设置占位、权重以及偏置变量(用于将输出的形态从[batch_size, num_units]
转换为[batch_size, n_classes]
),从而可以预测正确的类别:
# weights and biases of appropriate shape to accomplish above task
out_weights = tf.Variable(tf.random_normal([num_units, n_classes]))
out_bias = tf.Variable(tf.random_normal([n_classes]))
x = tf.placeholder("float", [None, time_steps, n_input]) # input image placeholder
y = tf.placeholder("float", [None, n_classes]) # input label placeholder
现在我们得到了形态为[batch_size, time_steps, n_input]
的输入,我们需要将其转换成形态为[batch_size, n_inputs]
,长度为time_steps
的张量列表,从而可以将其输入static_rnn
:
# processing the input tensor from [batch_size, time_steps, n_input] to
# "time_steps" number of [batch_size, n_input] tensors
input = tf.unstack(x, time_steps, 1)
现在我们可以定义网络了,利用BasicLSTMCell
的一个层,将static_rnn
从中提取出来:
# defining the network
lstm_layer = rnn.BasicLSTMCell(num_units, forget_bias=1)
outputs, _ = rnn.static_rnn(lstm_layer, input, dtype="float32")
我们只考虑最后一个时间步的输入,从中生成预测:
# converting last output of dimension [batch_size, num_units] to
# [batch_size, n_classes] by out_weight multiplication
prediction = tf.matmul(outputs[-1], out_weights) + out_bias
定义损失函数、优化器和准确率:
# loss_function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y))
opt = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss) # optimization
# model evaluation
correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
现在我们已经完成定义,可以开始运行了。需要注意的是,我们的每一张图像在开始时被平坦化为784
维的单一向量,函数next_batch(batch_size)
必须返回这些784
维向量的batch_size
批次数。因此它们的形态要被改造成[batch_size, time_steps, n_input]
,从而可以被占位符接受:
init = tf.global_variables_initializer() # initialize variables
with tf.Session() as sess:
sess.run(init)
iter = 1
while iter < 800:
batch_x, batch_y = mnist.train.next_batch(batch_size=batch_size)
batch_x = batch_x.reshape((batch_size, time_steps, n_input))
sess.run(opt, feed_dict={x: batch_x, y: batch_y})
if iter % 10 == 0:
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
los = sess.run(loss, feed_dict={x: batch_x, y: batch_y})
print("For iter ", iter)
print("Accuracy ", acc)
print("Loss ", los)
print("-------------------")
iter = iter + 1
# calculating test accuracy
test_data = mnist.test.images[:128].reshape((-1, time_steps, n_input))
test_label = mnist.test.labels[:128]
print("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))