关于 RNN 循环神经网络的反向传播求导
本文是对 RNN 循环神经网络中的每一个神经元进行反向传播求导的数学推导过程,下面还使用 PyTorch
对导数公式进行编程求证。
RNN 神经网络架构
一个普通的 RNN 神经网络如下图所示:
其中 x ⟨ t ⟩ x^{\langle t \rangle} x⟨t⟩ 表示某一个输入数据在 t t t 时刻的输入; a ⟨ t ⟩ a^{\langle t \rangle} a⟨t⟩ 表示神经网络在 t t t 时刻时的hidden state,也就是要传送到 t + 1 t+1 t+1 时刻的值; y ⟨ t ⟩ y^{\langle t \rangle} y⟨t⟩ 则表示在第 t t t 时刻输入数据传入以后产生的预测值,在进行预测或 sampling 时 y ⟨ t ⟩ y^{\langle t \rangle} y⟨t⟩ 通常作为下一时刻即 t + 1 t+1 t+1 时刻的输入,也就是说 x ⟨ t ⟩ = y ^ ⟨ t ⟩ x^{\langle t \rangle}=\hat{y}^{\langle t \rangle} x⟨t⟩=y^⟨t⟩ ;下面对数据的维度进行说明。
- 输入: x ∈ R n x × m × T x x\in\mathbb{R}^{n_x\times m\times T_x} x∈Rnx×m×Tx 其中 n x n_x nx 表示每一个时刻输入向量的长度; m m m 表示数据批量数(batch); T x T_x Tx 表示共有多少个输入的时刻(time step)。
- hidden state: a ∈ R n a × m × T x a\in\mathbb{R}^{n_a\times m\times T_x} a∈Rna×m×Tx 其中 n a n_a na 表示每一个 hidden state 的长度。
- 预测: y ∈ R n y × m × T y y\in\mathbb{R}^{n_y\times m\times T_y} y∈Rny×m×Ty 其中 n y n_y ny 表示预测输出的长度; T y T_y Ty 表示共有多少个输出的时刻(time step)。
RNN 神经元
下图所示的是一个特定的 RNN 神经元:
上图说明了在第 t t t 时刻的神经元中,数据的输入 x ⟨ t ⟩ x^{\langle t \rangle} x⟨t⟩ 和上一层的 hidden state a ⟨ t ⟩ a^{\langle t \rangle} a⟨t⟩ 是如何经过计算得到下一层的 hidden state 和预测输出 y ^ ⟨ t ⟩ \hat{y}^{\langle t \rangle} y^⟨t⟩ 。
下面是对五个参数的维度说明:
- W a a ∈ R n a × n a W_{aa}\in\mathbb{R}^{n_a\times n_a} Waa∈Rna×na
- W a x ∈