原理参考:https://blog.csdn.net/banxin1995/article/details/85332465
代码:
with tf.variable_scope('lstm_nn', initializer = lstm_init):
"""
cells = []
for i in range(hps.num_lstm_layers):
cell = tf.contrib.rnn.BasicLSTMCell(
hps.num_lstm_nodes[i],
state_is_tuple = True)
cell = tf.contrib.rnn.DropoutWrapper(
cell,
output_keep_prob = keep_prob)
cells.append(cell)
cell = tf.contrib.rnn.MultiRNNCell(cells)
initial_state = cell.zero_state(batch_size, tf.float32)
# rnn_outputs: [batch_size, num_timesteps, lstm_outputs[-1]]
rnn_outputs, _ = tf.nn.dynamic_rnn(
cell, embed_inputs, initial_state = initial_state)
last = rnn_outputs[:, -1, :]
"""
with tf.variable_scope('inputs'):
ix, ih, ib = _generate_params_for_lstm_cell(
x_size = [hps.num_embedding_size, hps.num_lstm_nodes[0]],
h_size = [hps.num_lstm_nodes[0], hps.num_lstm_nodes[0]],
bias_size = [1, hps.num_lstm_nodes[0]]
)
with tf.variable_scope('outputs'):
ox, oh, ob = _generate_params_for_lstm_cell(
x_size = [hps.num_embedding_size, hps.num_lstm_nodes[0]],
h_size = [hps.num_lstm_nodes[0], hps.num_lstm_nodes[0]],
bias_size = [1, hps.num_lstm_nodes[0]]
)
with tf.variable_scope('forget'):
fx, fh, fb = _generate_params_for_lstm_cell(
x_size = [hps.num_embedding_size, hps.num_lstm_nodes[0]],
h_size = [hps.num_lstm_nodes[0], hps.num_lstm_nodes[0]],
bias_size = [1, hps.num_lstm_nodes[0]]
)
with tf.variable_scope('memory'):
cx, ch, cb = _generate_params_for_lstm_cell(
x_size = [hps.num_embedding_size, hps.num_lstm_nodes[0]],
h_size = [hps.num_lstm_nodes[0], hps.num_lstm_nodes[0]],
bias_size = [1, hps.num_lstm_nodes[0]]
)
state = tf.Variable(
tf.zeros([batch_size, hps.num_lstm_nodes[0]]),
trainable = False
)
h = tf.Variable(
tf.zeros([batch_size, hps.num_lstm_nodes[0]]),
trainable = False
)
for i in range(num_timesteps):
# [batch_size, 1, embed_size]
embed_input = embed_inputs[:, i, :]
embed_input = tf.reshape(embed_input,
[batch_size, hps.num_embedding_size])
forget_gate = tf.sigmoid(
tf.matmul(embed_input, fx) + tf.matmul(h, fh) + fb)
input_gate = tf.sigmoid(
tf.matmul(embed_input, ix) + tf.matmul(h, ih) + ib)
output_gate = tf.sigmoid(
tf.matmul(embed_input, ox) + tf.matmul(h, oh) + ob)
mid_state = tf.tanh(
tf.matmul(embed_input, cx) + tf.matmul(h, ch) + cb)
state = mid_state * input_gate + state * forget_gate
h = output_gate * tf.tanh(state)
last = h
fc_init = tf.uniform_unit_scaling_initializer(factor=1.0)