pytorch下从头搭建LSTM模型(全网最简洁,非第三方封装LSTM)

前不久项目使用到了一个LSTM模型,让GPT写了一个结果用的是pytorch的封装好的模型,网上大多数博客也是这样为了博取流量这样搞出来的模型完全用不了,所以我根据提出LSTM模型的Understanding LSTM Networks这篇文章一步一步搭建一个LSTM模型,文章链接:Understanding LSTM Networks -- colah's blog

注意,本篇博客只适合对LSTM模型有基础了解的同学,不了解的先看Understanding LSTM Networks文章,我的代码完全基于该文章写的,所以一定要先看这篇文章,这非常重要!!!

接下来直接给代码:

class RiceLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RiceLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.Wf = nn.Linear(input_size + hidden_size, hidden_size)
        self.Wi = nn.Linear(input_size + hidden_size, hidden_size)
        self.Wo = nn.Linear(input_size + hidden_size, hidden_size)
        self.Wc = nn.Linear(input_size + hidden_size, hidden_size)#12 -> 10(以下标注均为隐藏层大小10,特征向量2的情况)
        self.output_layer = nn.Linear(hidden_size, output_size)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
    def forward(self, input):
        batch_size = input.size(0)
        seq_len = input.size(1)
        hidden_state = torch.zeros(batch_size, self.hidden_size, dtype=torch.float32)#(1,10)
        cell_state = torch.zeros(batch_size, self.hidden_size, dtype=torch.float32)
        outputs = []

        for i in range(seq_len):
            combined = torch.cat((input[:, i, :], hidden_state), dim=1)#(1,2)(时间步,特征)+(1,10)=(1,12)
            f_t = self.sigmoid(self.Wf(combined))#(1,10)
            i_t = self.sigmoid(self.Wi(combined))
            o_t = self.sigmoid(self.Wo(combined))
            c_hat_t = self.tanh(self.Wc(combined))
            # cell_state = f_t * cell_state + (1-f_t) * c_hat_t#(1,10)
            cell_state = f_t * cell_state + i_t * c_hat_t  # (1,10)
            hidden_state = o_t * self.tanh(cell_state)#(1,10)
            outputs.append(hidden_state.unsqueeze(1))#(1,+1,10)

        outputs = torch.cat(outputs, dim=1)#80个outputs (1,80,10)
        final_output = self.output_layer(outputs)#(1,80,1)
        return final_output, (hidden_state, cell_state)#hidden_state、cell_state大小没变

上述模型是一个最简单的LSTM模型,代码看上去头大是吧,首先要看懂代码结构(看不懂的去看我反复提及的那篇文章),然后特征向量是影响因素,我对水稻的产量进行了一个预测,开始只用了两个影响因素温度和湿度测试一下,这两个就是特征向量,他们各随机产生100个数据,二八划分数据集,所以后面维度有80,所有的步骤上面我已经给好注释了,看不懂的多练,注释括号里面的是张量维度,上面注释掉的“cell_state = f_t * cell_state + (1-f_t) * c_hat_t”是Understanding LSTM Networks文章中的变体模型,叫什么窥视孔连接?但是我的项目实测效果没有区别,有兴趣的同学可以看看,强烈建议大家去看看Understanding LSTM Networks这篇文章,网上所有关于LSTM的经典图片都出自这里。文章到此完事,下面是一下效果展示(模型不存在任何问题,完全是根据LSTM标准实现的,问题是我的数据集太少了)

最后给一下效果图(这个是7个特征向量下对1个输出结果的预测效果,由于我的数据集少,仅仅150条,特征向量又多,所以不太准确,这是神经网络的通病,但是比其他封装好的LSTM强太多了,如果数据集少的情况下减少特征向量,结果会准确很多,第一次用两个特征向量的结果我丢失了,至少有90%的准确度,不信可以自己试试):

这是代码生产的两个特征向量的计算图(后面的我也看不懂哈哈哈哈)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

焚詩作薪

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值