一文读懂RNN
RNN的基本结构
- RNN的基本结构如图所示,每个时间步t接收两个输入:
- x t x_{t} xt:当前时间步的输入。
- s t − 1 s_{t-1} st−1:上一个时间步的隐藏状态。
- RNN单元会输出:
- s t s_{t} st:当前时间步的隐藏状态,该状态将传递给下一个时间步。
- o t o_{t} ot:当前时间步的输出(可选)。
隐藏状态的更新
隐藏状态的更新公式是RNN的核心,它决定了如何将当前输入和上一个时间步的隐藏状态结合起来,生成当前时间步的隐藏状态。公式如下:
s t = ϕ ( U t x t + W t s t − 1 + b s ) s_{t} =\phi (U_{t}x_{t}+W_{t}s_{t-1}+b_{s}) st=ϕ(Utxt+Wtst−1+bs)
其中
- U t U_{t} Ut:输入到隐藏状态的权重矩阵。
- W t W_{t} Wt:隐藏状态到隐藏状态的权重矩阵。
- b w b_{w} bw:隐藏层的偏置向量。
- ϕ \phi ϕ:激活函数。
- 推导过程:
- 输入到隐藏状态的线性变换:
W t s t − 1 W_{t}s_{t-1} Wtst−1这一步将当前时间步骤的输入 x t x_{t} xt 通过权重矩阵 U t U_{t} Ut 进行线性变换。 - 隐藏状态到隐藏状态的线性变换:
W t s t − 1 W_{t}s_{t-1} Wtst−1这一步将上一个时间步的隐藏状态 s t − 1 s_{t-1} st−1 通过权重矩阵 W t W_{t} Wt进行线性变换。 - 偏置项的添加:
U t x t + W t s t − 1 + b s U_{t}x_{t}+W_{t}s_{t-1}+b_{s} Utxt+Wtst−1+bs将上述两个线性变换的结果相加,并加上偏置项 - 激活函数的应用:
最后,通过激活函数 s t = ϕ ( U t x t + W t s t − 1 + b s ) s_{t} =\phi (U_{t}x_{t}+W_{t}s_{t-1}+b_{s}) st=ϕ(Utxt+Wtst−1+bs)引入非线性
,得到当前时间步的隐藏状态
- 输入到隐藏状态的线性变换:
输出的计算
如果RNN单元有输出 o t o_{t} ot,那么它通常是由当前的隐藏状态 s t s_{t} st经过线性变换得到的。公式如下:
o t = V t s t + b y o_{t} = V_{t}s_{t}+b_{y} ot=Vtst+by
其中:
- V t V_{t} Vt:隐藏状态到输出的权重矩阵。
- b y b_{y} by:输出层的偏置向量。
- 推导过程:
- 隐藏层到输出的线性变换:
V t s t V_{t}s_{t} Vtst这一步将当前时间步骤的输入 s t s_{t} st 通过权重矩阵 V t V_{t} Vt 进行线性变换。 - 偏置项的添加:
V t s t + b y V_{t}s_{t}+b_{y} Vtst+by将上述线性变换的结果加上偏置项 b y b_{y} by
- 隐藏层到输出的线性变换:
训练过程
在训练过程中,RNN的参数(权重矩阵和偏置向量)是通过反向传播算法(Backpropagation Through Time, BPTT)
根据损失函数对这些参数的梯度进行更新的。BPTT算法将整个序列展开成一个深层的前馈神经网络,然后使用标准的反向传播算法来计算梯度。
4.1 损失函数
假设我们有一个序列数据{x_{1},x_{2},…,x_{T}}和对应的目标序列{o_{1},o_{2},…,o_{T}},损失函数 L L L 可以定义为:
L t = ∑ t = 1 T L t L_{t}= {\textstyle \sum_{t=1}^{T}}L_{t} Lt=∑t=1TLt
其中 L t L_{t} Lt是在时间步 t的损失,通常使用交叉熵损失函数:
L t = ∑ i = 1 C o t ( i ) log o ^ t ( i ) L_{t}= {\textstyle \sum_{i=1}^{C}} o_{t}^{(i)} \log_{}{} \hat{o} _{t}^{(i)} Lt=∑i=1Cot(i)logo^t(i)
4.2 反向传播
为了更新参数,我们需要计算损失函数 L L L 对各参数的梯度。我们使用链式法则来计算这些梯度。
- 输出层的梯度:
ϑ L t ϑ V t = ϑ L t ϑ o t ϑ o t ϑ V t \frac{\vartheta L_{t}}{\vartheta V_{t}} = \frac{\vartheta L_{t}}{\vartheta o_{t}} \frac{\vartheta o_{t}}{\vartheta V_{t}} ϑVtϑLt=ϑotϑLtϑVtϑot
ϑ L t ϑ b y = ϑ L t ϑ o t ϑ o t ϑ b y \frac{\vartheta L_{t}}{\vartheta b_{y}} = \frac{\vartheta L_{t}}{\vartheta o_{t}} \frac{\vartheta o_{t}}{\vartheta b_{y}} ϑbyϑLt=ϑotϑLtϑbyϑot
- 隐藏层的梯度:
ϑ L t ϑ s t