Tensorflow rnn_cell api 阅读笔记

Tensorflow rnn_cell api 阅读笔记

github地址

github地址

概述

该rnn_cell api 来源于rnn_cell_impl.py 官网RNN and Cells介绍

讲述的是rnn_cell的构造。

"""Module implementing RNN Cells.

This module provides a number of basic commonly used RNN cells, such as LSTM
(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of
operators that allow adding dropouts, projections, or embeddings for inputs.
Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by
calling the `rnn` ops several times.
"""

RNNCell是所有cell的父类,LayerRNNCell的设计是为了方便参数的管理,其他的cell和wrapper等结构均是根据具体论文和相关应用场景设计实现的。

rnn_cell的继承关系:

RNNCell(tf.layers.Layer)

LayerRNNCell(RNNCell)

BasicRNNCell(LayerRNNCell)
BasicLSTMCell(LayerRNNCell)
LSTMCell(LayerRNNCell)
GRUCell(LayerRNNCell)

存储rnn_cell状态的类:

LSTMStateTuple

rnn_cell wrapper 用于装载rnncell,warpper均继承自RNNCell,因为它们的实现是在rnncell的信息的输入、输出和信息的自环的过程中进行操作,该过程和rnn_cell有相同的自环操作过程:

MultiRNNCell
DropoutWrapper
DeviceWrapper
ResidualWrapper

常用rnn_cell

RNNCell

class RNNCell(base_layer.Layer):
  """Abstract object representing an RNN cell.

  Every `RNNCell` must have the properties below and implement `call` with
  the signature `(output, next_state) = call(input, state)`.  The optional
  third input argument, `scope`, is allowed for backwards compatibility
  purposes; but should be left off for new subclasses.

  This definition of cell differs from the definition used in the literature.
  In the literature, 'cell' refers to an object with a single scalar output.
  This definition refers to a horizontal array of such units.

  An RNN cell, in the most abstract setting, is anything that has
  a state and performs some operation that takes a matrix of inputs.
  This operation results in an output matrix with `self.output_size` columns.
  If `self.state_size` is an integer, this operation also results in a new
  state matrix with `self.state_size` columns.  If `self.state_size` is a
  (possibly nested tuple of) TensorShape object(s), then it should return a
  matching structure of Tensors having shape `[batch_size].concatenate(s)`
  for each `s` in `self.batch_size`.
  """

RNNCell是所有rnncell的父类

每个RNNCell必须有以下的属性并实现具有如下函数签名的函数(output, next_state) = call(input, state)。 可选的第三个输入参数‘scope’,用于向下兼容,给子类定制化使用。scope传入的值是tf.Variable类型,用于更方便的管理变量。

这个代码中cell的定义和cell的实际定义是不同的。实际的cell是有一个单一的标量输出。而这个代码中定义的cell指的是一行这样的单元(RNNCell是有的_num_units属性的,和全连接的cell不同)。

。。。如果self.state_size是一个整数,这个操作也会生成一个具有self.state_size 个列的矩阵。如果self.state_size是一个张量(可能是一个tuple形式),那么该cell单元应该返回一个对应结构的tensor,该tenor应该具有的shape 是 [batch_size].concatenate(s), for each s in self.batch_size.

  def __call__(self, inputs, state, scope=None):
    """Run this RNN cell on inputs, starting from the given state.

    Args:
      inputs: `2-D` tensor with shape `[batch_size, input_size]`.
      state: if `self.state_size` is an integer, this should be a `2-D Tensor`
        with shape `[batch_size, self.state_size]`.  Otherwise, if
        `self.state_size` is a tuple of integers, this should be a tuple
        with shapes `[batch_size, s] for s in self.state_size`.
      scope: VariableScope for the created subgraph; defaults to class name.

    Returns:
      A pair containing:

      - Output: A `2-D` tensor with shape `[batch_size, self.output_size]`.
      - New state: Either a single `2-D` tensor, or a tuple of tensors matching
        the arity and shapes of `state`.
    """

从给定的state开始运行,根据rnn cell的输入

args:

inputs:是一个具有二维的张量shape为[batch_size, input_size]
states:如果 self.state_size 是一个整数,state就应该是一个二维张量 shape是 [batch_size, self.state_size],否则,如果 s

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值