LSTM/GRU详细代码解析+完整代码实现

本文详细介绍了LSTM和GRU两种门控循环神经网络的原理,并提供了PyTorch实现的代码示例。这两种模型常用于时间序列预测任务,如交通流量预测,通过结合历史数据提取结构和时序特征。文章还展示了基于GRU的完整代码以及训练过程中的损失曲线,最后给出了预测结果的对比图。
摘要由CSDN通过智能技术生成

LSTM和GRU目前被广泛的应用在各种预测场景中,并与卷积神经网络CNN或者图神经网络GCN这里等相结合,对数据的结构特征和时序特征进行提取,从而预测下一时刻的数据。在这里整理一下详细的LSTM/GRU的代码,并基于heatmap热力图实现对结果的展示。

一、GRU

GRU的公式如下图所示:

 其代码部分:

class GRU(torch.nn.Module):
    def __init__(self, hidden_size, output_size, num_layers):
        super().__init__()
        self.input_size = 1
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size
        self.num_directions = 1
        self.gru = torch.nn.GRU(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.linear = torch.nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input_seq):
        # input(batch_size, seq_len, input_size)
        batch_size, seq_len = input_seq.shape[0], input_seq.shape[1]
        h_0 = torch.randn(self.num_directions * self.num_layers, batch_size, self.hidden_size).to(device)
        # output(batch_size, seq_len, num_directions * hidden_size)
        output, _ = self.gru(input_seq, (h_0))
        pred = self.linear(output)
        pred = pred[:, -1, :]
        return pred

这里主要对里面主要的五个参数进行介绍: 

 input_size:输入节点特征的维度。这里需要注意的是,如果你输入的是节点的交通流量数据,一般只使用一个值表示,那么你的input_size为1;若是想基于该节点在t时刻的多个特征,如:流量、速度、车辆数这三个指标对交通流量进行预测,这是input_size=3。

hidden_size

  • 27
    点赞
  • 273
    收藏
    觉得还不错? 一键收藏
  • 26
    评论
评论 26
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值