手写RNN
import numpy as np
from utils import *
def rnn_cell_forward(x_t,s_prev,param):
U=param['U']
W = param['W']
V = param['V']
ba = param['ba']
by = param['by']
s_next=np.tanh(np.dot(U,x_t)+np.dot(W,s_prev)+ba)
out_pred=softmax(np.dot(V,s_next)+by)
cache=(s_next,s_prev,x_t,param)
return s_next,out_pred,cache
def rnn_forward(x,s0,param):
m,_,T=x.shape
m,n=param["V"].shape
s_next=s0
caches = []
s=np.zeros((n,1,T))
y=np.zeros((m,1,T))
for t in range(T):
s_next,out_pred,cache=rnn_cell_forward(x[:,:,t],s_next,param)
s[:,:,t]=s_next
y[:,:,t]=out_pred
caches.append(cache)
return s,y,caches
def rnn_cell_backward(ds_next,cache):
(s_next,s_prev,x_t,param)=cache
U = param['U']
W = param['W']
V = param['V']
ba = param['ba']
by = param['by']
dtanh=(1-s_next**2)*ds_next
dU=np.dot(dtanh,x_t.T)
dW=np.dot(dtanh,s_prev.T)
dba=np.sum(dtanh,axis=1,keepdims=1)
dx_t=np.dot(U.T,dtanh)
ds_prev=np.dot(W.T,dtanh)
ans={"dtanh":dtanh,"dU":dU,"dW":dW,"dba":dba,"dx_t":dx_t,"ds_prev":ds_prev}
return ans
def rnn_backward(ds,caches):
(s1,s0,x_1,param)=caches[0]
n,_,T=ds.shape
m,_=x_1.shape
dU=np.zeros((n,m))
dW=np.zeros((n,n))
dba=np.zeros((n,1))
ds_prevt=np.zeros((n,1))
dx=np.zeros((m,1,T))
for t in reversed(range(T)):
gradients=rnn_cell_backward(ds[:,:,t]+ds_prevt,caches[t])
ds_prevt=gradients["ds_prev"]
dU+=gradients["dU"]
dW+=gradients["dW"]
dba+=gradients["dba"]
dx[:,:,t]=gradients['dx_t']
gradients={"dU":dU,"dW":dW,"dba":dba,"dx":dx}
return gradients
np.random.seed(1)
n=6
m=2
T=7
x=np.random.randn(m,1,T)
s0=np.random.randn(n,1)
W=np.random.randn(n,n)
U=np.random.randn(n,m)
V=np.random.randn(m,n)
ba=np.random.randn(n,1)
by=np.random.randn(m,1)
param={"U":U,"W":W,"V":V,"ba":ba,"by":by}
s,y,caches=rnn_forward(x,s0,param)
ds=np.random.randn(n,1,T)
gradients=rnn_backward(ds,caches)
print(gradients)
print("s==",s)
print("s.shape==",s.shape)
print("y==",y)
print("y.shape==",y.shape)