RNN
原理解析
RNN的一层结构图如下:
- W: state 权重
- U: 输入权重
- V: 输出权重
- xt: t个step的输入
- ht: rnn cell 隐层, 也有的叫st (状态层)
- ot: 最终的输出
- yt: 经过soft max之后的分类结果
数学关系:
关于维度:
* ht维度: 是隐层的数量,也是自定义的, shape=[hn✖️1]
* W维度: shape=[hn, hn],才能保证W✖️ht-1的维度与ht的维度一样[hn✖️1]
* xt维度: 是embeding时候自定义的,shape=[xn✖️1]
* U维度: shape=[hn, xn],保证U✖️xt 与 ht的维度是一样的 [hn, 1]
* ot: 与xt的输出维度是一样的, 所以V必须是[xn, hn]维度.理论上也可以不一样,对于多层的RNN来说,只要保证最后一层ot与xt的输出是一样的就可以了。所以V其实可以是任意的[vn, hn]这样就可以了,在最后一层变成[xn, hn]就可以了,保证与xt一样的维度就可以了.
rnn cell的实现
通过上面的原理介绍,如果实现一个rnn cell,函数描述如下:
名称 | 类别 | 描述 | 对应变量 |
---|---|---|---|
xt | 输入参数 | 当前时刻的embeding变量 | xt |
ht-1 | 输入参数 | 前一个时刻的隐层变量输出 | ht-1 |
ot | 返回值 | 返回的target, 与x的维度是一样的 | ot |
ht | 返回值 | 当前时刻产生的新的隐层变量 | ht |
在tensorflow对应的类是: tf.nn.rnn_cell.BasicRNNCell.
BasicRNNCell
init参数描述:
- num_units: 隐层单元的数量,这是自己定义的。也就是前面维度中所说的hn, 隐层的单元数是自己定义的.
- activation: 激活函数,双曲正切
- reuse: True, 表示共享变量,多个cell 是共享权重的.
call参数描述:
- inputs: 也就是xt
- state: 也就是前一个ht-1, 注意ht-1的维度要与num_units吻合因为二者本来就是一个.
- ht, ht: 返回的都是ht,与前面描述 中返回的是ot, ht,因为ot也是ht变换得来的ot=C+V✖️ht,所以这里返回是两个ht.
demo:
import tensorflow as tf
import numpy as np
import logging
from tensorflow.python.ops import array_ops
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(name)s:%(levelname)s: %(message)s [%(filename)s:%(lineno)d]"
)
def test_rnn_cell():
num_units = 2
state_size = num_units # state_size也就是ht的size一定要与num_units是一样的
batch_size = 1
input_size = 4
x = array_ops.zeros([batch_size, input_size])
m = array_ops.zeros([batch_size, state_size])
logging.info('x type: ' + str(type(x)) + ': ' + str(x.shape))
logging.info('m type: ' + str(type(m)) + ': ' + str(m.shape))
# m = (array_ops.zeros([batch_size]), array_ops.zeros([batch_size]))
with tf.Session() as sess:
g, out_m = tf.nn.rnn_cell.BasicRNNCell(num_units)(x, m)
sess.run([tf.global_variables_initializer()])
# g_result == out_m_result 二者是同一个
g_result, out_m_result = sess.run([g, out_m],
{x.name: 1 * np.ones([batch_size, input_size]),
m.name: 0.1 * np.ones([batch_size, state_size])})
logging.info('g_result: ' + str(g_result))
logging.info('out_m_result: ' + str(out_m_result))
test_rnn_cell()
[2017-10-13 15:37:41,889] root:INFO: x type: <class 'tensorflow.python.framework.ops.Tensor'>: (1, 4) [<ipython-input-4-0e1f199b1a5b>:10]
[2017-10-13 15:37:41,890] root:INFO: m type: <class 'tensorflow.python.framework.ops.Tensor'>: (1, 2) [<ipython-input-4-0e1f199b1a5b>:11]
[2017-10-13 15:37:42,042] root:INFO: g_result: [[-0.7603032 0.54075867]] [<ipython-input-4-0e1f199b1a5b>:22]
[2017-10-13 15:37:42,043] root:INFO: out_m_result: [[-0.7603032 0.54075867]] [<ipython-input-4-0e1f199b1a5b>:23]