RNN cell的实现
flyfish
已编译通过
步骤
1 使用tanh激活函数计算隐藏状态
a⟨t⟩=tanh(Waaa⟨t−1⟩+Waxx⟨t⟩+ba)
a
⟨
t
⟩
=
tanh
(
W
a
a
a
⟨
t
−
1
⟩
+
W
a
x
x
⟨
t
⟩
+
b
a
)
2 使用新的隐藏状态
a⟨t⟩
a
⟨
t
⟩
计算预测值,
y^⟨t⟩=softmax(Wyaa⟨t⟩+by)
y
^
⟨
t
⟩
=
s
o
f
t
m
a
x
(
W
y
a
a
⟨
t
⟩
+
b
y
)
已提供softmax函数
3 在cache中存储
(a⟨t⟩,a⟨t−1⟩,x⟨t⟩,parameters)
(
a
⟨
t
⟩
,
a
⟨
t
−
1
⟩
,
x
⟨
t
⟩
,
p
a
r
a
m
e
t
e
r
s
)
4 返回
a⟨t⟩,y⟨t⟩,cache
a
⟨
t
⟩
,
y
⟨
t
⟩
,
c
a
c
h
e
import numpy as np
def softmax(x):
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=0)
def rnn_cell_forward(xt, a_prev, parameters):
# Retrieve parameters from "parameters"
#字符串与值的一种映射方法
Wax = parameters["Wax"]
Waa = parameters["Waa"]
Wya = parameters["Wya"]
ba = parameters["ba"]
by = parameters["by"]
#按照上面的公式写就行
# compute next activation state using the formula given above np.tanh
#使用上面的np.tanh公式计算下一个激活状态
a_next = np.tanh(np.dot(Wax, xt) + np.dot(Waa, a_prev) + ba)
yt_pred = softmax(np.dot(Wya, a_next) + by)
# store values you need for backward propagation in cache
#cache在反向传播中会使用
cache = (a_next, a_prev, xt, parameters)
return a_next, yt_pred, cache
np.random.seed(1)
xt = np.random.randn(3,10)
a_prev = np.random.randn(5,10)
Waa = np.random.randn(5,5)
Wax = np.random.randn(5,3)
Wya = np.random.randn(2,5)
ba = np.random.randn(5,1)
by = np.random.randn(2,1)
parameters = {"Waa": Waa, "Wax": Wax, "Wya": Wya, "ba": ba, "by": by}
a_next, yt_pred, cache = rnn_cell_forward(xt, a_prev, parameters)
print("a_next = ", a_next)
print("a_next.shape = ", a_next.shape)
print("yt_pred[1] =", yt_pred[1])
print("yt_pred.shape = ", yt_pred.shape)