习题6-1P 推导RNN反向传播算法BPTT.
看以下ppt更容易理解(请忽略FNN BP):
习题6-2 推导公式(6.40)和公式(6.41)中的梯度.
推导了关于U的之后,关于W和b就会清晰很多。
习题6-3 当使用公式(6.50)作为循环神经网络的状态更新公式时, 分析其可能存在梯度爆炸的原因并给出解决方法.
给出公式(6.50)
𝜹𝒕,𝒌: 第𝑡时刻的损失对第𝑘𝑘步隐 藏神经元的净输入的导数,根据以下ppt内容
可以知道只要求梯度就可能会发生梯度爆炸,因为𝜹𝒕,𝒌的计算是不可避免的。
针对于梯度爆炸的改进方案:
1、通过正则化来约束γ的取值范围,让γ不大于1,最好是再1的周围。
2、通过观测梯度的模的大小来直接对梯度进行约束。
习题6-2P 设计简单RNN模型,分别用Numpy、Pytorch实现反向传播算子,并代入数值测试.
代码使用的是引用中文章里的代码,这里不再贴出只对其反向传播部分进行分析L:
def backward(self, dh):
# 从的栈中弹出当前时间步的输入x、上一个时间步的隐藏状态prev_hidden和下一个时间步的隐藏状态next_hidden。
x = self.x_stack.pop()
prev_hidden = self.prev_hidden_stack.pop()
next_hidden = self.next_hidden_stack.pop()
# 算当前时间步的tanh激活函数的导数d_tanh(当前时间步的梯度dh与上一个时间步的梯度prev_dh的和乘以(1−next_hidden 2))
d_tanh = (dh + self.prev_dh) * (1 - next_hidden ** 2)
self.prev_dh = np.dot(d_tanh, self.weight_hh)
# 计算当前时间步的输入x的梯度dx(d_tanh与输入到隐藏状态的权重矩阵weight_ih的乘积)dx插入到存储输入梯度的列表dx_list的开头,以便后续时间步使用
dx = np.dot(d_tanh, self.weight_ih)
self.dx_list.insert(0, dx)
# 计算当前时间步的输入到隐藏状态权重矩阵weight_ih的梯度dw_ih(d_tanh与输入x的转置的乘积),并将dw_ih存储到dw_ih_stack列表
dw_ih = np.dot(d_tanh.T, x)
self.dw_ih_stack.append(dw_ih)
# 计算当前时间步的隐藏状态到隐藏状态权重矩阵weight_hh的梯度dw_hh(d_tanh与上一个时间步的隐藏状态prev_hidden的转置的乘积)将dw_hh存储到dw_hh_stack列表
dw_hh = np.dot(d_tanh.T, prev_hidden)
self.dw_hh_stack.append(dw_hh)
# 将d_tanh存储到db_ih_stack和db_hh_stack列表,分别存储输入到隐藏状态和隐藏状态到隐藏状态的偏置项的梯度。
self.db_ih_stack.append(d_tanh)
self.db_hh_stack.append(d_tanh)
# 返回每个时间步的输入梯度
return self.dx_list
结果:
numpy_hidden :
[[[ 0.4686 -0.298203 0.741399 -0.446474 0.019391]
[ 0.365172 -0.361254 0.426838 -0.448951 0.331553]
[ 0.589187 -0.188248 0.684941 -0.45859 0.190099]]
[[ 0.146213 -0.306517 0.297109 0.370957 -0.040084]
[-0.009201 -0.365735 0.333659 0.486789 0.061897]
[ 0.030064 -0.282985 0.42643 0.025871 0.026388]]
[[ 0.225432 -0.015057 0.116555 0.080901 0.260097]
[ 0.368327 0.258664 0.357446 0.177961 0.55928 ]
[ 0.103317 -0.029123 0.182535 0.216085 0.264766]]]
tensor_hidden :
[[[ 0.4686 -0.298203 0.741399 -0.446474 0.019391]
[ 0.365172 -0.361254 0.426838 -0.448951 0.331553]
[ 0.589187 -0.188248 0.684941 -0.45859 0.190099]]
[[ 0.146213 -0.306517 0.297109 0.370957 -0.040084]
[-0.009201 -0.365735 0.333659 0.486789 0.061897]
[ 0.030064 -0.282985 0.42643 0.025871 0.026388]]
[[ 0.225432 -0.015057 0.116555 0.080901 0.260097]
[ 0.368327 0.258664 0.357446 0.177961 0.55928 ]
[ 0.103317 -0.029123 0.182535 0.216085 0.264766]]]
------
dx_numpy :
[[[-0.643965 0.215931 -0.476378 0.072387]
[-1.221727 0.221325 -0.757251 0.092991]
[-0.59872 -0.065826 -0.390795 0.037424]]
[[-0.537631 -0.303022 -0.364839 0.214627]
[-0.815198 0.392338 -0.564135 0.217464]
[-0.931365 -0.254144 -0.561227 0.164795]]
[[-1.055966 0.249554 -0.623127 0.009784]
[-0.45858 0.108994 -0.240168 0.117779]
[-0.957469 0.315386 -0.616814 0.205634]]]
dx_tensor :
[[[-0.643965 0.215931 -0.476378 0.072387]
[-1.221727 0.221325 -0.757251 0.092991]
[-0.59872 -0.065826 -0.390795 0.037424]]
[[-0.537631 -0.303022 -0.364839 0.214627]
[-0.815198 0.392338 -0.564135 0.217464]
[-0.931365 -0.254144 -0.561227 0.164795]]
[[-1.055966 0.249554 -0.623127 0.009784]
[-0.45858 0.108994 -0.240168 0.117779]
[-0.957469 0.315386 -0.616814 0.205634]]]
------
dw_ih_numpy :
[[3.918335 2.958509 3.725173 4.157478]
[1.261197 0.812825 1.10621 0.97753 ]
[2.216469 1.718251 2.366936 2.324907]
[3.85458 3.052212 3.643157 3.845696]
[1.806807 1.50062 1.615917 1.521762]]
dw_ih_tensor :
[[3.918335 2.958509 3.725173 4.157478]
[1.261197 0.812825 1.10621 0.97753 ]
[2.216469 1.718251 2.366936 2.324907]
[3.85458 3.052212 3.643157 3.845696]
[1.806807 1.50062 1.615917 1.521762]]
------
dw_hh_numpy :
[[ 2.450078 0.243735 4.269672 0.577224 1.46911 ]
[ 0.421015 0.372353 0.994656 0.962406 0.518992]
[ 1.079054 0.042843 2.12169 0.863083 0.757618]
[ 2.225794 0.188735 3.682347 0.934932 0.955984]
[ 0.660546 -0.321076 1.554888 0.833449 0.605201]]
dw_hh_tensor :
[[ 2.450078 0.243735 4.269672 0.577224 1.46911 ]
[ 0.421015 0.372353 0.994656 0.962406 0.518992]
[ 1.079054 0.042843 2.12169 0.863083 0.757618]
[ 2.225794 0.188735 3.682347 0.934932 0.955984]
[ 0.660546 -0.321076 1.554888 0.833449 0.605201]]
------
db_ih_numpy :
[7.568411 2.175445 4.335336 6.820628 3.51003 ]
db_ih_tensor :
[7.568411 2.175445 4.335336 6.820628 3.51003 ]
------
db_hh_numpy :
[7.568411 2.175445 4.335336 6.820628 3.51003 ]
db_hh_tensor :
[7.568411 2.175445 4.335336 6.820628 3.51003 ]
可以看到两者基本一致。
总结:
本次作业,深入认识了BPTT以及其推导方式,其中不乏一些难以理解的公式,但是通过搜寻资料以及对课件的反复观看,终于还是有点眉目了,要将其视为一个时间序列,这样就比较好理解了,而且在理解了这个之后,梯度爆炸推出来的那个公式也会更好理解为什么会产生γ这样一个式子。不要遇到一些难以认识或者理解的符号就退缩,首先搞清楚他们代表的是什么,理解之后会清晰很多,而且要对输入的形式有一个理解,例如说输入是什么形式的矩阵,再网络中又会有各种各样的矩阵。这些对于推导都很重要。
引用:
HBU-NNDL 作业9:分别使用numpy和pytorch实现BPTT_哦n........yly hbu: 9-CSDN博客