LSTM笔记

1、为什么用tanh,不用ReLUctant?(RNN)

避免梯度爆炸。优点:先前+当前。缺点:RNN难以实现长距离依赖
在这里插入图片描述
2、#RNN LSTM LSTMP 的源码实现
import torch
from torch import nn

bs, T, i_size, h_size = 2,3,4,5

input = torch.randn(bs,T,i_size) #正态分布随机初始化变量
c0 = torch.randn(bs,h_size) #输入初始值
h0 = torch.randn(bs,h_size)

#调用官方API
lstm_layer = nn.LSTM(i_size,h_size,batch_first=True)
output,(h_final,c_final) = lstm_layer(input,(h0.unsqueeze(0),c0.unsqueeze(0)))
for k,v in lstm_layer.named_parameters():
print(k,v.shape )

#自己写LSTM模型
def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh):
h0,c0 = initial_states
bs, T, i_size = input.shape
h_size = w_ih.shape[0] // 4

prev_h = h0
prev_c = c0
batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1)
batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1)

output_size = h_size
output = torch.zeros(bs,T,output_size) #输出序列

for t in range(T):
    x = input[:,t,:] #当前时刻的输入向量
    w_times_x = torch.bmm(batch_w_ih,x.unsqueeze(-1))
    w_times_x = w_times_x.squeeze(-1)

    w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1))
    w_times_h_prev = w_times_h_prev.squeeze(-1)

    #分别计算输入门(i)、遗忘门(f)、cell门、输出门
    i_t = torch.sigmoid(w_times_x[:,:h_size] + w_times_h_prev[:,h_size] + b_ih[:h_size] + b_hh[:h_size])
    f_t = torch.sigmoid(w_times_x[:, h_size:2*h_size] + w_times_h_prev[:, h_size:2*h_size]\
                        + b_ih[h_size:2*h_size] + b_hh[h_size:2*h_size])
    g_t = torch.tanh(w_times_x[:, 2*h_size:3*h_size] + w_times_h_prev[:, 2*h_size:3*h_size] \
                        + b_ih[2*h_size:3*h_size] + b_hh[2*h_size:3*h_size])
    o_t = torch.sigmoid(w_times_x[:, 3 * h_size:4 * h_size] + w_times_h_prev[:, 3 * h_size:4 * h_size] \
                     + b_ih[3 * h_size:4 * h_size] + b_hh[3 * h_size:4 * h_size])
    prev_c = f_t*prev_c + i_t*g_t
    prev_h = o_t*torch.tanh(prev_c)

    output[:,t,:] = prev_h

return output,(prev_c,prev_h)

output_custom,(h_final_custom,c_final_custom) = lstm_forward(input,(h0,c0),lstm_layer.weight_ih_l0,lstm_layer.weight_hh_l0,
lstm_layer.bias_ih_l0,lstm_layer.bias_hh_l0)

print(output_custom)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值