概述
循环神经网络(RNN)用于处理序列标注等问题,在自然语言处理、语音识别等有很广泛的用途。LSTM和GRU是目前使用最广泛的两个循环神经网络的模型变种,TensorFlow中已经把这些常用的模型封装的很好,但是在实际工作中,我们经常需要根据需求的不同对LSTM或者GRU进行一些修改,甚至需要重新实现一种RNN模型,本篇文章主要带领读者分析了解一下TensorFlow中RNN系列模型的实现原理,让读者在实现自己的RNN模型时有个参考。
基本概念
本篇主要介绍TensorFlow中RNN系列模型的源码,下面只简单回顾一下相关模型的数学公式,具体原理还请自行查找其它资料。
RNN
LSTM
遗忘门:
输入门:
输出门:
状态层:
输出层:
GRU
重置门:
更新门:
输出层:
基本流程
在TensorFlow中,RNN相关的源码主要分为两类,一类是表示基础Cell实现逻辑的类,这些类都继承自RNNCell类,主要包括BasicRNNCell、BasicLSTMCell、GRUCell等。另外一类就是让cell在不同时间轴上运转起来的循环流程控制类,包括动态单向RNN流程类tf.nn.dynamic_rnn、动态双向RNN流程类tf.nn.bidirectional_dynamic_rnn等,本篇文章主要介绍更简单一些的tf.nn.dynamic_rnn。
RNNCells
RNNCell
RNNCell是所有RNN系列模型继承的基础Cell类。其原始代码如下:
class RNNCell(base_layer.Layer):
def __call__(self, inputs, state, scope=None):
if scope is not None:
with vs.variable_scope(scope,
custom_getter=self._rnn_get_variable) as scope:
return super(RNNCell, self).__call__(inputs, state, scope=scope)
else:
scope_attrname = "rnncell_scope"
scope = getattr(self, scope_attrname, None)
if scope is None:
scope = vs.variable_scope(vs.get_variable_scope(),
custom_getter=self._rnn_get_variable)
setattr(self, scope_attrname, scope)
with scope:
return super(RNNCell, self).__call__(inputs, state)
def _rnn_get_variable(self, getter, *args, **kwargs):
variable = getter(*args, **kwargs)
if context.in_graph_mode():
trainable = (variable in tf_variables.trainable_variables() or
(isinstance(variable, tf_variables.PartitionedVariable) and
list(variable)[0] in tf_variables.trainable_variables()))
else:
trainable = variable._trainable # pylint: disable=protected-access
if trainable and variable not in self._trainable_weights:
self._trainable_weights.append(variable)
elif not trainable and variable not in self._non_trainable_weights:
self._non_trainable_weights.append(variable)
return variable
@property
def state_size(self):
raise NotImplementedError("Abstract method")
@property
def output_size(self):
raise NotImplementedError("Abstract method")
def build(self, _):
pass
def zero_state(self, batch_size, dtype):
state_size = self.state_size
if hasattr(self, "_last_zero_state"):
(last_state_size, last_batch_size, last_dtype,
last_output) = getattr(self, "_last_zero_state")
if (last_batch_size == batch_size and
last_dtype == dtype and
last_state_size == state_size):
return last_output
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
output = _zero_state_tensors(state_size, batch_size, dtype)
self._last_zero_state = (state_size, batch_size, dtype, output)
return output
RNNCell的__call__方法的输入参数包括inputs和state,inputs是每个time_step输入的数据,是一个二维Tensor对象,两个维度分别是批次数量和每个具体的输入数据,inputs的shape为[batch_size, input_size];state是上一步Cell后保存下来的状态向量,其第一个维度也是批次数量batch_size,第二个维度根据state_size方法配置的不同,可以是状态向量(此时state的shape为[batch_size, state_size]),也可以是一个tuple,其中包含状态向量和任意自定义的其它数据(此时state的shape为[batch_size, tuple(state_size, ...)])。
在RNNCell的__call__方法中,首先通过调用_rnn_get_variable方法把输入的variable划分为可训练参数和不可训练参数,并分别保存于该类的_trainable_weights和_nontrainable_weights成员变量中,完后调用其父类base_layer.Layer类的__call__方法,在这个方法中,最终调用的实现的Cell的call方法,所有继承自RNNCell的Cell类都需要实现一个call,这个方法的入参即是inputs和state,方法内容是每个Cell的具体计算逻辑,通过RNNCell类__call__方法的封装,使得一个Cell实例可以像一个方法一样被调用。
def __call__(self, inputs, state, scope=None):
if scope is not None:
with vs.variable_scope(scope,
custom_getter=self._rnn_get_variable) as scope:
return super(RNNCell, self).__call__(inputs, state, scope=scope)
else:
scope_attrname = "rnncell_scope"
scope = getattr(self, scope_attrname, None)
if scope is None:
scope = vs.variable_scope(vs.get_variable_scope(),
custom_getter=self._rnn_get_variable)
setattr(self, scope_attrname, scope)
with scope:
return super(RNNCell, self).__call__(inputs, state)
def _rnn_get_variable(self, getter, *args, **kwargs):
variable = getter(*args, **kwargs)
if context.in_graph_mode():
trainable = (variable in tf_variables.trainable_variables() or
(isinstance(variable, tf_variables.PartitionedVariable) and
list(variable)[0] in tf_variables.trainable_variables()))
else:
trainable = variable._trainable # pylint: disable=protected-access
if trainable and variable not in self._trainable_weights:
self._trainable_weights.append(variable)
elif not trainable and variable not in self._non_trainable_weights:
self._non_trainable_weights.append(variable)
return variable
state_size方法和output_size方法需要相应的子Cell继承并实现,这两个方法主要用于设置状态向量state和输出向量output的维度。由于这两个方法都有@property装饰器装饰,其它类可以以属性的方式调用这两个方法。
@property
def state_size(self):
raise NotImplementedError("Abstract method")
@property
def output_size(self):
raise NotImplementedError("Abstract method")
build方法也需要由具体的Cell继承并实现,build是base_layer.Layer类中推荐的方法,其会在__call__方法中被调用,调用前系统已经知道了当前层(Layer)输入参数的维度即类型,在build方法中用于应该调用add_variable()方法进行待学习变量的初始化工作。
def build(self, _):
pass
在RNNCell中最后是zero_state方法,用于生成值为0的状态向量的值,如果具体的Cell没有设置初始化状态向量的方法,则会调用此方法生成状态向量的初始值。
def zero_state(self, batch_size, dtype):
state_size = self.state_size
if hasattr(self, "_last_zero_state"):
(last_state_size, last_batch_size, last_dtype,
last_output) = getattr(self, "_last_zero_state")
if (last_batch_size == batch_size and
last_dtype == dtype and
last_state_size == state_size):
return last_output
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
output = _zero_state_tensors(state_size, batch_size, dtype)
self._last_zero_state = (state_size, batch_size, dtype, output)
return output
总结一下:
每个具体的Cell实现应该继承自RNNCell类,同时实现__init__、build、call、state_size、output_size等方法,其中__init__、build、call三个方法是TensorFlow中实现不同功能层通用的建议流程,首先调用__init__方法保存一些配置,当获取到输入数据的维度和类型后,调用build方法初始化待学习的参数、而call方法中是真正的计算逻辑。state_size和output_size针对RNN Cell特有的两个方法,用于设置状态向量state和输出向量output的维度。
BasicRNNCell
rnn_cell_impl中的BasicRNNCell类实现了基础的RNN功能。
基础的RNN只是一个需要在time_step上做多次前向传播计算的全连接网络,其隐藏层计算就是一个全连接。
在__init__方法中,配置了状态层神经元个数num_units,使用的激活函数activation。
在TensorFlow的不同功能层通用的建议流程,建议在__init__方法中设置base_layer.InputSpec类型的input_spec参数,用于告知输入数据的各种属性,包括输入的维度数量、每个维度的尺寸、以及数据类型等,在BasicRNNCell中设置了输入数据的维度数量为2。
def __init__(self, num_units, activation=None, reuse=None, name=None):
super(BasicRNNCell, self).__init__(_reuse=reuse, name=name)
self.input_spec = base_layer.InputSpec(ndim=2)
self._num_units = num_units
self._activation = activation or math_ops.tanh
BasicRNNCell的state_size方法和output_size方法把状态层神经元个数和输出层神经元个数都设置成了num_units。
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
在build方法中,根据输入数据的维度,生成待学习的权重和偏置,其中传入的inputs_shape为[batch_size, input_size],input_depth即输入数据的维数,_kernel的维度为[input_depth + self._num_units, self._num_units],其中输入神经元个数即是,输出是state_size方法的返回值即num_units。
def build(self, inputs_shape):
if inputs_shape[1].value is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
% inputs_shape)
input_depth = inputs_shape[1].value
self._kernel = self.add_variable(
_WEIGHTS_VARIABLE_NAME,
shape=[input_depth + self._num_units, self._num_units])
self._bias = self.add_variable(
_BIAS_VARIABLE_NAME,
shape=[self._num_units],
initializer=init_ops.zeros_initializer(dtype=self.dtype))
self.built = True
call方法中,把输入参数和之前time_step的状态值整合,与_kernel相乘,加上偏置后再调用激活函数,可以看出其实就是做了一次全连接。call方法的返回值包括当前time_step的输出向量,以及当前time_step的状态向量,在普通的RNN中,当前time_step的输出向量和状态向量是一样的。
def call(self, inputs, state):
gate_inputs = math_ops.matmul(
array_ops.concat([inputs, state], 1), self._kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
output = self._activation(gate_inputs)
return output, output
BasicLSTMCell
rnn_cell_impl中的BasicLSTMCell类实现了基础的LSTM功能。另外rnn_cell_impl中的LSTMCell则实现了更为复杂的功能,比如单元裁剪、投影层等,本篇中我们只介绍BasicLSTMCell类的实现。
在__init__方法中,除了与BasicRNNCell一样的一些基本的配置外,BasicLSTMCell还添加了forget_bias和state_is_tuple两个配置,其中state_is_tuple如果为True,则传入call方法的state,以及call方法返回的当前time_step的state,都是以tuple格式的LSTMStateTuple类保存了当前time_step的状态层和输出层,在lstm中,这两个向量是不一样的;如果state_is_tuple为False,则call方法返回的当前time_step的state,是通过按列拼接的方式把当前time_step的状态层和输出层拼接在一个向量中返回的。
def __init__(self, num_units, forget_bias=1.0,
state_is_tuple=True, activation=None, reuse=None, name=None):
super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name)
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
self.input_spec = base_layer.InputSpec(ndim=2)
self._num_units = num_units
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation or math_ops.tanh
正如之前讨论state_is_tuple参数的时候说到的,在state_size方法中,如果state_is_tuple为True,则state的维度为类似tuple的格式、否则为状态层和输出层拼接在一起,即2倍的num_unites。
@property
def state_size(self):
return (LSTMStateTuple(self._num_units, self._num_units)
if self._state_is_tuple else 2 * self._num_units)
@property
def output_size(self):
return self._num_units
build方法中,注意kernel权重的输出神经元个数为4*num_units,这里把神经元个数设置成num_units的4倍的原因,是因为在lstm的公式中,遗忘门、输入门、输出门及状态层的计算基数都是我们可以把这四个计算放到一起
,在计算完以后再拆开即可。
def build(self, inputs_shape):
if inputs_shape[1].value is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
% inputs_shape)
input_depth = inputs_shape[1].value
h_depth = self._num_units
self._kernel = self.add_variable(
_WEIGHTS_VARIABLE_NAME,
shape=[input_depth + h_depth, 4 * self._num_units])
self._bias = self.add_variable(
_BIAS_VARIABLE_NAME,
shape=[4 * self._num_units],
initializer=init_ops.zeros_initializer(dtype=self.dtype))
self.built = True
最后看一下BasicLSTMCell类的call方法,gate_inputs即使把四个通用计算整合在一起进行计算,完后通过array_ops.split把计算结果分拆成输入门、状态层、遗忘门及输出门,完后根据公式生成新的状态向量及输出向量,根据state_is_tuple参数的值决定是把这两个向量封装进LSTMStateTuple返回,还是拼接在一起返回。
def call(self, inputs, state):
sigmoid = math_ops.sigmoid
one = constant_op.constant(1, dtype=dtypes.int32)
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)
gate_inputs = math_ops.matmul(
array_ops.concat([inputs, h], 1), self._kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(
value=gate_inputs, num_or_size_splits=4, axis=one)
forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
# Note that using `add` and `multiply` instead of `+` and `*` gives a
# performance improvement. So using those at the cost of readability.
add = math_ops.add
multiply = math_ops.multiply
new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),
multiply(sigmoid(i), self._activation(j)))
new_h = multiply(self._activation(new_c), sigmoid(o))
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
GRUCell
GRUCell的__init__方法与之前的Cell类似,唯一的不同是用户可以在此设置权重初始化器及偏置初始化器。
def __init__(self,
num_units,
activation=None,
reuse=None,
kernel_initializer=None,
bias_initializer=None,
name=None):
super(GRUCell, self).__init__(_reuse=reuse, name=name)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
self._num_units = num_units
self._activation = activation or math_ops.tanh
self._kernel_initializer = kernel_initializer
self._bias_initializer = bias_initializer
GRUCell的状态向量state和输出向量output的size都是num_units。
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
在GRU的公式中,重置门和更新门的基数都是,可以合在一起计算,所以_gate_kernel输出神经元的个数是2*num_units,状态层的基数是
,无法合在一起计算,所以需要单独初始化candidate_kernel参数。
def build(self, inputs_shape):
if inputs_shape[1].value is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
% inputs_shape)
input_depth = inputs_shape[1].value
self._gate_kernel = self.add_variable(
"gates/%s" % _WEIGHTS_VARIABLE_NAME,
shape=[input_depth + self._num_units, 2 * self._num_units],
initializer=self._kernel_initializer)
self._gate_bias = self.add_variable(
"gates/%s" % _BIAS_VARIABLE_NAME,
shape=[2 * self._num_units],
initializer=(
self._bias_initializer
if self._bias_initializer is not None
else init_ops.constant_initializer(1.0, dtype=self.dtype)))
self._candidate_kernel = self.add_variable(
"candidate/%s" % _WEIGHTS_VARIABLE_NAME,
shape=[input_depth + self._num_units, self._num_units],
initializer=self._kernel_initializer)
self._candidate_bias = self.add_variable(
"candidate/%s" % _BIAS_VARIABLE_NAME,
shape=[self._num_units],
initializer=(
self._bias_initializer
if self._bias_initializer is not None
else init_ops.zeros_initializer(dtype=self.dtype)))
self.built = True
在call方法中,首先计算重置门和更新门,完后计算状态层,最后计算输出。
def call(self, inputs, state):
gate_inputs = math_ops.matmul(
array_ops.concat([inputs, state], 1), self._gate_kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
value = math_ops.sigmoid(gate_inputs)
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
r_state = r * state
candidate = math_ops.matmul(
array_ops.concat([inputs, r_state], 1), self._candidate_kernel)
candidate = nn_ops.bias_add(candidate, self._candidate_bias)
c = self._activation(candidate)
new_h = u * state + (1 - u) * c
return new_h, new_h
流程控制
RNN是处理序列数据的模型,上面介绍的Cell只是序列数据中每一步需要计算的内容,所以还需要一个流程控制类从输入数据的第一步开始,计算每一步的Cell,保存返回值并传递给下一步的Cell使用,并且把最终的返回值作为整个RNN模型的返回值返回。当今的RNN从计算方向上分为单向RNN和双向RNN;从层数上分为单层RNN和多层RNN,在TensorFlow中,单向RNN控制类为tf.nn.dynamic_rnn,双向RNN控制类为tf.nn.bidirectional_dynamic_rnn,多层RNN是在普通的RNNCell外面添加了一个tensorflow.contrib.rnn.MultiRNNCell类,完后根据单向还是双向还是使用dynamic_rnn或者bidirectional_dynamic_rnn。本文只介绍tf.nn.dynamic_rnn的内容,其它类还请自行参考源码。
dynamic_rnn
在构建RNN模型时,在创建完我们需要的RNNCell类以后,把cell和输入参数传递给tf.nn.dynamic_rnn后,就算完成了基于具体input的RNN模型的创建了,tf.nn.dynamic_rnn返回每一步的输出outputs和最后一步的状态state,如下所示:
# create a BasicRNNCell
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
initial_state=initial_state,
dtype=tf.float32)
dynamic_rnn方法接受的参数除了cell和inputs以外,sequence_length用于标识inputs中每一个输入的真实序列长度,因为inputs其实是一批数据,准备数据时需要按照其中最长的一个输入的长度定义inputs第二维的长度,其它短些的数据需要在后面补位,如果指定了sequence_length,那么在计算每一个输入数据时,只会计算真实长度的内容,剩下的补位数据不计算。initial_state可以定义状态的初始化内容。
def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
dtype=None, parallel_iterations=None, swap_memory=False,
time_major=False, scope=None):
一般情况下用户整理好的输入数据的维度是[batch_size, time_step, input_size]的,在dynamic_rnn方法中为了方便计算,需要把数据的维度转换成[time_stet, batch_size, input_size],如下所示,_transpose_batch_time方法即负责此转换:
flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
如果用户传递了状态state的初始化值,则把此值直接设置为state的初始化值,否则调用RNNCell父类的zero_state把state初始化为全0的值。
if initial_state is not None:
state = initial_state
else:
if not dtype:
raise ValueError("If there is no initial_state, you must give a dtype.")
state = cell.zero_state(batch_size, dtype)
完后就调用_dynamic_rnn_loop方法进行下一步处理了。_dynamic_rnn_loop方法返回了每一步的输出outputs和最后一步的状态final_state。
(outputs, final_state) = _dynamic_rnn_loop(
cell,
inputs,
state,
parallel_iterations=parallel_iterations,
swap_memory=swap_memory,
sequence_length=sequence_length,
dtype=dtype)
可以看到dynamic_rnn方法只是对传入的参数做了一些校验,把输入参数根据计算需要改变了一下维度顺序,完后设置了状态state的初始化值,同时还设置了parallel_iterations并行化度和sequence_length等参数,完后就交给_dynamic_rnn_loop方法处理了。
_dynamic_rnn_loop
在_dynamic_rnn_loop方法中,主要创建了output TensorArray和input TensorArray,TensorArray在TensorFlow中可以看做是具有动态size功能的Tensor数组,output和input TensorArray的size都是time_step的值,output值创建了一个空的TensorArray,具体的值会在每一步计算完成后填入TensorArray中,input TensorArray在创建完成后,直接把输入input中的值填入其中。
# 创建TensorArray
def _create_ta(name, element_shape, dtype):
return tensor_array_ops.TensorArray(dtype=dtype,
size=time_steps,
element_shape=element_shape,
tensor_array_name=base_name + name)
# 生成output TensorArray和input TensorArray
in_graph_mode = context.in_graph_mode()
if in_graph_mode:
output_ta = tuple(
_create_ta(
"output_%d" % i,
element_shape=(tensor_shape.TensorShape([const_batch_size])
.concatenate(
_maybe_tensor_shape_from_tensor(out_size))),
dtype=_infer_state_dtype(dtype, state))
for i, out_size in enumerate(flat_output_size))
input_ta = tuple(
_create_ta(
"input_%d" % i,
element_shape=flat_input_i.shape[1:],
dtype=flat_input_i.dtype)
for i, flat_input_i in enumerate(flat_input))
# 把input灌入到TensorArray中
input_ta = tuple(ta.unstack(input_)
for ta, input_ in zip(input_ta, flat_input))
else:
output_ta = tuple([0 for _ in range(time_steps.numpy())]
for i in range(len(flat_output_size)))
input_ta = flat_input
完后就调用TensorFlow中的循环方法while_loop方法,循环time_steps次,调用_time_step执行cell计算,更新time、output_ta及state的值。
loop_bound = time_steps
_, output_final_ta, final_state = control_flow_ops.while_loop(
cond=lambda time, *_: time < loop_bound,
body=_time_step,
loop_vars=(time, output_ta, state),
parallel_iterations=parallel_iterations,
maximum_iterations=time_steps,
swap_memory=swap_memory)
_time_step
首先从inputs中读取当前time_step的input值。
input_t = tuple(ta.read(time) for ta in input_ta)
定义调用cell的方法,传入当前time_step的input和上一步的state。
call_cell = lambda: cell(input_t, state)
如果设置了sequence_length参数,则调用_rnn_step方法根据每个输入的实际长度执行计算,否则直接调用call_cell(),即调用相应cell的call方法。
if sequence_length is not None:
(output, new_state) = _rnn_step(
time=time,
sequence_length=sequence_length,
min_sequence_length=min_sequence_length,
max_sequence_length=max_sequence_length,
zero_output=zero_output,
state=state,
call_cell=call_cell,
state_size=state_size,
skip_conditionals=True)
else:
(output, new_state) = call_cell()
把当前time_step的output写入output TensorArray中。
output_ta_t = tuple(
ta.write(time, out) for ta, out in zip(output_ta_t, output))
注意,return的数据,需要把time加1,另外两个返回值为更新后的输出output TensorArray,以及当前time生成的状态state。
return (time + 1, output_ta_t, new_state)
_rnn_step
此方法会根据设置的sequence_length,在执行计算的同时,每个输入数据sequence_length使用空值填充。
def _rnn_step(
time, sequence_length, min_sequence_length, max_sequence_length,
zero_output, state, call_cell, state_size, skip_conditionals=False):
"""Calculate one step of a dynamic RNN minibatch.
Returns an (output, state) pair conditioned on `sequence_length`.
When skip_conditionals=False, the pseudocode is something like:
if t >= max_sequence_length:
return (zero_output, state)
if t < min_sequence_length:
return call_cell()
# Selectively output zeros or output, old state or new state depending
# on whether we've finished calculating each row.
new_output, new_state = call_cell()
final_output = np.vstack([
zero_output if time >= sequence_length[r] else new_output_r
for r, new_output_r in enumerate(new_output)
])
final_state = np.vstack([
state[r] if time >= sequence_length[r] else new_state_r
for r, new_state_r in enumerate(new_state)
])
return (final_output, final_state)
Args:
time: Python int, the current time step
sequence_length: int32 `Tensor` vector of size [batch_size]
min_sequence_length: int32 `Tensor` scalar, min of sequence_length
max_sequence_length: int32 `Tensor` scalar, max of sequence_length
zero_output: `Tensor` vector of shape [output_size]
state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`,
or a list/tuple of such tensors.
call_cell: lambda returning tuple of (new_output, new_state) where
new_output is a `Tensor` matrix of shape `[batch_size, output_size]`.
new_state is a `Tensor` matrix of shape `[batch_size, state_size]`.
state_size: The `cell.state_size` associated with the state.
skip_conditionals: Python bool, whether to skip using the conditional
calculations. This is useful for `dynamic_rnn`, where the input tensor
matches `max_sequence_length`, and using conditionals just slows
everything down.
Returns:
A tuple of (`final_output`, `final_state`) as given by the pseudocode above:
final_output is a `Tensor` matrix of shape [batch_size, output_size]
final_state is either a single `Tensor` matrix, or a tuple of such
matrices (matching length and shapes of input `state`).
Raises:
ValueError: If the cell returns a state tuple whose length does not match
that returned by `state_size`.
"""
# Convert state to a list for ease of use
flat_state = nest.flatten(state)
flat_zero_output = nest.flatten(zero_output)
def _copy_one_through(output, new_output):
# TensorArray and scalar get passed through.
if isinstance(output, tensor_array_ops.TensorArray):
return new_output
if output.shape.ndims == 0:
return new_output
# Otherwise propagate the old or the new value.
copy_cond = (time >= sequence_length)
with ops.colocate_with(new_output):
return array_ops.where(copy_cond, output, new_output)
def _copy_some_through(flat_new_output, flat_new_state):
# Use broadcasting select to determine which values should get
# the previous state & zero output, and which values should get
# a calculated state & output.
flat_new_output = [
_copy_one_through(zero_output, new_output)
for zero_output, new_output in zip(flat_zero_output, flat_new_output)]
flat_new_state = [
_copy_one_through(state, new_state)
for state, new_state in zip(flat_state, flat_new_state)]
return flat_new_output + flat_new_state
def _maybe_copy_some_through():
"""Run RNN step. Pass through either no or some past state."""
new_output, new_state = call_cell()
nest.assert_same_structure(state, new_state)
flat_new_state = nest.flatten(new_state)
flat_new_output = nest.flatten(new_output)
return control_flow_ops.cond(
# if t < min_seq_len: calculate and return everything
time < min_sequence_length, lambda: flat_new_output + flat_new_state,
# else copy some of it through
lambda: _copy_some_through(flat_new_output, flat_new_state))
# TODO(ebrevdo): skipping these conditionals may cause a slowdown,
# but benefits from removing cond() and its gradient. We should
# profile with and without this switch here.
if skip_conditionals:
# Instead of using conditionals, perform the selective copy at all time
# steps. This is faster when max_seq_len is equal to the number of unrolls
# (which is typical for dynamic_rnn).
new_output, new_state = call_cell()
nest.assert_same_structure(state, new_state)
new_state = nest.flatten(new_state)
new_output = nest.flatten(new_output)
final_output_and_state = _copy_some_through(new_output, new_state)
else:
empty_update = lambda: flat_zero_output + flat_state
final_output_and_state = control_flow_ops.cond(
# if t >= max_seq_len: copy all state through, output zeros
time >= max_sequence_length, empty_update,
# otherwise calculation is required: copy some or all of it through
_maybe_copy_some_through)
if len(final_output_and_state) != len(flat_zero_output) + len(flat_state):
raise ValueError("Internal error: state and output were not concatenated "
"correctly.")
final_output = final_output_and_state[:len(flat_zero_output)]
final_state = final_output_and_state[len(flat_zero_output):]
for output, flat_output in zip(final_output, flat_zero_output):
output.set_shape(flat_output.get_shape())
for substate, flat_substate in zip(final_state, flat_state):
if not isinstance(substate, tensor_array_ops.TensorArray):
substate.set_shape(flat_substate.get_shape())
final_output = nest.pack_sequence_as(
structure=zero_output, flat_sequence=final_output)
final_state = nest.pack_sequence_as(
structure=state, flat_sequence=final_state)
return final_output, final_state