循环神经网络系列(五)Tensorflow中BasicLSTMCell

本文介绍了Tensorflow中BasicLSTMCell的工作原理,详细解析了细胞状态、权重参数和计算过程,包括输入和隐藏状态的融合、线性变换以及门控机制。通过源码分析,展示了如何从LSTM单元获取参数形状。
摘要由CSDN通过智能技术生成

1.结论

照惯例,先上结论,再说过程,不想看过程的可直接略过。

在这里插入图片描述

从这个图我们可以知道,一个LSTM cell中有4个参数,并且形状都是一样的shape=[output_size+n,output_size],其中n表示输入张量的维度,output_size通过函数BasicLSTMCell(num_units=output_size)获得。

2.怎么来的?

让我们一步一步从Tensorflow的源码中来获得这些信息!

2.1 cell.state_size

首先,需要明白Tensorflow中,state表示的是cell中有几个状态。例如在BasicRNNCell中,state就只有h这一个状态;而在BasicLSTMCell中,state就有h和c这两个状态。其次,state_size表示的是每个状态的第二维度,也就是output_size。

举例:

import tensorflow as tf

output_size = 10
batch_size = 32
dim = 50
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=output_size)
print(cell.state_size)

>>
LSTMStateTuple(c=10, h=10) 

LSTMStateTuple(c=10, h=10)就表示,c和h的output_size都为10,即[batch_size,10]。另外Tensorflow在实现的时候,都将c,h困在一起了,即以Tuple的方式,这也是Tensorflow所推荐的。

2.2 cell.zero_state

在LSTM中,zero_state就自然对应两个部分了, h 0 , c 0 h_0,c_0 h0,c0

import tensorflow as tf

output_size = 10
batch_size = 32
dim = 50
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=output_size)
input = tf.placeholder(dtype=tf.float32, shape=[batch_size, 50])
h0 = cell.zero_state(batch_size=batch_size, dtype=tf.float32)
print(h0)

>>
LSTMStateTuple(c=
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值