LSTM在Keras和Tensorflow中的统一

本文记录了将Tensorflow实现的LSTM模型转换为Keras过程中的挑战和解决方案。通过比较Tensorflow和纯Python版本的LSTM,确保在Keras中配置正确,并探讨了两者中激活函数和权重的对应关系。
摘要由CSDN通过智能技术生成

最近想把一个用到Tensorflow的LSTM的模型改成Keras,崩溃,好在解决了问题,小笔记记录一下

目的

KerasLSTM的输出

Tensorflow的用LSTMCelldynamic_rnn组成的LSTM结果一样。

首先是固定seed然后做一个简单的tf的LSTM的模型,如下

令人抓狂的过程

Tensorflow参考例

这边的输出当作Keras的配置的正确答案
其权重当作Keras初始权重看看输出的答案是否与正确答案一样

forget_bias设置为0的原因在于keras中并没有提供。(但是不影响bias的训练)

import tensorflow as tf
from tensorflow.contrib.rnn import LSTMCell, LSTMStateTuple

import numpy as np

np.random.seed(0)
tf.set_random_seed(0)
batch_size = 1
seq_length = 5
inputs = tf.placeholder(shape=[None, seq_length, 1], dtype=tf.float32)

cell = LSTMCell(num_units=1,
                state_is_tuple=True,
                forget_bias=0.0,
                initializer=None)

rnn_outputs, rnn_states = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float32,
    sequence_length=[seq_length] * batch_size,
    inputs=inputs)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

total_parameters = 0
for variable in tf.trainable_variables():
    print("---- ", variable, " ----")
    print(repr(sess.run(variable)))
print("===========================================")

rnn_outputs_, rnn_states_ = sess.run([rnn_outputs, rnn_states], 
                                     feed
  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值