其中该博客的展开的RNN图片有错误,左侧的原始状态输入应该没有纵向的数据传递连线
关于初始状态的输入数据结构重新分析
init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32)
outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False),
注意上面这个init_state,若是单层返回的为 [batch_size,num_units],若是多层则返回的是列表[],见tensorflow的API,为什么zero_state该方法只需传递一个参数batch_size的讨论,注意init_state的数据结构不简单!
zero_state(
batch_size, dtype
)
Return zero-filled state tensor(s).
Args:
batch_size
: int, float, or unit Tensor representing the batch size.dtype
: the data type to use for the state.
Returns:
If state_size
is an int or TensorShape, then the return value is a N-D
tensor of shape [batch_size, state_size]
filled with zeros.
If state_size
is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-D
tensors with the shapes [batch_size, s]
for each s in state_size
.
mlstm_cell是个MultiRNNCell类,它内部的hidden_size和layer_num已经定义好了,所以只需要告诉它batch是多少就可以了,所以这里只传入了batch_size
返回的是[batck_size,state_size],其实也就是[batck_size,num_units],初始数据形状与layer_num有关系吗,没有,只与batch_size*hidden_size也即batch*num_units有关系吧?是的
tensorflow API:
https://tensorflow.google.cn/api_docs/python/tf/compat/v1/nn/rnn_cell/MultiRNNCell?hl=en
# 把784个点的字符信息还原成 28 * 28 的图片
# 下面几个步骤是实现 RNN / LSTM 的关键
####################################################################
# **步骤1:RNN 的输入shape = (batch_size, timestep_size, input_size)
X = tf.reshape(_X, [-1, 28, 28])
# **步骤2:定义一层 LSTM_cell,只需要说明 hidden_size, 它会自动匹配输入的 X 的维度
lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True)
# **步骤3:添加 dropout layer, 一般只设置 output_keep_prob
lstm_cell = rnn.DropoutWrapper(cell=lstm_cell, input_keep_prob=1.0, output_keep_prob=keep_prob)
# **步骤4:调用 MultiRNNCell 来实现多层 LSTM
mlstm_cell = rnn.MultiRNNCell([lstm_cell] * layer_num, state_is_tuple=True)
# **步骤5:用全零来初始化state
init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32)
# **步骤6:方法一,调用 dynamic_rnn() 来让我们构建好的网络运行起来
# ** 当 time_major==False 时, outputs.shape = [batch_size, timestep_size, hidden_size]
# ** 所以,可以取 h_state = outputs[:, -1, :] 作为最后输出
# ** state.shape = [layer_num, 2, batch_size, hidden_size],
# ** 或者,可以取 h_state = state[-1][1] 作为最后输出
# ** 最后输出维度是 [batch_size, hidden_size]
# outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False)
# h_state = outputs[:, -1, :] # 或者 h_state = state[-1][1]
# *************** 为了更好的理解 LSTM 工作原理,我们把上面 步骤6 中的函数自己来实现 ***************
# 通过查看文档你会发现, RNNCell 都提供了一个 __call__()函数(见最后附),我们可以用它来展开实现LSTM按时间步迭代。
# **步骤6:方法二,按时间步展开计算
outputs = list()
state = init_state
with tf.variable_scope('RNN'):
for timestep in range(timestep_size):
if timestep > 0:
tf.get_variable_scope().reuse_variables()
# 这里的state保存了每一层 LSTM 的状态
(cell_output, state) = mlstm_cell(X[:, timestep, :], state)
outputs.append(cell_output)
h_state = outputs[-1]
————————————————
版权声明:本文为CSDN博主「永永夜」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/Jerr__y/article/details/61195257