LSTM模型简介及Tensorflow实现

LSTM模型在RNN模型的基础上新增加了单元状态C(cell state)。

一. 模型的输入和输出

在t时刻,LSTM的输入有3个:
(1) 当前时刻LSTM的输入值x(t);
(2) 上一时刻LSTM的输出值h(t-1);
(3) 上一时刻的单元状态c(t-1);

LSTM的输出有2个:
(1) 当前时刻LSTM的输出值h(t);
(2) 当前时刻的单元状态c(t);

二. 模型的计算

这里写图片描述

(1) 遗忘门:forget gate,控制上一时刻的单元状态有多少传入:

这里写图片描述

(2) 输入门:input gate,控制上一时刻LSTM的输出有多少传入:

这里写图片描述

(3) 当前时刻输入的单元状态:

这里写图片描述

(4) 当前时刻LSTM的单元状态:

这里写图片描述

(5) 输出门:output gate,控制有多少传入到LSTM当前时刻的输出:

这里写图片描述

(6) 当前时刻LSTM的输出:

这里写图片描述

note:公式中的X表示对应元素相乘;

三. TensorFlow实现LSTM-regression模型

# load module
from tensorflow.example.tutorial.mmist import input_data
import tensorflow as tf
import numpy as np

# definite hyperparameters
BATCH_SIZE = 64
TIME_STEP = 28
INPUT_SIZE = 28
LR = 0.01

# load data
mnist = input_data.read_data_sets('mnist', one_hot=True)

# test data
test_x = mnist.test.images[:2000]
test_y = mnist.test.labels[:2000]

# placeholder
tf_x = tf.placeholder(tf.float32, [None, TIME_STEP * INPUT_SIZE])
image = tf.reshape(tf_x, [-1, TIME_STEP, INPUT_SIZE])
tf_y = tf.placeholder(tf.int32, [None, 10])

# RNN
rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=64)
outputs, (h_c, h_n) = tf.nn.dynamic_rnn(rnn_cell, image, dtype=tf.float32)
loss = tf.losses.softmax_cross_entropy(onehot_labels=tf_y, logits=output)
train_op = tf.train.AdamOptimizer(LR).minimize(loss)
accuracy = tf.metrics.accuracy(labels=tf.argmax(tf_y, axis=1), predictions=tf.argmax(output, axis=1),)[1]

# open an tf session
sess = tf.Session()
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)

# train
for step in range(1200):
    b_x, b_y = mnist.train.next_batch(BATCH_SIZE)
    _, loss_ = sess.run([train_op, loss], {tf_x: b_x, tf_y: b_y})
    if step % 50 == 0:
        accuracy_ = sess.run(accuracy, {tf_x: test_x, tf_y: test_y})
        print('train loss: %.4f' % loss_, '| test accuracy: %.2f' % accuracy_)

test_output = sess.run(output, {tf_x: test_x[: 10]})
pred_y = np.argmax(test_output, 1)
print(pred_y, 'prediction_number')
print(np.argmax(test_y[: 10], 1), 'real number')

四. 参考

(1) 韩炳涛系列文章:https://www.zybuluo.com/hanbingtao/note/581764
(2) 莫烦系列教程: https://github.com/MorvanZhou/Tensorflow-Tutorial/tree/master/tutorial-contents

  • 3
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值