写在前面的一些内容
本次习题来源于 NNDL 作业9:分别使用numpy和pytorch实现BPTT 。
水平有限,难免有误,如有错漏之处敬请指正。
习题1
推导循环神经网络反向传播算法BPTT.
一些已知的东西:
z
1
=
U
h
0
+
W
x
1
+
b
z
2
=
U
h
1
+
W
x
2
+
b
z
3
=
U
h
2
+
W
x
3
+
b
h
1
=
f
(
z
1
)
y
1
^
=
g
(
h
1
)
z_1=Uh_0+Wx_1+b \\ z_2=Uh_1+Wx_2+b \\ z_3=Uh_2+Wx_3+b \\ h_1=f(z_1) \\ \widehat{y_1}=g(h_1)
z1=Uh0+Wx1+bz2=Uh1+Wx2+bz3=Uh2+Wx3+bh1=f(z1)y1
=g(h1)
T
=
1
T=1
T=1时,
∂
L
∂
U
\frac{\partial\mathcal{L}}{\partial U}
∂U∂L的情况
∂
L
1
∂
U
=
∂
L
1
∂
y
1
^
⋅
∂
y
1
^
∂
h
1
⋅
∂
h
1
∂
z
1
⋅
∂
z
1
∂
U
\frac{\partial\mathcal{L}_1}{\partial U}= \frac{\partial\mathcal{L}_1}{\partial\widehat{y_1}}\cdot \frac{\partial\widehat{y_1}}{\partial h_1}\cdot \frac{\partial h_1}{\partial z_1}\cdot \frac{\partial z_1}{\partial U}
∂U∂L1=∂y1
∂L1⋅∂h1∂y1
⋅∂z1∂h1⋅∂U∂z1
T
=
2
T=2
T=2时,
∂
L
∂
U
\frac{\partial\mathcal{L}}{\partial U}
∂U∂L的情况
∂
L
2
∂
U
=
∂
L
2
∂
y
2
^
⋅
∂
y
2
^
∂
h
2
⋅
∂
h
2
∂
z
2
⋅
∂
z
2
∂
U
+
∂
L
2
∂
y
2
^
⋅
∂
y
2
^
∂
h
2
⋅
∂
h
2
∂
z
2
⋅
∂
z
2
∂
h
1
⋅
∂
h
1
∂
z
1
⋅
∂
z
1
∂
U
\frac{\partial\mathcal{L}_2}{\partial U}= \frac{\partial\mathcal{L}_2}{\partial\widehat{y_2}}\cdot \frac{\partial\widehat{y_2}}{\partial h_2}\cdot \frac{\partial h_2}{\partial z_2}\cdot \frac{\partial z_2}{\partial U} + \frac{\partial\mathcal{L}_2}{\partial\widehat{y_2}}\cdot \frac{\partial\widehat{y_2}}{\partial h_2}\cdot \frac{\partial h_2}{\partial z_2}\cdot \frac{\partial z_2}{\partial h_1}\cdot \frac{\partial h_1}{\partial z_1}\cdot \frac{\partial z_1}{\partial U}
∂U∂L2=∂y2
∂L2⋅∂h2∂y2
⋅∂z2∂h2⋅∂U∂z2+∂y2
∂L2⋅∂h2∂y2
⋅∂z2∂h2⋅∂h1∂z2⋅∂z1∂h1⋅∂U∂z1
T
=
3
T=3
T=3时,
∂
L
∂
U
\frac{\partial\mathcal{L}}{\partial U}
∂U∂L的情况
∂
L
3
∂
U
=
∂
L
3
∂
y
3
^
⋅
∂
y
3
^
∂
h
3
⋅
∂
h
3
∂
z
3
⋅
∂
z
3
∂
U
+
∂
L
3
∂
y
3
^
⋅
∂
y
3
^
∂
h
3
⋅
∂
h
3
∂
z
3
⋅
∂
z
3
∂
h
2
⋅
∂
h
2
∂
z
2
⋅
∂
z
2
∂
U
+
∂
L
3
∂
y
3
^
⋅
∂
y
3
^
∂
h
3
⋅
∂
h
3
∂
z
3
⋅
∂
z
3
∂
h
2
⋅
∂
h
2
∂
z
2
⋅
∂
z
2
∂
h
1
⋅
∂
h
1
∂
z
1
⋅
∂
z
1
∂
U
\frac{\partial\mathcal{L}_3}{\partial U}= \frac{\partial\mathcal{L}_3}{\partial\widehat{y_3}}\cdot \frac{\partial\widehat{y_3}}{\partial h_3}\cdot \frac{\partial h_3}{\partial z_3}\cdot \frac{\partial z_3}{\partial U} + \frac{\partial\mathcal{L}_3}{\partial\widehat{y_3}}\cdot \frac{\partial\widehat{y_3}}{\partial h_3}\cdot \frac{\partial h_3}{\partial z_3}\cdot \frac{\partial z_3}{\partial h_2}\cdot \frac{\partial h_2}{\partial z_2}\cdot \frac{\partial z_2}{\partial U} + \frac{\partial\mathcal{L}_3}{\partial\widehat{y_3}}\cdot \frac{\partial\widehat{y_3}}{\partial h_3}\cdot \frac{\partial h_3}{\partial z_3}\cdot \frac{\partial z_3}{\partial h_2}\cdot \frac{\partial h_2}{\partial z_2}\cdot \frac{\partial z_2}{\partial h_1}\cdot \frac{\partial h_1}{\partial z_1}\cdot \frac{\partial z_1}{\partial U}
∂U∂L3=∂y3
∂L3⋅∂h3∂y3
⋅∂z3∂h3⋅∂U∂z3+∂y3
∂L3⋅∂h3∂y3
⋅∂z3∂h3⋅∂h2∂z3⋅∂z2∂h2⋅∂U∂z2+∂y3
∂L3⋅∂h3∂y3
⋅∂z3∂h3⋅∂h2∂z3⋅∂z2∂h2⋅∂h1∂z2⋅∂z1∂h1⋅∂U∂z1以此类推,可得
∂
L
∂
U
=
∑
t
=
1
T
∂
L
t
∂
U
\frac{\partial\mathcal{L}}{\partial U}=\sum_{t=1}^T\frac{\partial\mathcal{L}_t}{\partial U}
∂U∂L=t=1∑T∂U∂Lt设
δ
=
∂
L
∂
z
\delta=\frac{\partial\mathcal{L}}{\partial z}
δ=∂z∂L,
h
=
∂
z
∂
U
h=\frac{\partial z}{\partial U}
h=∂U∂z,则有
∂
L
t
∂
U
=
∑
k
=
1
t
δ
t
,
k
h
k
−
1
T
\frac{\partial\mathcal{L}_t}{\partial U}=\sum_{k=1}^t \delta_{t,k}h_{k-1}^T
∂U∂Lt=k=1∑tδt,khk−1T进而
∂
L
∂
U
=
∑
t
=
1
T
∑
k
=
1
t
δ
t
,
k
h
k
−
1
T
\frac{\partial\mathcal{L}}{\partial U}=\sum_{t=1}^T \sum_{k=1}^t \delta_{t,k}h_{k-1}^T
∂U∂L=t=1∑Tk=1∑tδt,khk−1T同理可得
∂
L
∂
W
=
∑
t
=
1
T
∑
k
=
1
t
δ
t
,
k
x
k
T
∂
L
∂
b
=
∑
t
=
1
T
∑
k
=
1
t
δ
t
,
k
\frac{\partial\mathcal{L}}{\partial W}=\sum_{t=1}^T \sum_{k=1}^t \delta_{t,k} x_k^T\\ \frac{\partial\mathcal{L}}{\partial b}=\sum_{t=1}^T \sum_{k=1}^t \delta_{t,k}
∂W∂L=t=1∑Tk=1∑tδt,kxkT∂b∂L=t=1∑Tk=1∑tδt,k
习题2
设计简单循环神经网络模型,分别用Numpy、Pytorch实现反向传播算子,并代入数值测试。
代码实现如下:
import torch
import numpy as np
class RNNCell:
def __init__(self, weight_ih, weight_hh,
bias_ih, bias_hh):
self.weight_ih = weight_ih
self.weight_hh = weight_hh
self.bias_ih = bias_ih
self.bias_hh = bias_hh
self.x_stack = []
self.dx_list = []
self.dw_ih_stack = []
self.dw_hh_stack = []
self.db_ih_stack = []
self.db_hh_stack = []
self.prev_hidden_stack = []
self.next_hidden_stack = []
# temporary cache
self.prev_dh = None
def __call__(self, x, prev_hidden):
self.x_stack.append(x)
next_h = np.tanh(
np.dot(x, self.weight_ih.T)
+ np.dot(prev_hidden, self.weight_hh.T)
+ self.bias_ih + self.bias_hh)
self.prev_hidden_stack.append(prev_hidden)
self.next_hidden_stack.append(next_h)
# clean cache
self.prev_dh = np.zeros(next_h.shape)
return next_h
def backward(self, dh):
x = self.x_stack.pop()
prev_hidden = self.prev_hidden_stack.pop()
next_hidden = self.next_hidden_stack.pop()
d_tanh = (dh + self.prev_dh) * (1 - next_hidden ** 2)
self.prev_dh = np.dot(d_tanh, self.weight_hh)
dx = np.dot(d_tanh, self.weight_ih)
self.dx_list.insert(0, dx)
dw_ih = np.dot(d_tanh.T, x)
self.dw_ih_stack.append(dw_ih)
dw_hh = np.dot(d_tanh.T, prev_hidden)
self.dw_hh_stack.append(dw_hh)
self.db_ih_stack.append(d_tanh)
self.db_hh_stack.append(d_tanh)
return self.dx_list
if __name__ == '__main__':
np.random.seed(123)
torch.random.manual_seed(123)
np.set_printoptions(precision=6, suppress=True)
rnn_PyTorch = torch.nn.RNN(4, 5).double()
rnn_numpy = RNNCell(rnn_PyTorch.all_weights[0][0].data.numpy(),
rnn_PyTorch.all_weights[0][1].data.numpy(),
rnn_PyTorch.all_weights[0][2].data.numpy(),
rnn_PyTorch.all_weights[0][3].data.numpy())
nums = 3
x3_numpy = np.random.random((nums, 3, 4))
x3_tensor = torch.tensor(x3_numpy, requires_grad=True)
h3_numpy = np.random.random((1, 3, 5))
h3_tensor = torch.tensor(h3_numpy, requires_grad=True)
dh_numpy = np.random.random((nums, 3, 5))
dh_tensor = torch.tensor(dh_numpy, requires_grad=True)
h3_tensor = rnn_PyTorch(x3_tensor, h3_tensor)
h_numpy_list = []
h_numpy = h3_numpy[0]
for i in range(nums):
h_numpy = rnn_numpy(x3_numpy[i], h_numpy)
h_numpy_list.append(h_numpy)
h3_tensor[0].backward(dh_tensor)
for i in reversed(range(nums)):
rnn_numpy.backward(dh_numpy[i])
print("numpy_hidden :\n", np.array(h_numpy_list))
print("torch_hidden :\n", h3_tensor[0].data.numpy())
print("-----------------------------------------------")
print("dx_numpy :\n", np.array(rnn_numpy.dx_list))
print("dx_torch :\n", x3_tensor.grad.data.numpy())
print("------------------------------------------------")
print("dw_ih_numpy :\n",
np.sum(rnn_numpy.dw_ih_stack, axis=0))
print("dw_ih_torch :\n",
rnn_PyTorch.all_weights[0][0].grad.data.numpy())
print("------------------------------------------------")
print("dw_hh_numpy :\n",
np.sum(rnn_numpy.dw_hh_stack, axis=0))
print("dw_hh_torch :\n",
rnn_PyTorch.all_weights[0][1].grad.data.numpy())
print("------------------------------------------------")
print("db_ih_numpy :\n",
np.sum(rnn_numpy.db_ih_stack, axis=(0, 1)))
print("db_ih_torch :\n",
rnn_PyTorch.all_weights[0][2].grad.data.numpy())
print("-----------------------------------------------")
print("db_hh_numpy :\n",
np.sum(rnn_numpy.db_hh_stack, axis=(0, 1)))
print("db_hh_torch :\n",
rnn_PyTorch.all_weights[0][3].grad.data.numpy())
代码执行结果:
numpy_hidden :
[[[ 0.4686 -0.298203 0.741399 -0.446474 0.019391]
[ 0.365172 -0.361254 0.426838 -0.448951 0.331553]
[ 0.589187 -0.188248 0.684941 -0.45859 0.190099]]
[[ 0.146213 -0.306517 0.297109 0.370957 -0.040084]
[-0.009201 -0.365735 0.333659 0.486789 0.061897]
[ 0.030064 -0.282985 0.42643 0.025871 0.026388]]
[[ 0.225432 -0.015057 0.116555 0.080901 0.260097]
[ 0.368327 0.258664 0.357446 0.177961 0.55928 ]
[ 0.103317 -0.029123 0.182535 0.216085 0.264766]]]
torch_hidden :
[[[ 0.4686 -0.298203 0.741399 -0.446474 0.019391]
[ 0.365172 -0.361254 0.426838 -0.448951 0.331553]
[ 0.589187 -0.188248 0.684941 -0.45859 0.190099]]
[[ 0.146213 -0.306517 0.297109 0.370957 -0.040084]
[-0.009201 -0.365735 0.333659 0.486789 0.061897]
[ 0.030064 -0.282985 0.42643 0.025871 0.026388]]
[[ 0.225432 -0.015057 0.116555 0.080901 0.260097]
[ 0.368327 0.258664 0.357446 0.177961 0.55928 ]
[ 0.103317 -0.029123 0.182535 0.216085 0.264766]]]
-----------------------------------------------
dx_numpy :
[[[-0.643965 0.215931 -0.476378 0.072387]
[-1.221727 0.221325 -0.757251 0.092991]
[-0.59872 -0.065826 -0.390795 0.037424]]
[[-0.537631 -0.303022 -0.364839 0.214627]
[-0.815198 0.392338 -0.564135 0.217464]
[-0.931365 -0.254144 -0.561227 0.164795]]
[[-1.055966 0.249554 -0.623127 0.009784]
[-0.45858 0.108994 -0.240168 0.117779]
[-0.957469 0.315386 -0.616814 0.205634]]]
dx_torch :
[[[-0.643965 0.215931 -0.476378 0.072387]
[-1.221727 0.221325 -0.757251 0.092991]
[-0.59872 -0.065826 -0.390795 0.037424]]
[[-0.537631 -0.303022 -0.364839 0.214627]
[-0.815198 0.392338 -0.564135 0.217464]
[-0.931365 -0.254144 -0.561227 0.164795]]
[[-1.055966 0.249554 -0.623127 0.009784]
[-0.45858 0.108994 -0.240168 0.117779]
[-0.957469 0.315386 -0.616814 0.205634]]]
------------------------------------------------
dw_ih_numpy :
[[3.918335 2.958509 3.725173 4.157478]
[1.261197 0.812825 1.10621 0.97753 ]
[2.216469 1.718251 2.366936 2.324907]
[3.85458 3.052212 3.643157 3.845696]
[1.806807 1.50062 1.615917 1.521762]]
dw_ih_torch :
[[3.918335 2.958509 3.725173 4.157478]
[1.261197 0.812825 1.10621 0.97753 ]
[2.216469 1.718251 2.366936 2.324907]
[3.85458 3.052212 3.643157 3.845696]
[1.806807 1.50062 1.615917 1.521762]]
------------------------------------------------
dw_hh_numpy :
[[ 2.450078 0.243735 4.269672 0.577224 1.46911 ]
[ 0.421015 0.372353 0.994656 0.962406 0.518992]
[ 1.079054 0.042843 2.12169 0.863083 0.757618]
[ 2.225794 0.188735 3.682347 0.934932 0.955984]
[ 0.660546 -0.321076 1.554888 0.833449 0.605201]]
dw_hh_torch :
[[ 2.450078 0.243735 4.269672 0.577224 1.46911 ]
[ 0.421015 0.372353 0.994656 0.962406 0.518992]
[ 1.079054 0.042843 2.12169 0.863083 0.757618]
[ 2.225794 0.188735 3.682347 0.934932 0.955984]
[ 0.660546 -0.321076 1.554888 0.833449 0.605201]]
------------------------------------------------
db_ih_numpy :
[7.568411 2.175445 4.335336 6.820628 3.51003 ]
db_ih_torch :
[7.568411 2.175445 4.335336 6.820628 3.51003 ]
-----------------------------------------------
db_hh_numpy :
[7.568411 2.175445 4.335336 6.820628 3.51003 ]
db_hh_torch :
[7.568411 2.175445 4.335336 6.820628 3.51003 ]