# Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
# ============================================================================== |
|
"""Module implementing RNN Cells. |
|
This module provides a number of basic commonly used RNN cells, such as LSTM |
|
(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of |
|
operators that allow adding dropouts, projections, or embeddings for inputs. |
|
Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by |
|
calling the `rnn` ops several times. |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
from tensorflow.python.eager import context |
|
from tensorflow.python.framework import constant_op |
|
from tensorflow.python.framework import dtypes |
|
from tensorflow.python.framework import ops |
|
from tensorflow.python.framework import tensor_shape |
|
from tensorflow.python.framework import tensor_util |
|
from tensorflow.python.layers import base as base_layer |
|
from tensorflow.python.ops import array_ops |
|
from tensorflow.python.ops import clip_ops |
|
from tensorflow.python.ops import init_ops |
|
from tensorflow.python.ops import math_ops |
|
from tensorflow.python.ops import nn_ops |
|
from tensorflow.python.ops import partitioned_variables |
|
from tensorflow.python.ops import random_ops |
|
from tensorflow.python.ops import tensor_array_ops |
|
from tensorflow.python.ops import variable_scope as vs |
|
from tensorflow.python.ops import variables as tf_variables |
|
from tensorflow.python.platform import tf_logging as logging |
|
from tensorflow.python.util import nest |
|
_BIAS_VARIABLE_NAME = "bias" |
|
_WEIGHTS_VARIABLE_NAME = "kernel" |
|
"""Checks that a given object is an RNNCell by using duck typing.""" |
|
conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"), |
|
hasattr(cell, "zero_state"), callable(cell)] |
|
def _concat(prefix, suffix, static=False): |
|
"""Concat that enables int, Tensor, or TensorShape values. |
|
This function takes a size specification, which can be an integer, a |
|
TensorShape, or a Tensor, and converts it into a concatenated Tensor |
|
(if static = False) or a list of integers (if static = True). |
|
prefix: The prefix; usually the batch size (and/or time step size). |
|
(TensorShape, int, or Tensor.) |
|
suffix: TensorShape, int, or Tensor. |
|
static: If `True`, return a python list with possibly unknown dimensions. |
|
Otherwise return a `Tensor`. |
|
shape: the concatenation of prefix and suffix. |
|
ValueError: if `suffix` is not a scalar or vector (or TensorShape). |
|
ValueError: if prefix or suffix was `None` and asked for dynamic |
|
if isinstance(prefix, ops.Tensor): |
|
p_static = tensor_util.constant_value(prefix) |
|
p = array_ops.expand_dims(p, 0) |
|
raise ValueError("prefix tensor must be either a scalar or vector, " |
|
"but saw tensor: %s" % p) |
|
p = tensor_shape.as_shape(prefix) |
|
p_static = p.as_list() if p.ndims is not None else None |
|
p = (constant_op.constant(p.as_list(), dtype=dtypes.int32) |
|
if p.is_fully_defined() else None) |
|
if isinstance(suffix, ops.Tensor): |
|
s_static = tensor_util.constant_value(suffix) |
|
s = array_ops.expand_dims(s, 0) |
|
raise ValueError("suffix tensor must be either a scalar or vector, " |
|
"but saw tensor: %s" % s) |
|
s = tensor_shape.as_shape(suffix) |
|
s_static = s.as_list() if s.ndims is not None else None |
|
s = (constant_op.constant(s.as_list(), dtype=dtypes.int32) |
|
if s.is_fully_defined() else None) |
|
shape = tensor_shape.as_shape(p_static).concatenate(s_static) |
|
shape = shape.as_list() if shape.ndims is not None else None |
|
if p is None or s is None: |
|
raise ValueError("Provided a prefix or suffix of None: %s and %s" |
|
shape = array_ops.concat((p, s), 0) |
|
def _zero_state_tensors(state_size, batch_size, dtype): |
|
"""Create tensors of zeros based on state_size, batch_size, and dtype.""" |
|
"""Combine s with batch_size to get a proper tensor shape.""" |
|
c = _concat(batch_size, s) |
|
size = array_ops.zeros(c, dtype=dtype) |
|
if context.in_graph_mode(): |
|
c_static = _concat(batch_size, s, static=True) |
|
return nest.map_structure(get_state_shape, state_size) |
|
class RNNCell(base_layer.Layer): |
|
"""Abstract object representing an RNN cell. |
|
Every `RNNCell` must have the properties below and implement `call` with |
|
the signature `(output, next_state) = call(input, state)`. The optional |
|
third input argument, `scope`, is allowed for backwards compatibility |
|
purposes; but should be left off for new subclasses. |
|
This definition of cell differs from the definition used in the literature. |
|
In the literature, 'cell' refers to an object with a single scalar output. |
|
This definition refers to a horizontal array of such units. |
|
An RNN cell, in the most abstract setting, is anything that has |
|
a state and performs some operation that takes a matrix of inputs. |
|
This operation results in an output matrix with `self.output_size` columns. |
|
If `self.state_size` is an integer, this operation also results in a new |
|
state matrix with `self.state_size` columns. If `self.state_size` is a |
|
(possibly nested tuple of) TensorShape object(s), then it should return a |
|
matching structure of Tensors having shape `[batch_size].concatenate(s)` |
|
for each `s` in `self.batch_size`. |
|
def __call__(self, inputs, state, scope=None): |
|
"""Run this RNN cell on inputs, starting from the given state. |
|
inputs: `2-D` tensor with shape `[batch_size x input_size]`. |
|
state: if `self.state_size` is an integer, this should be a `2-D Tensor` |
|
with shape `[batch_size x self.state_size]`. Otherwise, if |
|
`self.state_size` is a tuple of integers, this should be a tuple |
|
with shapes `[batch_size x s] for s in self.state_size`. |
|
scope: VariableScope for the created subgraph; defaults to class name. |
|
- Output: A `2-D` tensor with shape `[batch_size x self.output_size]`. |
|
- New state: Either a single `2-D` tensor, or a tuple of tensors matching |
|
the arity and shapes of `state`. |
|
with vs.variable_scope(scope, |
|
custom_getter=self._rnn_get_variable) as scope: |
|
return super(RNNCell, self).__call__(inputs, state, scope=scope) |
|
with vs.variable_scope(vs.get_variable_scope(), |
|
custom_getter=self._rnn_get_variable): |
|
return super(RNNCell, self).__call__(inputs, state) |
|
def _rnn_get_variable(self, getter, *args, **kwargs): |
|
variable = getter(*args, **kwargs) |
|
if context.in_graph_mode(): |
|
trainable = (variable in tf_variables.trainable_variables() or |
|
(isinstance(variable, tf_variables.PartitionedVariable) and |
|
list(variable)[0] in tf_variables.trainable_variables())) |
|
trainable = variable._trainable # pylint: disable=protected-access |
|
if trainable and variable not in self._trainable_weights: |
|
self._trainable_weights.append(variable) |
|
elif not trainable and variable not in self._non_trainable_weights: |
|
self._non_trainable_weights.append(variable) |
|
"""size(s) of state(s) used by this cell. |
|
It can be represented by an Integer, a TensorShape or a tuple of Integers |
|
raise NotImplementedError("Abstract method") |
|
"""Integer or TensorShape: size of outputs produced by this cell.""" |
|
raise NotImplementedError("Abstract method") |
|
# This tells the parent Layer object that it's OK to call |
|
# self.add_variable() inside the call() method. |
|
def zero_state(self, batch_size, dtype): |
|
"""Return zero-filled state tensor(s). |
|
batch_size: int, float, or unit Tensor representing the batch size. |
|
dtype: the data type to use for the state. |
|
If `state_size` is an int or TensorShape, then the return value is a |
|
`N-D` tensor of shape `[batch_size x state_size]` filled with zeros. |
|
If `state_size` is a nested list or tuple, then the return value is |
|
a nested list or tuple (of the same structure) of `2-D` tensors with |
|
the shapes `[batch_size x s]` for each s in `state_size`. |
|
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): |
|
state_size = self.state_size |
|
return _zero_state_tensors(state_size, batch_size, dtype) |
|
class BasicRNNCell(RNNCell): |
|
"""The most basic RNN cell. |
|
num_units: int, The number of units in the RNN cell. |
|
activation: Nonlinearity to use. Default: `tanh`. |
|
reuse: (optional) Python boolean describing whether to reuse variables |
|
in an existing scope. If not `True`, and the existing scope already has |
|
the given variables, an error is raised. |
|
def __init__(self, num_units, activation=None, reuse=None): |
|
super(BasicRNNCell, self).__init__(_reuse=reuse) |
|
self._num_units = num_units |
|
self._activation = activation or math_ops.tanh |
|
def call(self, inputs, state): |
|
"""Most basic RNN: output = new_state = act(W * input + U * state + B).""" |
|
self._linear = _Linear([inputs, state], self._num_units, True) |
|
output = self._activation(self._linear([inputs, state])) |
|
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078). |
|
num_units: int, The number of units in the GRU cell. |
|
activation: Nonlinearity to use. Default: `tanh`. |
|
reuse: (optional) Python boolean describing whether to reuse variables |
|
in an existing scope. If not `True`, and the existing scope already has |
|
the given variables, an error is raised. |
|
kernel_initializer: (optional) The initializer to use for the weight and |
|
bias_initializer: (optional) The initializer to use for the bias. |
|
super(GRUCell, self).__init__(_reuse=reuse) |
|
self._num_units = num_units |
|
self._activation = activation or math_ops.tanh |
|
self._kernel_initializer = kernel_initializer |
|
self._bias_initializer = bias_initializer |
|
self._candidate_linear = None |
|
def call(self, inputs, state): |
|
"""Gated recurrent unit (GRU) with nunits cells.""" |
|
if self._gate_linear is None: |
|
bias_ones = self._bias_initializer |
|
if self._bias_initializer is None: |
|
bias_ones = init_ops.constant_initializer(1.0, dtype=inputs.dtype) |
|
with vs.variable_scope("gates"): # Reset gate and update gate. |
|
self._gate_linear = _Linear( |
|
bias_initializer=bias_ones, |
|
kernel_initializer=self._kernel_initializer) |
|
value = math_ops.sigmoid(self._gate_linear([inputs, state])) |
|
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) |
|
if self._candidate_linear is None: |
|
with vs.variable_scope("candidate"): |
|
self._candidate_linear = _Linear( |
|
bias_initializer=self._bias_initializer, |
|
kernel_initializer=self._kernel_initializer) |
|
c = self._activation(self._candidate_linear([inputs, r_state])) |
|
new_h = u * state + (1 - u) * c |
|
_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h")) |
|
class LSTMStateTuple(_LSTMStateTuple): |
|
"""Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. |
|
Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state |
|
Only used when `state_is_tuple=True`. |
|
raise TypeError("Inconsistent internal state: %s vs %s" % |
|
(str(c.dtype), str(h.dtype))) |
|
class BasicLSTMCell(RNNCell): |
|
"""Basic LSTM recurrent network cell. |
|
The implementation is based on: http://arxiv.org/abs/1409.2329. |
|
We add forget_bias (default: 1) to the biases of the forget gate in order to |
|
reduce the scale of forgetting in the beginning of the training. |
|
It does not allow cell clipping, a projection layer, and does not |
|
use peep-hole connections: it is the basic baseline. |
|
For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell} |
|
def __init__(self, num_units, forget_bias=1.0, |
|
state_is_tuple=True, activation=None, reuse=None): |
|
"""Initialize the basic LSTM cell. |
|
num_units: int, The number of units in the LSTM cell. |
|
forget_bias: float, The bias added to forget gates (see above). |
|
Must set to `0.0` manually when restoring from CudnnLSTM-trained |
|
state_is_tuple: If True, accepted and returned states are 2-tuples of |
|
the `c_state` and `m_state`. If False, they are concatenated |
|
along the column axis. The latter behavior will soon be deprecated. |
|
activation: Activation function of the inner states. Default: `tanh`. |
|
reuse: (optional) Python boolean describing whether to reuse variables |
|
in an existing scope. If not `True`, and the existing scope already has |
|
the given variables, an error is raised. |
|
When restoring from CudnnLSTM-trained checkpoints, must use |
|
CudnnCompatibleLSTMCell instead. |
|
super(BasicLSTMCell, self).__init__(_reuse=reuse) |
|
logging.warn("%s: Using a concatenated state is slower and will soon be " |
|
"deprecated. Use state_is_tuple=True.", self) |
|
self._num_units = num_units |
|
self._forget_bias = forget_bias |
|
self._state_is_tuple = state_is_tuple |
|
self._activation = activation or math_ops.tanh |
|
return (LSTMStateTuple(self._num_units, self._num_units) |
|
if self._state_is_tuple else 2 * self._num_units) |
|
def call(self, inputs, state): |
|
"""Long short-term memory cell (LSTM). |
|
inputs: `2-D` tensor with shape `[batch_size x input_size]`. |
|
state: An `LSTMStateTuple` of state tensors, each shaped |
|
`[batch_size x self.state_size]`, if `state_is_tuple` has been set to |
|
`True`. Otherwise, a `Tensor` shaped |
|
`[batch_size x 2 * self.state_size]`. |
|
A pair containing the new hidden state, and the new state (either a |
|
`LSTMStateTuple` or a concatenated state, depending on |
|
sigmoid = math_ops.sigmoid |
|
# Parameters of gates are concatenated into one multiply for efficiency. |
|
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1) |
|
self._linear = _Linear([inputs, h], 4 * self._num_units, True) |
|
# i = input_gate, j = new_input, f = forget_gate, o = output_gate |
|
i, j, f, o = array_ops.split( |
|
value=self._linear([inputs, h]), num_or_size_splits=4, axis=1) |
|
c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j)) |
|
new_h = self._activation(new_c) * sigmoid(o) |
|
new_state = LSTMStateTuple(new_c, new_h) |
|
new_state = array_ops.concat([new_c, new_h], 1) |
|
"""Long short-term memory unit (LSTM) recurrent network cell. |
|
The default non-peephole implementation is based on: |
|
http://www.bioinf.jku.at/publications/older/2604.pdf |
|
S. Hochreiter and J. Schmidhuber. |
|
"Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. |
|
The peephole implementation is based on: |
|
https://research.google.com/pubs/archive/43905.pdf |
|
Hasim Sak, Andrew Senior, and Francoise Beaufays. |
|
"Long short-term memory recurrent neural network architectures for |
|
large scale acoustic modeling." INTERSPEECH, 2014. |
|
The class uses optional peep-hole connections, optional cell clipping, and |
|
an optional projection layer. |
|
def __init__(self, num_units, |
|
use_peepholes=False, cell_clip=None, |
|
initializer=None, num_proj=None, proj_clip=None, |
|
num_unit_shards=None, num_proj_shards=None, |
|
forget_bias=1.0, state_is_tuple=True, |
|
activation=None, reuse=None): |
|
"""Initialize the parameters for an LSTM cell. |
|
num_units: int, The number of units in the LSTM cell. |
|
use_peepholes: bool, set True to enable diagonal/peephole connections. |
|
cell_clip: (optional) A float value, if provided the cell state is clipped |
|
by this value prior to the cell output activation. |
|
initializer: (optional) The initializer to use for the weight and |
|
num_proj: (optional) int, The output dimensionality for the projection |
|
matrices. If None, no projection is performed. |
|
proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is |
|
provided, then the projected values are clipped elementwise to within |
|
`[-proj_clip, proj_clip]`. |
|
num_unit_shards: Deprecated, will be removed by Jan. 2017. |
|
Use a variable_scope partitioner instead. |
|
num_proj_shards: Deprecated, will be removed by Jan. 2017. |
|
Use a variable_scope partitioner instead. |
|
forget_bias: Biases of the forget gate are initialized by default to 1 |
|
in order to reduce the scale of forgetting at the beginning of |
|
the training. Must set it manually to `0.0` when restoring from |
|
CudnnLSTM trained checkpoints. |
|
state_is_tuple: If True, accepted and returned states are 2-tuples of |
|
the `c_state` and `m_state`. If False, they are concatenated |
|
along the column axis. This latter behavior will soon be deprecated. |
|
activation: Activation function of the inner states. Default: `tanh`. |
|
reuse: (optional) Python boolean describing whether to reuse variables |
|
in an existing scope. If not `True`, and the existing scope already has |
|
the given variables, an error is raised. |
|
When restoring from CudnnLSTM-trained checkpoints, must use |
|
CudnnCompatibleLSTMCell instead. |
|
super(LSTMCell, self).__init__(_reuse=reuse) |
|
logging.warn("%s: Using a concatenated state is slower and will soon be " |
|
"deprecated. Use state_is_tuple=True.", self) |
|
if num_unit_shards is not None or num_proj_shards is not None: |
|
"%s: The num_unit_shards and proj_unit_shards parameters are " |
|
"deprecated and will be removed in Jan 2017. " |
|
"Use a variable scope with a partitioner instead.", self) |
|
self._num_units = num_units |
|
self._use_peepholes = use_peepholes |
|
self._cell_clip = cell_clip |
|
self._initializer = initializer |
|
self._num_proj = num_proj |
|
self._proj_clip = proj_clip |
|
self._num_unit_shards = num_unit_shards |
|
self._num_proj_shards = num_proj_shards |
|
self._forget_bias = forget_bias |
|
self._state_is_tuple = state_is_tuple |
|
self._activation = activation or math_ops.tanh |
|
LSTMStateTuple(num_units, num_proj) |
|
if state_is_tuple else num_units + num_proj) |
|
self._output_size = num_proj |
|
LSTMStateTuple(num_units, num_units) |
|
if state_is_tuple else 2 * num_units) |
|
self._output_size = num_units |
|
def call(self, inputs, state): |
|
inputs: input Tensor, 2D, batch x num_units. |
|
state: if `state_is_tuple` is False, this must be a state Tensor, |
|
`2-D, batch x state_size`. If `state_is_tuple` is True, this must be a |
|
tuple of state Tensors, both `2-D`, with column sizes `c_state` and |
|
- A `2-D, [batch x output_dim]`, Tensor representing the output of the |
|
LSTM after reading `inputs` when previous state was `state`. |
|
num_proj if num_proj was set, |
|
- Tensor(s) representing the new state of LSTM after reading `inputs` when |
|
the previous state was `state`. Same type and shape(s) as `state`. |
|
ValueError: If input size cannot be inferred from inputs via |
|
num_proj = self._num_units if self._num_proj is None else self._num_proj |
|
sigmoid = math_ops.sigmoid |
|
c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) |
|
m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) |
|
input_size = inputs.get_shape().with_rank(2)[1] |
|
if input_size.value is None: |
|
raise ValueError("Could not infer input size from inputs.get_shape()[-1]") |
|
if self._linear1 is None: |
|
scope = vs.get_variable_scope() |
|
scope, initializer=self._initializer) as unit_scope: |
|
if self._num_unit_shards is not None: |
|
unit_scope.set_partitioner( |
|
partitioned_variables.fixed_size_partitioner( |
|
self._linear1 = _Linear([inputs, m_prev], 4 * self._num_units, True) |
|
# i = input_gate, j = new_input, f = forget_gate, o = output_gate |
|
lstm_matrix = self._linear1([inputs, m_prev]) |
|
i, j, f, o = array_ops.split( |
|
value=lstm_matrix, num_or_size_splits=4, axis=1) |
|
if self._use_peepholes and not self._w_f_diag: |
|
scope = vs.get_variable_scope() |
|
scope, initializer=self._initializer) as unit_scope: |
|
with vs.variable_scope(unit_scope): |
|
self._w_f_diag = vs.get_variable( |
|
"w_f_diag", shape=[self._num_units], dtype=dtype) |
|
self._w_i_diag = vs.get_variable( |
|
"w_i_diag", shape=[self._num_units], dtype=dtype) |
|
self._w_o_diag = vs.get_variable( |
|
"w_o_diag", shape=[self._num_units], dtype=dtype) |
|
c = (sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + |
|
sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) |
|
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * |
|
if self._cell_clip is not None: |
|
# pylint: disable=invalid-unary-operand-type |
|
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) |
|
# pylint: enable=invalid-unary-operand-type |
|
m = sigmoid(o + self._w_o_diag * c) * self._activation(c) |
|
m = sigmoid(o) * self._activation(c) |
|
if self._num_proj is not None: |
|
if self._linear2 is None: |
|
scope = vs.get_variable_scope() |
|
with vs.variable_scope(scope, initializer=self._initializer): |
|
with vs.variable_scope("projection") as proj_scope: |
|
if self._num_proj_shards is not None: |
|
proj_scope.set_partitioner( |
|
partitioned_variables.fixed_size_partitioner( |
|
self._linear2 = _Linear(m, self._num_proj, False) |
|
if self._proj_clip is not None: |
|
# pylint: disable=invalid-unary-operand-type |
|
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) |
|
# pylint: enable=invalid-unary-operand-type |
|
new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else |
|
array_ops.concat([c, m], 1)) |
|
def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs): |
|
def enumerated_fn(*inner_args, **inner_kwargs): |
|
r = map_fn(ix[0], *inner_args, **inner_kwargs) |
|
return nest.map_structure_up_to(shallow_structure, |
|
enumerated_fn, *args, **kwargs) |
|
def _default_dropout_state_filter_visitor(substate): |
|
if isinstance(substate, LSTMStateTuple): |
|
# Do not perform dropout on the memory state. |
|
return LSTMStateTuple(c=False, h=True) |
|
elif isinstance(substate, tensor_array_ops.TensorArray): |
|
class DropoutWrapper(RNNCell): |
|
"""Operator adding dropout to inputs and outputs of the given cell.""" |
|
def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0, |
|
state_keep_prob=1.0, variational_recurrent=False, |
|
input_size=None, dtype=None, seed=None, |
|
dropout_state_filter_visitor=None): |
|
"""Create a cell with added input, state, and/or output dropout. |
|
If `variational_recurrent` is set to `True` (**NOT** the default behavior), |
|
then the same dropout mask is applied at every step, as described in: |
|
Y. Gal, Z Ghahramani. "A Theoretically Grounded Application of Dropout in |
|
Recurrent Neural Networks". https://arxiv.org/abs/1512.05287 |
|
Otherwise a different dropout mask is applied at every time step. |
|
Note, by default (unless a custom `dropout_state_filter` is provided), |
|
the memory state (`c` component of any `LSTMStateTuple`) passing through |
|
a `DropoutWrapper` is never modified. This behavior is described in the |
|
cell: an RNNCell, a projection to output_size is added to it. |
|
input_keep_prob: unit Tensor or float between 0 and 1, input keep |
|
probability; if it is constant and 1, no input dropout will be added. |
|
output_keep_prob: unit Tensor or float between 0 and 1, output keep |
|
probability; if it is constant and 1, no output dropout will be added. |
|
state_keep_prob: unit Tensor or float between 0 and 1, output keep |
|
probability; if it is constant and 1, no output dropout will be added. |
|
State dropout is performed on the outgoing states of the cell. |
|
**Note** the state components to which dropout is applied when |
|
`state_keep_prob` is in `(0, 1)` are also determined by |
|
the argument `dropout_state_filter_visitor` (e.g. by default dropout |
|
is never applied to the `c` component of an `LSTMStateTuple`). |
|
variational_recurrent: Python bool. If `True`, then the same |
|
dropout pattern is applied across all time steps per run call. |
|
If this parameter is set, `input_size` **must** be provided. |
|
input_size: (optional) (possibly nested tuple of) `TensorShape` objects |
|
containing the depth(s) of the input tensors expected to be passed in to |
|
the `DropoutWrapper`. Required and used **iff** |
|
`variational_recurrent = True` and `input_keep_prob < 1`. |
|
dtype: (optional) The `dtype` of the input, state, and output tensors. |
|
Required and used **iff** `variational_recurrent = True`. |
|
seed: (optional) integer, the randomness seed. |
|
dropout_state_filter_visitor: (optional), default: (see below). Function |
|
that takes any hierarchical level of the state and returns |
|
a scalar or depth=1 structure of Python booleans describing |
|
which terms in the state should be dropped out. In addition, if the |
|
function returns `True`, dropout is applied across this sublevel. If |
|
the function returns `False`, dropout is not applied across this entire |
|
Default behavior: perform dropout on all terms except the memory (`c`) |
|
state of `LSTMCellState` objects, and don't try to apply dropout to |
|
def dropout_state_filter_visitor(s): |
|
if isinstance(s, LSTMCellState): |
|
# Never perform dropout on the c state. |
|
return LSTMCellState(c=False, h=True) |
|
elif isinstance(s, TensorArray): |
|
TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is provided |
|
ValueError: if any of the keep_probs are not between 0 and 1. |
|
if not _like_rnncell(cell): |
|
raise TypeError("The parameter cell is not a RNNCell.") |
|
if (dropout_state_filter_visitor is not None |
|
and not callable(dropout_state_filter_visitor)): |
|
raise TypeError("dropout_state_filter_visitor must be callable") |
|
self._dropout_state_filter = ( |
|
dropout_state_filter_visitor or _default_dropout_state_filter_visitor) |
|
with ops.name_scope("DropoutWrapperInit"): |
|
def tensor_and_const_value(v): |
|
tensor_value = ops.convert_to_tensor(v) |
|
const_value = tensor_util.constant_value(tensor_value) |
|
return (tensor_value, const_value) |
|
for prob, attr in [(input_keep_prob, "input_keep_prob"), |
|
(state_keep_prob, "state_keep_prob"), |
|
(output_keep_prob, "output_keep_prob")]: |
|
tensor_prob, const_prob = tensor_and_const_value(prob) |
|
if const_prob is not None: |
|
if const_prob < 0 or const_prob > 1: |
|
raise ValueError("Parameter %s must be between 0 and 1: %d" |
|
setattr(self, "_%s" % attr, float(const_prob)) |
|
setattr(self, "_%s" % attr, tensor_prob) |
|
# Set cell, variational_recurrent, seed before running the code below |
|
self._variational_recurrent = variational_recurrent |
|
self._recurrent_input_noise = None |
|
self._recurrent_state_noise = None |
|
self._recurrent_output_noise = None |
|
if variational_recurrent: |
|
"When variational_recurrent=True, dtype must be provided") |
|
def convert_to_batch_shape(s): |
|
# Prepend a 1 for the batch dimension; for recurrent |
|
# variational dropout we use the same dropout mask for all |
|
([1], tensor_shape.TensorShape(s).as_list()), 0) |
|
def batch_noise(s, inner_seed): |
|
shape = convert_to_batch_shape(s) |
|
return random_ops.random_uniform(shape, seed=inner_seed, dtype=dtype) |
|
if (not isinstance(self._input_keep_prob, numbers.Real) or |
|
self._input_keep_prob < 1.0): |
|
"When variational_recurrent=True and input_keep_prob < 1.0 or " |
|
"is unknown, input_size must be provided") |
|
self._recurrent_input_noise = _enumerated_map_structure_up_to( |
|
lambda i, s: batch_noise(s, inner_seed=self._gen_seed("input", i)), |
|
self._recurrent_state_noise = _enumerated_map_structure_up_to( |
|
lambda i, s: batch_noise(s, inner_seed=self._gen_seed("state", i)), |
|
self._recurrent_output_noise = _enumerated_map_structure_up_to( |
|
lambda i, s: batch_noise(s, inner_seed=self._gen_seed("output", i)), |
|
def _gen_seed(self, salt_prefix, index): |
|
salt = "%s_%d" % (salt_prefix, index) |
|
string = (str(self._seed) + salt).encode("utf-8") |
|
return int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF |
|
return self._cell.state_size |
|
return self._cell.output_size |
|
def zero_state(self, batch_size, dtype): |
|
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): |
|
return self._cell.zero_state(batch_size, dtype) |
|
def _variational_recurrent_dropout_value( |
|
self, index, value, noise, keep_prob): |
|
"""Performs dropout given the pre-calculated noise tensor.""" |
|
# uniform [keep_prob, 1.0 + keep_prob) |
|
random_tensor = keep_prob + noise |
|
# 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob) |
|
binary_tensor = math_ops.floor(random_tensor) |
|
ret = math_ops.div(value, keep_prob) * binary_tensor |
|
ret.set_shape(value.get_shape()) |
|
def _dropout(self, values, salt_prefix, recurrent_noise, keep_prob, |
|
shallow_filtered_substructure=None): |
|
"""Decides whether to perform standard dropout or recurrent dropout.""" |
|
if shallow_filtered_substructure is None: |
|
# Put something so we traverse the entire structure; inside the |
|
# dropout function we check to see if leafs of this are bool or not. |
|
shallow_filtered_substructure = values |
|
if not self._variational_recurrent: |
|
def dropout(i, do_dropout, v): |
|
if not isinstance(do_dropout, bool) or do_dropout: |
|
v, keep_prob=keep_prob, seed=self._gen_seed(salt_prefix, i)) |
|
return _enumerated_map_structure_up_to( |
|
shallow_filtered_substructure, dropout, |
|
*[shallow_filtered_substructure, values]) |
|
def dropout(i, do_dropout, v, n): |
|
if not isinstance(do_dropout, bool) or do_dropout: |
|
return self._variational_recurrent_dropout_value(i, v, n, keep_prob) |
|
return _enumerated_map_structure_up_to( |
|
shallow_filtered_substructure, dropout, |
|
*[shallow_filtered_substructure, values, recurrent_noise]) |
|
def __call__(self, inputs, state, scope=None): |
|
"""Run the cell with the declared dropouts.""" |
|
return (not isinstance(p, float)) or p < 1 |
|
if _should_dropout(self._input_keep_prob): |
|
inputs = self._dropout(inputs, "input", |
|
self._recurrent_input_noise, |
|
output, new_state = self._cell(inputs, state, scope) |
|
if _should_dropout(self._state_keep_prob): |
|
# Identify which subsets of the state to perform dropout on and |
|
shallow_filtered_substructure = nest.get_traverse_shallow_structure( |
|
self._dropout_state_filter, new_state) |
|
new_state = self._dropout(new_state, "state", |
|
self._recurrent_state_noise, |
|
shallow_filtered_substructure) |
|
if _should_dropout(self._output_keep_prob): |
|
output = self._dropout(output, "output", |
|
self._recurrent_output_noise, |
|
class ResidualWrapper(RNNCell): |
|
"""RNNCell wrapper that ensures cell inputs are added to the outputs.""" |
|
def __init__(self, cell, residual_fn=None): |
|
"""Constructs a `ResidualWrapper` for `cell`. |
|
cell: An instance of `RNNCell`. |
|
residual_fn: (Optional) The function to map raw cell inputs and raw cell |
|
outputs to the actual cell outputs of the residual network. |
|
Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs |
|
self._residual_fn = residual_fn |
|
return self._cell.state_size |
|
return self._cell.output_size |
|
def zero_state(self, batch_size, dtype): |
|
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): |
|
return self._cell.zero_state(batch_size, dtype) |
|
def __call__(self, inputs, state, scope=None): |
|
"""Run the cell and then apply the residual_fn on its inputs to its outputs. |
|
scope: optional cell scope. |
|
Tuple of cell outputs and new state. |
|
TypeError: If cell inputs and outputs have different structure (type). |
|
ValueError: If cell inputs and outputs have different structure (value). |
|
outputs, new_state = self._cell(inputs, state, scope=scope) |
|
def assert_shape_match(inp, out): |
|
inp.get_shape().assert_is_compatible_with(out.get_shape()) |
|
def default_residual_fn(inputs, outputs): |
|
nest.assert_same_structure(inputs, outputs) |
|
nest.map_structure(assert_shape_match, inputs, outputs) |
|
return nest.map_structure(lambda inp, out: inp + out, inputs, outputs) |
|
res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs) |
|
return (res_outputs, new_state) |
|
class DeviceWrapper(RNNCell): |
|
"""Operator that ensures an RNNCell runs on a particular device.""" |
|
def __init__(self, cell, device): |
|
"""Construct a `DeviceWrapper` for `cell` with device `device`. |
|
Ensures the wrapped `cell` is called with `tf.device(device)`. |
|
cell: An instance of `RNNCell`. |
|
device: A device string or function, for passing to `tf.device`. |
|
return self._cell.state_size |
|
return self._cell.output_size |
|
def zero_state(self, batch_size, dtype): |
|
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): |
|
with ops.device(self._device): |
|
return self._cell.zero_state(batch_size, dtype) |
|
def __call__(self, inputs, state, scope=None): |
|
"""Run the cell on specified device.""" |
|
with ops.device(self._device): |
|
return self._cell(inputs, state, scope=scope) |
|
class MultiRNNCell(RNNCell): |
|
"""RNN cell composed sequentially of multiple simple cells.""" |
|
def __init__(self, cells, state_is_tuple=True): |
|
"""Create a RNN cell composed sequentially of a number of RNNCells. |
|
cells: list of RNNCells that will be composed in this order. |
|
state_is_tuple: If True, accepted and returned states are n-tuples, where |
|
`n = len(cells)`. If False, the states are all |
|
concatenated along the column axis. This latter behavior will soon be |
|
ValueError: if cells is empty (not allowed), or at least one of the cells |
|
returns a state tuple but the flag `state_is_tuple` is `False`. |
|
super(MultiRNNCell, self).__init__() |
|
raise ValueError("Must specify at least one cell for MultiRNNCell.") |
|
if not nest.is_sequence(cells): |
|
"cells must be a list or tuple, but saw: %s." % cells) |
|
self._state_is_tuple = state_is_tuple |
|
if any(nest.is_sequence(c.state_size) for c in self._cells): |
|
raise ValueError("Some cells return tuples of states, but the flag " |
|
"state_is_tuple is not set. State sizes are: %s" |
|
% str([c.state_size for c in self._cells])) |
|
return tuple(cell.state_size for cell in self._cells) |
|
return sum([cell.state_size for cell in self._cells]) |
|
return self._cells[-1].output_size |
|
def zero_state(self, batch_size, dtype): |
|
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): |
|
return tuple(cell.zero_state(batch_size, dtype) for cell in self._cells) |
|
# We know here that state_size of each cell is not a tuple and |
|
# presumably does not contain TensorArrays or anything else fancy |
|
return super(MultiRNNCell, self).zero_state(batch_size, dtype) |
|
def call(self, inputs, state): |
|
"""Run this multi-layer cell on inputs, starting from state.""" |
|
for i, cell in enumerate(self._cells): |
|
with vs.variable_scope("cell_%d" % i): |
|
if not nest.is_sequence(state): |
|
"Expected state to be a tuple of length %d, but received: %s" % |
|
(len(self.state_size), state)) |
|
cur_state = array_ops.slice(state, [0, cur_state_pos], |
|
cur_state_pos += cell.state_size |
|
cur_inp, new_state = cell(cur_inp, cur_state) |
|
new_states.append(new_state) |
|
new_states = (tuple(new_states) if self._state_is_tuple else |
|
array_ops.concat(new_states, 1)) |
|
return cur_inp, new_states |
|
class _SlimRNNCell(RNNCell): |
|
"""A simple wrapper for slim.rnn_cells.""" |
|
def __init__(self, cell_fn): |
|
"""Create a SlimRNNCell from a cell_fn. |
|
cell_fn: a function which takes (inputs, state, scope) and produces the |
|
outputs and the new_state. Additionally when called with inputs=None and |
|
state=None it should return (initial_outputs, initial_state). |
|
TypeError: if cell_fn is not callable |
|
ValueError: if cell_fn cannot produce a valid initial state. |
|
if not callable(cell_fn): |
|
raise TypeError("cell_fn %s needs to be callable", cell_fn) |
|
self._cell_name = cell_fn.func.__name__ |
|
init_output, init_state = self._cell_fn(None, None) |
|
output_shape = init_output.get_shape() |
|
state_shape = init_state.get_shape() |
|
self._output_size = output_shape.with_rank(2)[1].value |
|
self._state_size = state_shape.with_rank(2)[1].value |
|
if self._output_size is None: |
|
raise ValueError("Initial output created by %s has invalid shape %s" % |
|
(self._cell_name, output_shape)) |
|
if self._state_size is None: |
|
raise ValueError("Initial state created by %s has invalid shape %s" % |
|
(self._cell_name, state_shape)) |
|
def __call__(self, inputs, state, scope=None): |
|
scope = scope or self._cell_name |
|
output, state = self._cell_fn(inputs, state, scope=scope) |
|
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. |
|
args: a 2D Tensor or a list of 2D, batch x n, Tensors. |
|
output_size: int, second dimension of weight variable. |
|
dtype: data type for variables. |
|
build_bias: boolean, whether to build a bias variable. |
|
bias_initializer: starting value to initialize the bias |
|
kernel_initializer: starting value to initialize the weight. |
|
ValueError: if inputs_shape is wrong. |
|
kernel_initializer=None): |
|
self._build_bias = build_bias |
|
if args is None or (nest.is_sequence(args) and not args): |
|
raise ValueError("`args` must be specified") |
|
if not nest.is_sequence(args): |
|
self._is_sequence = False |
|
# Calculate the total size of arguments on dimension 1. |
|
shapes = [a.get_shape() for a in args] |
|
raise ValueError("linear is expecting 2D arguments: %s" % shapes) |
|
if shape[1].value is None: |
|
raise ValueError("linear expects shape[1] to be provided for shape %s, " |
|
"but saw %s" % (shape, shape[1])) |
|
total_arg_size += shape[1].value |
|
dtype = [a.dtype for a in args][0] |
|
scope = vs.get_variable_scope() |
|
with vs.variable_scope(scope) as outer_scope: |
|
self._weights = vs.get_variable( |
|
_WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size], |
|
initializer=kernel_initializer) |
|
with vs.variable_scope(outer_scope) as inner_scope: |
|
inner_scope.set_partitioner(None) |
|
if bias_initializer is None: |
|
bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) |
|
self._biases = vs.get_variable( |
|
_BIAS_VARIABLE_NAME, [output_size], |
|
initializer=bias_initializer) |
|
def __call__(self, args): |
|
if not self._is_sequence: |
|
res = math_ops.matmul(args[0], self._weights) |
|
res = math_ops.matmul(array_ops.concat(args, 1), self._weights) |
|
res = nn_ops.bias_add(res, self._biases) |
|
# TODO(xpan): Remove this function in a follow up. |
|
kernel_initializer=None): |
|
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. |
|
args: a 2D Tensor or a list of 2D, batch x n, Tensors. |
|
output_size: int, second dimension of W[i]. |
|
bias: boolean, whether to add a bias term or not. |
|
bias_initializer: starting value to initialize the bias |
|
kernel_initializer: starting value to initialize the weight. |
|
A 2D Tensor with shape [batch x output_size] equal to |
|
sum_i(args[i] * W[i]), where W[i]s are newly created matrices. |
|
ValueError: if some of the arguments has unspecified or wrong shape. |
|
if args is None or (nest.is_sequence(args) and not args): |
|
raise ValueError("`args` must be specified") |
|
if not nest.is_sequence(args): |
|
# Calculate the total size of arguments on dimension 1. |
|
shapes = [a.get_shape() for a in args] |
|
raise ValueError("linear is expecting 2D arguments: %s" % shapes) |
|
if shape[1].value is None: |
|
raise ValueError("linear expects shape[1] to be provided for shape %s, " |
|
"but saw %s" % (shape, shape[1])) |
|
total_arg_size += shape[1].value |
|
dtype = [a.dtype for a in args][0] |
|
scope = vs.get_variable_scope() |
|
with vs.variable_scope(scope) as outer_scope: |
|
weights = vs.get_variable( |
|
_WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size], |
|
initializer=kernel_initializer) |
|
res = math_ops.matmul(args[0], weights) |
|
res = math_ops.matmul(array_ops.concat(args, 1), weights) |
|
with vs.variable_scope(outer_scope) as inner_scope: |
|
inner_scope.set_partitioner(None) |
|
if bias_initializer is None: |
|
bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) |
|
biases = vs.get_variable( |
|
_BIAS_VARIABLE_NAME, [output_size], |
|
initializer=bias_initializer) |
|
return nn_ops.bias_add(res, biases)