keras 中lstm 和bidirectional lstm (Bilstm)的return_sequence和return_state之间的关系

15 篇文章 0 订阅
12 篇文章 0 订阅

import tensorflow as tf
tf.enable_eager_execution()



embedding = tf.Variable(tf.truncated_normal((2, 3, 4)))

lstm = tf.keras.layers.LSTM(units=5, return_sequences=False, return_state=False)
outputs = lstm(embedding)  # return_sequences=False, return_state=False
print(outputs)  # 只有每个样本的最后一个time step的输出, shape:(batch, hidden_units)

lstm = tf.keras.layers.LSTM(units=5, return_sequences=True, return_state=False)

outputs = lstm(embedding)
print(outputs)   # (batch, seq_len, hidden_units)


lstm = tf.keras.layers.LSTM(units=5, return_sequences=True, return_state=True)
outputs, hidden, state = lstm(embedding)
"""
outputs: (batch, seq_len, hidden_units)
hidden: (batch, hidden_units), 是每个样本对应最后一个time step的输出, 这个输出对应着各自样本的最后一个time step
state: (batch, hidden_units), 是每个样本对应的记忆状态, 代表着cell state。
"""
print(outputs)
print(hidden)
print(state)




bilstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(5, return_sequences=False, return_state=False), merge_mode="concat")
outputs = bilstm(embedding)  # shape: (batch, hidden_units * 2), 这个取决于merge_mode参数的设置
print(outputs)
"""
outputs: shape: (batch, hidden_units * 2), 这个取决于merge_mode参数的设置
"""




bilstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(5, return_sequences=True, return_state=False), merge_mode="concat")
outputs = bilstm(embedding)
print(outputs)
"""
outputs: shape=(batch,seq_len,hidden_units * 2), 最后一个维度值取决于merge_mode
"""
print("====")
bilstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(2, return_sequences=True, return_state=True), merge_mode="concat")
outputs = bilstm(embedding)
print(outputs)
"""
o, f_o, f_c, b_o, b_c = outputs
输出包括三部分:
1、layer output
2、 (h, c) for forward lstm
3、 (h, c) for backward lstm

"""

 

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值