python实现单层单向LSTM

用python的基本语法以及pytorch的一些数据结构实现LSTM,并用torch.nn.LSTM对我自己实现的网络的正确性进行验证,参考自https://www.bilibili.com/video/BV1jm4y1Q7uh/?spm_id_from=333.788

bs,T,i_size,h_size=2,3,4,5

input=torch.rand([bs,T,i_size])

h_0=torch.rand(bs,h_size)
c_0=torch.rand(bs,h_size)

# 用官方的接口运行一下
lstm=nn.LSTM(i_size,h_size,batch_first=True)
output,(h_n,c_n)=lstm(input,(h_0.unsqueeze(0),c_0.unsqueeze(0)))
print(output,(h_n,c_n))

def lstm_forward(input,initial_states,w_ih,w_hh,b_ih,b_hh):
    '''
        input.shape=bs*T*i_size
    '''
    h_0,c_0=initial_states
    bs,T,i_size=input.shape
    h_size=h_0.shape[-1]
    output_size=h_size

    prev_h=h_0 # shape=bs*hidden_size
    prev_c=c_0
    output=torch.zeros(bs,T,output_size)
    batch_w_ih=w_ih.unsqueeze(0).tile(bs,1,1) # shape=bs*(4*h_size)*i_size
    batch_w_hh=w_hh.unsqueeze(0).tile(bs,1,1)

    for t in range(T):
        x=input[:,t,:] # shape=bs*i_size
        w_times_x=torch.bmm(batch_w_ih,x.unsqueeze(-1)).squeeze(-1)
        w_times_h=torch.bmm(batch_w_hh,prev_h.unsqueeze(-1)).squeeze(-1)
        input_gate=torch.sigmoid(w_times_x[:,:h_size]+b_ih[:h_size]+w_times_h[:,:h_size]+b_hh[:h_size])
        forget_gate=torch.sigmoid(w_times_x[:,h_size:2*h_size]+b_ih[h_size:2*h_size]\
            +w_times_h[:,h_size:2*h_size]+b_hh[h_size:2*h_size])
        cell_gate=torch.tanh(w_times_x[:,2*h_size:3*h_size]+b_ih[2*h_size:3*h_size]\
            +w_times_h[:,2*h_size:3*h_size]+b_hh[2*h_size:3*h_size])
        output_gate=torch.sigmoid(w_times_x[:,3*h_size:]+b_ih[3*h_size:]+w_times_h[:,3*h_size:]+b_hh[3*h_size:])
        prev_c=forget_gate*prev_c+input_gate*cell_gate
        prev_h=output_gate*torch.tanh(prev_c)
        output[:,t,:]=prev_h

    return output,(prev_h,prev_c)    

output_custom,(h_n_custom,c_n_custom)=lstm_forward(input,(h_0,c_0),lstm.weight_ih_l0,lstm.weight_hh_l0,lstm.bias_ih_l0,lstm.bias_hh_l0)
# 这个allclose函数是用来判断两个tensor是否在可接受的误差范围内数值一致
print(torch.allclose(output,output_custom))
print(torch.allclose(h_n,h_n_custom))
print(torch.allclose(c_n,c_n_custom))

输出:
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值