Conclusion, for a batch of variable length sequences, the building time of static_rnn is longer, but it runs faster.
结论,对于一个batch的变长的序列,static_rnn建图的时候慢,但跑的时候快。
In the beginning, I thought static_rnn is just the new name for tf.nn.rnn. And the difference between static_rnn and dynamic_rnn is that dynamic_rnn support variable length sequences in one batch. And dynamic_rnn will only compute the first n timesteps, where n is the length of that sequence.
However, it is not the case. I used the following code to test the performance.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
tf.flags.DEFINE_integer('batch_size', 200, 'batch size')
tf.flags.DEFINE_bool('use_dynamic', False, 'whether to use dynamic rnn')
FLAGS = tf.flags.FLAGS
def model_fn(features, labels, mode):
length = features[1]
features = features[0]
cell = tf.nn.rnn_cell.BasicLSTMCell(200)
if FLAGS.use_dynamic:
output, state = tf.nn.dynamic_rnn(cell, features, length, dtype=tf.float32,
parallel_iterations=FLAGS.batch_size)
else:
output, state = tf.nn.static_rnn(cell, tf.unstack(features, axis=1),
sequence_length=length, dtype=tf.float32)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=state[1])
optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op)
def input_fn():
data = np.random.rand(1000, 300, 16).astype(np.float32)
length = np.ones(1000, dtype=np.int32) * 20
label = np.random.randint(200, size=1000, dtype=np.int32)
dataset = tf.data.Dataset.from_tensor_slices((data, length, label))
dataset = dataset.repeat()
dataset = dataset.batch(FLAGS.batch_size)
iterator = dataset.make_one_shot_iterator()
feat, length, label = iterator.get_next()
return [feat, length], label
def main(_):
session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
run_config = tf.estimator.RunConfig(keep_checkpoint_max=100,
save_summary_steps=100,
session_config=session_config)
estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir='saving',
config=run_config)
estimator.train(input_fn=input_fn, max_steps=3801)
if __name__ == '__main__':
tf.app.run()
The running time for static_rnn is
INFO:tensorflow:Saving checkpoints for 1 into saving/model.ckpt.
INFO:tensorflow:loss = 5.30413, step = 1
INFO:tensorflow:global_step/sec: 2.93677
INFO:tensorflow:loss = 5.1139, step = 101 (34.051 sec)
INFO:tensorflow:global_step/sec: 20.3018
INFO:tensorflow:loss = 4.93272, step = 201 (4.925 sec)
The running time for dynamic_rnn is
INFO:tensorflow:Saving checkpoints for 1 into saving/model.ckpt.
INFO:tensorflow:loss = 5.30566, step = 1
INFO:tensorflow:global_step/sec: 5.26866
INFO:tensorflow:loss = 5.08169, step = 101 (18.981 sec)
INFO:tensorflow:global_step/sec: 5.4056
INFO:tensorflow:loss = 4.65338, step = 201 (18.499 sec)