验证LSTM内部实现流程,加深对LSTM的印象

1、LSTM结构图(多层)

在这里插入图片描述

2、LSTM 单个cell结构

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

3、底层计算验证

f ( f o r g e t ) = σ ( W i f x + b i f + W h f h + b h f ) 遗忘门 i ( i n p u t ) = σ ( W i i x + b i i + W h i h + b h i ) g = tanh ⁡ ( W i g x + b i g + W h g h + b h g ) 输入门 o ( o u t p u t ) = σ ( W i o x + b i o + W h o h + b h o ) 输出门 c ′ = f ∗ c + i ∗ g h ′ = o ∗ tanh ⁡ ( c ′ ) \begin{array}{ll} f_{(forget)} = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \quad 遗忘门\\ i_{(input)} = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \quad\quad\quad 输入门\\ o_{(output)} = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \quad 输出门\\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array} f(forget)=σ(Wifx+bif+Whfh+bhf)遗忘门i(input)=σ(Wiix+bii+Whih+bhi)g=tanh(Wigx+big+Whgh+bhg)输入门o(output)=σ(Wiox+bio+Whoh+bho)输出门c=fc+igh=otanh(c)

# 验证经过一个cell的计算
import torch
import torch.nn as nn

# 1. 设置特征
feature_size = 4
batch_size = 1
hidden_size = 10

x = torch.randn(batch_size, feature_size)

# 2. 利用torch自带的lstmcell计算一个节点的ht,ct
lstm = nn.LSTMCell(input_size=feature_size, hidden_size=hidden_size, bias=False)

h0 = torch.zeros(size=(batch_size, hidden_size))
c0 = torch.zeros(size=(batch_size, hidden_size))
ht, ct = lstm(x, (h0, c0))

print(f'调用LSTMCell模块计算{ht}')
print(f'调用LSTMCell模块计算{ct}')

# 3. 手动计算一个lstmcell输出ho,co
# 理论上lstm应该包含4个本次输入x权重矩阵(wii,wif,wig,wio)和4个上次输出权重矩阵(whi,whf,whg,who)共8个矩阵,但torch里面把4个进行了合并,简化计算
wih = lstm.weight_ih   # shape=(10*4,4)
whh = lstm.weight_hh   # shape=(10*4,10)

# 3.1 将上一步h与这一步x进行合并,后拆分成各个门的输入
ht_1 = torch.mm(input=h0, mat2=torch.t(whh))
xt = torch.mm(input=x, mat2=torch.t(wih))
hx = torch.add(ht_1, xt).reshape(-1, hidden_size)
i, f, g, o = hx[0], hx[1], hx[2], hx[3]
# 3.2 忘记门计算
c1 = torch.multiply(input=c0, other=torch.sigmoid(f))
# 3.2 输入门计算
c2 = torch.add(input=c1, other=torch.multiply(input=torch.sigmoid(i), other=torch.tanh(g)))
# 3.4 输出门计算
co = c2
ho = torch.multiply(input=torch.tanh(c2), other=torch.sigmoid(o))

print(f'手动根据结构图计算{ho}')
print(f'手动根据结构图计算{co}')


'''
调用LSTMCell模块计算tensor([[-0.0115, -0.0040,  0.0376, -0.0131,  0.0128, -0.0104,  0.0382, -0.0359,
         -0.0498,  0.0463]], grad_fn=<MulBackward0>)
调用LSTMCell模块计算tensor([[-0.0244, -0.0077,  0.0661, -0.0289,  0.0255, -0.0201,  0.0829, -0.0789,
         -0.1051,  0.0888]], grad_fn=<AddBackward0>)
         
         
手动根据结构图计算tensor([[-0.0115, -0.0040,  0.0376, -0.0131,  0.0128, -0.0104,  0.0382, -0.0359,
         -0.0498,  0.0463]], grad_fn=<MulBackward0>)
手动根据结构图计算tensor([[-0.0244, -0.0077,  0.0661, -0.0289,  0.0255, -0.0201,  0.0829, -0.0789,
         -0.1051,  0.0888]], grad_fn=<AddBackward0>)

'''
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值