博文配套视频课程:自然语言处理与知识图谱
RNN原理剖析
既然我们已经有了深度神经网络和卷积神经网络,为什么还要循环神经网络?
原因很简单,无论是卷积神经网络,还是人工神经网络,他们的前提假设都是:元素之间是相互独立的,输入与输出也是独立的,比如猫和狗。但是在有些场景顺序是非常重要的。例如:股票交易时间序列,文章上下文等等。需要依赖上下文才能进行更好的推导和理解,这也需要RNN具备一定的记忆能力 (能够记住前面文字)
RNN原理图
网络中很多RNN的图都并不是很直观,并且相互抄窃导致很多图大同小异。我自己画了一个其实理解循环神经网络核心就两点:输入有序且前面词语义会对当前词有影响。
- 下图的神经元其实只有一个,是在不同的时间片段 (t1,t2,t3) 对一句话的三个词 (x1,x2,x3) 依次进行处理。这也是和CNN,DNN最大的区别。
- 每次在处理当前T时刻的特征时,上一个时刻T-1的输出也会成为T时刻的输入。因此你们看到公式为:fn(wt * xt + ut * ht-1)。其中wt,ut为权重,xt为当前时刻特征值,ht-1为上一时刻输出结果。
从Cell讲起
在时间轴上的某一次运算称为一个SimpleRNNCell,它是一个非常底层的概念,需要我们自己管理每次运算的输入、权重、偏置、输出。因此对于理解RNN底层是非常有帮助的。在视频教程中通过查看源码分析SimpleRNNCell参数。不习惯查看源码小伙伴也可以通过查看下面网页获取参数信息
SimpleRNNCell 参数参考:https://www.cnblogs.com/Renyi-Fan/p/13722276.html
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import SimpleRNNCell
# 模拟第T时刻的特征
xt = tf.Variable(np.random.randint(2, 3, size=[1, 1]), dtype= np.float32)
# 1个样本4个特征
print(xt)
print('='*100)
# Cell 一个循环神经元一次运算
# units: 正整数,输出空间的维度, 即隐藏层神经元数量.
# activation: 激活函数,默认是tanh
# use_bias: Boolean, 是否使用偏置向量.
# kernel_initializer: 输入和隐藏层之间的权重参数初始化器
# recurrent_initializer: 循环运算的权重
# bias_initializer: 偏置向量的初始化器.
# use_bias=True tf.keras.initializers.Constant(value=3) 指定值测试
cell = SimpleRNNCell(units=1,activation=None,use_bias=False,kernel_initializer='ones',
recurrent_initializer='ones',bias_initializer='ones')
print(tf.math.tanh(tf.constant([-float("inf"),3,float("inf")])))
# 指定输入的格式,用于初始化网络层中涉及到的权重参数
cell.build(input_shape=[None,1])
print(cell.variables)
print("config",cell.get_config())
第T次运算
# t-1 上一时刻运算的结果
ht_1 = tf.zeros([1,1])
# out,ht = active(xt * wt + ht_1 * ut)
out,ht = cell(xt, ht_1) # 传入第t-1时刻的h
# 得到t时刻输出结果,而这个结果在t+1时刻作为输入
print(f'out:{out},ht:{ht}')
print(id(out),id(ht[0]))
输出结果如下:对于RNN来说 out 与 ht 目前是相同的。到了后续LSTM才会有区分。
1907881944960 1907881944960
第T+1次运算
cell2 = SimpleRNNCell(units=1,activation=None,use_bias=True,kernel_initializer='ones',
recurrent_initializer='ones',bias_initializer='ones')
xt2 = tf.Variable(np.random.randint(2,3,size=[1,1]),dtype= np.float32)
# 此处的ht 就是上一次T运算的输出结果
out2,ht2 = cell2(xt2,ht) # out2,ht2 = active(xt2 * wt2 + ht * ut2)
print(f'out:{out2},ht:{ht2}')