BasicLSTMCell源码分析

class BasicLSTMCell(RNNCell):
  """Basic LSTM recurrent network cell.

  The implementation is based on: http://arxiv.org/abs/1409.2329.

  We add forget_bias (default: 1) to the biases of the forget gate in order to
  reduce the scale of forgetting in the beginning of the training.

  It does not allow cell clipping, a projection layer, and does not
  use peep-hole connections: it is the basic baseline.

  For advanced models, please use the full LSTMCell that follows.
  """

  def __init__(self, num_units, forget_bias=1.0, input_size=None,
               state_is_tuple=True, activation=tanh):
    """Initialize the basic LSTM cell.

    Args:
      num_units: int, The number of units in the LSTM cell.
      forget_bias: float, The bias added to forget gates (see above).
      input_size: Deprecated and unused.
      state_is_tuple: If True, accepted and returned states are 2-tuples of
        the `c_state` and `m_state`.  If False, they are concatenated
        along the column axis.  The latter behavior will soon be deprecated.
      activation: Activation function of the inner states.
    """
    if not state_is_tuple:
      logging.warn("%s: Using a concatenated state is slower and will soon be "
                   "deprecated.  Use state_is_tuple=True.", self)
    if input_size is not None:
      logging.warn("%s: The input_size parameter is deprecated.", self)
    self._num_units = num_units
    self._forget_bias = forget_bias
    self._state_is_tuple = state_is_tuple
    self._activation = activation

  @property
  def state_size(self):
    return (LSTMStateTuple(self._num_units, self._num_units)
            if self._state_is_tuple else 2 * self._num_units)

  @property
  def output_size(self):
    return self._num_units

#前向网络,这里并没有涉及误差梯度的反馈更新等处理
  def __call__(self, inputs, state, scope=None):
    """Long short-term memory cell (LSTM)."""
    with vs.variable_scope(scope or "basic_lstm_cell"):
      # Parameters of gates are concatenated into one multiply for efficiency.
      if self._state_is_tuple:
 

    

#这里的c,h 都是上一时刻的状态输出Ct-1和隐层输出Ht-1, 维度默认与隐含层大小一致为num_units

c, h  =  state  else : c, h  = array_ops.split(value=state, num_or_size_splits=2, axis=1 )

#这里的W和b,包含了整个lstm cell单元中的四个门的w与b,是一个参数数组 #

	# 线性计算 concat = [inputs, h]W + b 
    # 线性计算,分配W和b,  其中W的shape为(2*num_units, 4*num_units), b的shape为             (4*num_units,), 共包含有四套参数,
      # concat shape(batch_size, 4*num_units)
     # 注意:只有cell 的input和output的size相等时才可以这样计算,否则要定义两套W,b.每套再包含四套参数


      

'''

_linear 函数定义:def _linear(args,            output_size,            bias,            bias_initializer=None,            kernel_initializer=None):  """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.  Args:    args: a 2D Tensor or a list of 2D, batch x n, Tensors.    output_size: int, second dimension of W[i].    bias: boolean, whether to add a bias term or not.    bias_initializer: starting value to initialize the bias      (default is all zeros).    kernel_initializer: starting value to initialize the weight.'''



4 * self._num_units :表示输出维度(包含了四个门的输出的线性组合,每个门输出维度为self._num_units 
concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope)
      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      #对concat进行线性切分,得到四个门的计算结果输出,然后依次对前一刻的状态c,h以及当前时刻的输入input进行门逻辑处理,更新当前状态c以及得到隐层输出h
      i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)

      new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
               self._activation(j))
      new_h = self._activation(new_c) * sigmoid(o)

      if self._state_is_tuple:
        new_state = LSTMStateTuple(new_c, new_h)
      else:
        new_state = array_ops.concat([new_c, new_h], 1)
      return new_h, new_state
#在BasicLSTMCell中,将C,h以及input的输入维度都设置成一样的维度,因此在实际运算中可以通过一个大的W和B两个矩阵,来保存LSTM单元体中所有门单元的w和b。
i, j, f, o 四个门输出都是维度为num_size的tensor,sigmod输出的结果也是num_size维度的tensor,再和C进行点乘运算,即对于c每个维度上值进行运算,最终点乘结果的维度大小与C保持一致

'''
a = tf.constant([[1,2,3],[1,2,3]])
b = tf.constant([[2,3,4]])
print(a.get_shape())
print(b.get_shape())
c = a*b
c= tf.Print(c,[c])
with tf.Session() as sess:
print(sess.run(c))
'''

最终c的输出为

[[ 2 6 12] 
[ 2 6 12]]    b每个维度上的值作用于a 对应维度上值,即为点乘运算


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值