前不久项目使用到了一个LSTM模型,让GPT写了一个结果用的是pytorch的封装好的模型,网上大多数博客也是这样为了博取流量这样搞出来的模型完全用不了,所以我根据提出LSTM模型的Understanding LSTM Networks这篇文章一步一步搭建一个LSTM模型,文章链接:Understanding LSTM Networks -- colah's blog
注意,本篇博客只适合对LSTM模型有基础了解的同学,不了解的先看Understanding LSTM Networks文章,我的代码完全基于该文章写的,所以一定要先看这篇文章,这非常重要!!!
接下来直接给代码:
class RiceLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RiceLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.Wf = nn.Linear(input_size + hidden_size, hidden_size)
self.Wi = nn.Linear(input_size + hidden_size, hidden_size)
self.Wo = nn.Linear(input_size + hidden_size, hidden_size)
self.Wc = nn.Linear(input_size + hidden_size, hidden_size)#12 -> 10(以下标注均为隐藏层大小10,特征向量2的情况)
self.output_layer = nn.Linear(hidden_size, output_size)
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
def forward(self, input):
batch_size = input.size(0)
seq_len = input.size(1)
hidden_state = torch.zeros(batch_size, self.hidden_size, dtype=torch.float32)#(1,10)
cell_state = torch.zeros(batch_size, self.hidden_size, dtype=torch.float32)
outputs = []
for i in range(seq_len):
combined = torch.cat((input[:, i, :], hidden_state), dim=1)#(1,2)(时间步,特征)+(1,10)=(1,12)
f_t = self.sigmoid(self.Wf(combined))#(1,10)
i_t = self.sigmoid(self.Wi(combined))
o_t = self.sigmoid(self.Wo(combined))
c_hat_t = self.tanh(self.Wc(combined))
# cell_state = f_t * cell_state + (1-f_t) * c_hat_t#(1,10)
cell_state = f_t * cell_state + i_t * c_hat_t # (1,10)
hidden_state = o_t * self.tanh(cell_state)#(1,10)
outputs.append(hidden_state.unsqueeze(1))#(1,+1,10)
outputs = torch.cat(outputs, dim=1)#80个outputs (1,80,10)
final_output = self.output_layer(outputs)#(1,80,1)
return final_output, (hidden_state, cell_state)#hidden_state、cell_state大小没变
上述模型是一个最简单的LSTM模型,代码看上去头大是吧,首先要看懂代码结构(看不懂的去看我反复提及的那篇文章),然后特征向量是影响因素,我对水稻的产量进行了一个预测,开始只用了两个影响因素温度和湿度测试一下,这两个就是特征向量,他们各随机产生100个数据,二八划分数据集,所以后面维度有80,所有的步骤上面我已经给好注释了,看不懂的多练,注释括号里面的是张量维度,上面注释掉的“cell_state = f_t * cell_state + (1-f_t) * c_hat_t”是Understanding LSTM Networks文章中的变体模型,叫什么窥视孔连接?但是我的项目实测效果没有区别,有兴趣的同学可以看看,强烈建议大家去看看Understanding LSTM Networks这篇文章,网上所有关于LSTM的经典图片都出自这里。文章到此完事,下面是一下效果展示(模型不存在任何问题,完全是根据LSTM标准实现的,问题是我的数据集太少了)
最后给一下效果图(这个是7个特征向量下对1个输出结果的预测效果,由于我的数据集少,仅仅150条,特征向量又多,所以不太准确,这是神经网络的通病,但是比其他封装好的LSTM强太多了,如果数据集少的情况下减少特征向量,结果会准确很多,第一次用两个特征向量的结果我丢失了,至少有90%的准确度,不信可以自己试试):
这是代码生产的两个特征向量的计算图(后面的我也看不懂哈哈哈哈)