[TensorFlow笔记] BasicRNNCell源码浅析

BasicRNNCell是抽象类RNNCell的一个最简单的实现。

class BasicRNNCell(RNNCell):

    def __init__(self, num_units, activation=None, reuse=None):
        super(BasicRNNCell, self).__init__(_reuse=reuse)
        self._num_units = num_units
        self._activation = activation or math_ops.tanh
        self._linear = None

    @property
    def state_size(self):
        return self._num_units

    @property
    def output_size(self):
        return self._num_units

    def call(self, inputs, state):
        if self._linear is None:
            self._linear = _Linear([inputs, state], self._num_units, True)

        output = self._activation(self._linear([inputs, state]))
        return output, output

实现了下面的运算:

这里写图片描述

用公式表示就是:

ht=tanh(Wk[xt,ht1]+b) h t = t a n h ( W k [ x t , h t − 1 ] + b )

有时这个公式会写成下面这个形式(W叫Kernel,U叫Recurrent Kernel,图形可参考这里):

ht=tanh(Wxt+Uht1+b) h t = t a n h ( W x t + U h t − 1 + b )

结果是一样的。

从源代码里可以看到,state_size和output_size都跟num_units都是同一个数字,call函数返回两个一模一样的向量。

简单调用例子:

import tensorflow as tf
import numpy as np

batch_size = 3
input_dim = 2
output_dim = 4

inputs = tf.placeholder(dtype=tf.float32, shape=(batch_size, input_dim))
previous_state = tf.random_normal(shape=(batch_size, output_dim))

cell = tf.contrib.rnn.BasicRNNCell(num_units=output_dim)
output, state = cell(inputs, previous_state)

X = np.ones(shape=(batch_size, input_dim))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    o, s = sess.run([output, state], feed_dict={inputs: X})

    print(X)
    print(previous_state.eval())
    print(o)
    print(s)

Out:
# Input:
[[ 1.  1.]
 [ 1.  1.]
 [ 1.  1.]]

# Previous State:
[[ 0.29562142  1.88447475 -0.71091568 -1.03161728]
 [-0.32763469 -0.4521957  -0.33536151  0.06760707]
 [-0.04627729  0.04288582 -0.62693876  0.70083541]]

# Output = State
[[ 0.1948889   0.77429289  0.41136274  0.42551333]
 [-0.99327117 -0.68583459  0.97010344 -0.56064779]
 [-0.98540735 -0.51250875  0.92181391 -0.98040372]]
[[ 0.1948889   0.77429289  0.41136274  0.42551333]
 [-0.99327117 -0.68583459  0.97010344 -0.56064779]
 [-0.98540735 -0.51250875  0.92181391 -0.98040372]]

额外

  1. 想要查看内部 Wk W k b b <script type="math/tex" id="MathJax-Element-4">b</script>的值可参考这里
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

手撕机

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值