PyTorch_构建一个LSTM网络单元

今天用PyTorch参考《Python深度学习基于PyTorch》搭建了一个LSTM网络单元,在这里做一下笔记。

1.LSTM的原理

LSTM是RNN(循环神经网络)的变体,全名为长短期记忆网络(Long Short Term Memory networks)。
它的精髓在于引入了细胞状态这样一个概念,不同于RNN只考虑最近的状态,LSTM的细胞状态会决定哪些状态应该被留下来,哪些状态应该被遗忘。
具体与RNN的区别可参考这篇博文:LSTM与RNN的比较
先放一张LSTM网络的模型图:

在这里插入图片描述
如上图所示,可以看到这是一个网络,我们单拿出其中一个单元来进行分析,可见每一个单元都包含一系列运算,那么这些运算的意义是什么呢?下面我们来一一解释每个单元的具体内容。

(1)遗忘门
在这里插入图片描述
ht-1 :前一个时刻的Cell的输出
xt : 当前时刻的输入
注意:中括号的意思是将ht-1与xt拼接起来,后面出现公式同理

遗忘门主要来判断上一状态中的输出对现状态的影响大小,遗忘门的输出要通过一个Sigmoid函数,Sigmoid函数的输出范围是0~1,相当于得到一个权重,后面与Ct-1相乘,以此得到上一状态输出对现状态的影响。

(2)输入门
在这里插入图片描述
输入门中会得到一个临界的细胞状态(Ct^),表示此状态下的备选输出,与it作用后就得到此次状态需要输出的内容。

在这里插入图片描述
由以上两个门就可以输出更新后的细胞状态Ct,输出公式如上图所示,需要注意这里的“ * ”符号为哈达玛乘积,就是对应矩阵元素相乘。

(3)输出门
在这里插入图片描述
输出门具体运算过程如上图所示。这样就得到了这个时刻的输出,把这个输出再传入下一状态即可。

2.代码实现

初始化:

import torch
import torch.nn as nn

搭建一个LSTM单元:

class LSTMCell(nn.Module):
    def __init__(self,input_size,hidden_size,cell_size,output_size):
        super(LSTMCell,self).__init__()
        self.hidden_size = hidden_size
        self.cell_size = cell_size
        #设定门输入输出数据的大小尺寸
        self.gate = nn.Linear(input_size+hidden_size,cell_size)
        self.output = nn.Linear(hidden_size,output_size)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        #分类器-输出
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self,input,hidden,cell):
        #拼接数据,后置的0/1 确定横向(1)还是竖向(0)拼接 
        combined = torch.cat((input,hidden),1)
        #根据LSTM一个单元的网络图得出三个门,并进行运算
        f_gate = self.sigmoid(self.gate(combined))
        i_gate = self.sigmoid(self.gate(combined))
        #z_state看作为Cell的中间状态
        z_state = self.tanh(self.gate(combined))
        o_gate = self.sigmoid(self.gate(combined))
        #注意这下面的乘为哈达玛乘积,矩阵对应元素相乘
        cell = torch.add(torch.mul(f_gate,cell),torch.mul(i_gate,z_state))
        hidden = torch.mul(self.tanh(cell),o_gate)
        output = self.output(hidden)
        output = self.softmax(output)
        return output,hidden,cell
    
    def initHidden(self):
        return torch.zeros(1,self.hidden_size)
    
    def initCell(self):
        return torch.zeros(1,self.cell_size)

实例化LSTMCell,并传入输入、隐含状态等进行验证:

lstmcell = LSTMCell(input_size=10,hidden_size=20,cell_size=20,output_size=10)
input = torch.randn(32,10)
h_0 = torch.randn(32,20)
c_0 = torch.randn(32,20)
output,hn,cn = lstmcell(input,h_0,c_0)
print(output.size(),hn.size(),cn.size())

输出结果:
torch.Size([32, 10]) torch.Size([32, 20]) torch.Size([32, 20])

end
(以上图片来源于网络,若侵权请联系删除)

  • 7
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值