Tensorflow rnn_cell api 阅读笔记
github地址
概述
该rnn_cell api 来源于rnn_cell_impl.py 官网RNN and Cells介绍
讲述的是rnn_cell的构造。
"""Module implementing RNN Cells.
This module provides a number of basic commonly used RNN cells, such as LSTM
(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of
operators that allow adding dropouts, projections, or embeddings for inputs.
Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by
calling the `rnn` ops several times.
"""
RNNCell是所有cell的父类,LayerRNNCell的设计是为了方便参数的管理,其他的cell和wrapper等结构均是根据具体论文和相关应用场景设计实现的。
rnn_cell的继承关系:
RNNCell(tf.layers.Layer)
LayerRNNCell(RNNCell)
BasicRNNCell(LayerRNNCell)
BasicLSTMCell(LayerRNNCell)
LSTMCell(LayerRNNCell)
GRUCell(LayerRNNCell)
存储rnn_cell状态的类:
LSTMStateTuple
rnn_cell wrapper 用于装载rnncell,warpper均继承自RNNCell,因为它们的实现是在rnncell的信息的输入、输出和信息的自环的过程中进行操作,该过程和rnn_cell有相同的自环操作过程:
MultiRNNCell
DropoutWrapper
DeviceWrapper
ResidualWrapper
常用rnn_cell
RNNCell
class RNNCell(base_layer.Layer):
"""Abstract object representing an RNN cell.
Every `RNNCell` must have the properties below and implement `call` with
the signature `(output, next_state) = call(input, state)`. The optional
third input argument, `scope`, is allowed for backwards compatibility
purposes; but should be left off for new subclasses.
This definition of cell differs from the definition used in the literature.
In the literature, 'cell' refers to an object with a single scalar output.
This definition refers to a horizontal array of such units.
An RNN cell, in the most abstract setting, is anything that has
a state and performs some operation that takes a matrix of inputs.
This operation results in an output matrix with `self.output_size` columns.
If `self.state_size` is an integer, this operation also results in a new
state matrix with `self.state_size` columns. If `self.state_size` is a
(possibly nested tuple of) TensorShape object(s), then it should return a
matching structure of Tensors having shape `[batch_size].concatenate(s)`
for each `s` in `self.batch_size`.
"""
RNNCell是所有rnncell的父类
每个RNNCell必须有以下的属性并实现具有如下函数签名的函数(output, next_state) = call(input, state)。 可选的第三个输入参数‘scope’,用于向下兼容,给子类定制化使用。scope传入的值是tf.Variable类型,用于更方便的管理变量。
这个代码中cell的定义和cell的实际定义是不同的。实际的cell是有一个单一的标量输出。而这个代码中定义的cell指的是一行这样的单元(RNNCell是有的_num_units属性的,和全连接的cell不同)。
。。。如果self.state_size
是一个整数,这个操作也会生成一个具有self.state_size
个列的矩阵。如果self.state_size
是一个张量(可能是一个tuple形式),那么该cell单元应该返回一个对应结构的tensor,该tenor应该具有的shape 是 [batch_size].concatenate(s)
, for each s
in self.batch_size
.
def __call__(self, inputs, state, scope=None):
"""Run this RNN cell on inputs, starting from the given state.
Args:
inputs: `2-D` tensor with shape `[batch_size, input_size]`.
state: if `self.state_size` is an integer, this should be a `2-D Tensor`
with shape `[batch_size, self.state_size]`. Otherwise, if
`self.state_size` is a tuple of integers, this should be a tuple
with shapes `[batch_size, s] for s in self.state_size`.
scope: VariableScope for the created subgraph; defaults to class name.
Returns:
A pair containing:
- Output: A `2-D` tensor with shape `[batch_size, self.output_size]`.
- New state: Either a single `2-D` tensor, or a tuple of tensors matching
the arity and shapes of `state`.
"""
从给定的state开始运行,根据rnn cell的输入
args:
inputs:是一个具有二维的张量shape为[batch_size, input_size]
states:如果self.state_size
是一个整数,state就应该是一个二维张量 shape是[batch_size, self.state_size]
,否则,如果s