这篇文章主要是对之前一段时间里接触到的 循环神经网络 的相关知识进行一些总结,包括个人觉得初学难理解或者需要注意的问题和如何使用Pytorch的相关函数。由于这些经典结构网上资料很多,所以一些通识不再陈述,偏重代码部分。
1.RNN
很多问题都归结于序列信息的处理,例如 speech recognization,machine translation等等,RNN就是为了解决这类问题的结构,这里的RNN含义为循环神经网络(recurrent neural network)而非递归神经网络(recursive neural network)。序列信息可以看作是不同时间点输入相同格式的数据,那么使用一个结构循环处理不同时间点的数据,那么这也就是RNN网络了,所以很多介绍RNN的地方都会有那张经典的RNN展开的图了:
这一类介绍资源非常多,所以不再赘述。RNN的关键在于它的计算公式:
s t = f ( U ⋅ x t + W ⋅ s t − 1 ) o t = s o f t m a x ( V ⋅ s t ) s_t = f(U\cdot x_t + W\cdot s_{t-1}) \\ o_t = softmax(V\cdot s_t) st=f(U⋅xt+W⋅st−1)ot=softmax(V⋅st)
说明:
- x t x_t xt是某个时刻的输入信息,序列信息可以看作是不同时间的连续输入,所以每个时间点都会输入信息。
- s t s_t st 表示隐藏信息,对于序列信息的处理,很重要的一点就是上文信息会影响到下文信息,所以需要有一个结构来储存之前的所有信息。
- o t o_t ot表示某个时间点的输出信息。
RNN有几个特点:
- 每个时间点都会输出一个隐藏状态,但是显然我们并不需要全部的信息,例如在对文本进行分类的时候,我们往往只是使用最后一个时刻的隐藏状态,然后通过一个分类器即可。
- 权值共享,实际上是一个结构对不同时刻的信息进行处理,所以所有的权重实际上都是相同的。
- RNN也使用BP算法来更新参数,但是与之前的神经网络不同的是,这里的梯度计算需要依赖于之前的所有步,然后将梯度累加,这被称为 BPTT(