手写RNN

手写RNN

import numpy as np
from utils import *

def rnn_cell_forward(x_t,s_prev,param):
    U=param['U']
    W = param['W']
    V = param['V']
    ba = param['ba']
    by = param['by']
    s_next=np.tanh(np.dot(U,x_t)+np.dot(W,s_prev)+ba)
    out_pred=softmax(np.dot(V,s_next)+by)
    cache=(s_next,s_prev,x_t,param)
    return s_next,out_pred,cache
def rnn_forward(x,s0,param):
    m,_,T=x.shape
    m,n=param["V"].shape
    s_next=s0
    caches = []
    s=np.zeros((n,1,T))
    y=np.zeros((m,1,T))
    for t in range(T):
        s_next,out_pred,cache=rnn_cell_forward(x[:,:,t],s_next,param)
        s[:,:,t]=s_next
        y[:,:,t]=out_pred
        caches.append(cache)
    return s,y,caches
def rnn_cell_backward(ds_next,cache):
    (s_next,s_prev,x_t,param)=cache
    U = param['U']
    W = param['W']
    V = param['V']
    ba = param['ba']
    by = param['by']
    dtanh=(1-s_next**2)*ds_next
    dU=np.dot(dtanh,x_t.T)
    dW=np.dot(dtanh,s_prev.T)
    dba=np.sum(dtanh,axis=1,keepdims=1)
    dx_t=np.dot(U.T,dtanh)
    ds_prev=np.dot(W.T,dtanh)
    ans={"dtanh":dtanh,"dU":dU,"dW":dW,"dba":dba,"dx_t":dx_t,"ds_prev":ds_prev}
    return ans
def rnn_backward(ds,caches):
    (s1,s0,x_1,param)=caches[0]
    n,_,T=ds.shape
    m,_=x_1.shape
    dU=np.zeros((n,m))
    dW=np.zeros((n,n))
    dba=np.zeros((n,1))
    ds_prevt=np.zeros((n,1))
    dx=np.zeros((m,1,T))
    for t in reversed(range(T)):
        gradients=rnn_cell_backward(ds[:,:,t]+ds_prevt,caches[t])
        ds_prevt=gradients["ds_prev"]
        dU+=gradients["dU"]
        dW+=gradients["dW"]
        dba+=gradients["dba"]
        dx[:,:,t]=gradients['dx_t']
    gradients={"dU":dU,"dW":dW,"dba":dba,"dx":dx}
    return gradients


np.random.seed(1)
n=6
m=2
T=7
x=np.random.randn(m,1,T)
s0=np.random.randn(n,1)
W=np.random.randn(n,n)
U=np.random.randn(n,m)
V=np.random.randn(m,n)
ba=np.random.randn(n,1)
by=np.random.randn(m,1)
param={"U":U,"W":W,"V":V,"ba":ba,"by":by}
s,y,caches=rnn_forward(x,s0,param)
ds=np.random.randn(n,1,T)
gradients=rnn_backward(ds,caches)
print(gradients)
print("s==",s)
print("s.shape==",s.shape)
print("y==",y)
print("y.shape==",y.shape)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值