pytorch实现LSTM回归代码分享

最近正在学习RNN相关的知识,并尝试使用LSTM网络实现回归分析。刚刚接触RNN相关网络的上手难度比较大,首先从CSDN上寻找相关的代码并没有找到比较满意的。这几天终于把LSTM相关网络调试通过现在把我的代码及数据集开源,供大家学习参考。

LSTM回归算法代码分享

LSTM简介

参考相关博客链接: LSTM这一篇就够了.,在这里不再介绍相关的理论知识。我对这个网络的理解:如果某个信号的时间相关性强,那么RNN相关网络的训练效果应该会比较好。

数据集介绍

数据集来源于相关课题,课题的内容不便透漏,在这里展示数据集的相关图像数据集图像展示
数据集的横坐标表示时间,纵坐标表示振幅。可以看出数据的波动性很大,并且根据实际的物理背景很容易得知第n+1时刻的输出与第n时刻的输出之间有着密切联系,故选定LSTM这个长短时间记忆网络。

前期经过DNN网络训练过,DNN的网络结构为一个输入层(1 Net),三个隐藏层(10* 50 *10 Net),一个输出层(1 Net)。网络结构及训练结果如下图所示:
网络结构与训练结果
左图展示的为DNN网络结构,右图展示数据拟合效果。蓝色代表实际数据,红色线条代表对数据的拟合情况。可以看出红色线条对蓝色的数据点能够大体拟合但是丢掉了很多震荡数据。这个可能是由于调参不当或者训练次数不够导致。通过调参应该也能达到更好的一个结果。

使用LSTM网络也进行相关的实验,搭建了一个很简单的单输入单输出网络,训练结果如图所示:
RNN训练结果
刚开始的时候会有些许振荡现象,后来数据能够基本符合图片一展示的数据集。数据集图片一与训练出的结果对比图展示如下:

将第三张图片部分截取进行对比,可以看出,经过训练的LSTM网络能够很好的预测实验数据。保留了数据中的波动情况。

代码展示

"""
user:liujie
time:2020.10.07
"""
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
#引入相关的文件及数据集,数据集的数据来源于csv文件
import pandas as pd                             #导入pandas包
#csv文件读取
data = pd.read_csv("patientdata.csv")
#读取的文件首先进行列表化并转置。随后转存为float64的格式,默认格式为flaot32
data = np.transpose(np.array(data)).astype(np.float32)
x_data = data[0, :3000]                         #数据切片,x_data表示自变量
y_data = data[1, :3000]                         #数据切片,y_data表示因变量
# 设置超参数
input_size = 1                                  #定义超参数输入层,输入数据为1维
output_size = 1                                 #定义超参数输出层,输出数据为1维
num_layers = 1                                  #定义超参数rnn的层数,层数为1层
hidden_size = 32                                #定义超参数rnn的循环神经元个数,个数为32个
learning_rate = 0.02                            #定义超参数学习率
train_step = 1000                                #定义训练的批次,3000个数据共训练1000次,
time_step = 3                                  #定义每次训练的样本个数每次传入3个样本
h_state = None                                  #初始化隐藏层状态
use_gpu = torch.cuda.is_available()             #使用GPU加速训练
class RNN(nn.Module):
    """搭建rnn网络"""
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(RNN, self).__init__()
        self.rnn = nn.RNN(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,)                  #传入四个参数,这四个参数是rnn()函数中必须要有的
        self.output_layer = nn.Linear(in_features=hidden_size, out_features=output_size)
    def forward(self, x, h_state):
        # x (batch, time_step, input_size)
        # h_state (n_layers, batch, hidden_size)
        # rnn_out (batch, time_step, hidden_size)
        rnn_out, h_state = self.rnn(x, h_state)     #h_state是之前的隐层状态
        out = []
        for time in range(rnn_out.size(1)):
            every_time_out = rnn_out[:, time, :]    #相当于获取每个时间点上的输出,然后过输出层
            out.append(self.output_layer(every_time_out))
        return torch.stack(out, dim=1), h_state     #torch.stack扩成[1, output_size, 1]
# 显示由csv提供的样本数据图
plt.figure(1)
plt.plot(x_data, y_data, 'r-', label='target (Ca)')
plt.legend(loc='best')
plt.show()
#对CLASS RNN进行实例化时向其中传入四个参数
rnn = RNN(input_size, hidden_size, num_layers, output_size)
# 设置优化器和损失函数
#使用adam优化器进行优化,输入待优化参数rnn.parameters,优化学习率为learning_rate
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
loss_function = nn.MSELoss()                #损失函数设为常用的MES均方根误差函数
plt.figure(2)                               #新建一张空白图片2
plt.ion()
# 按照以下的过程进行参数的训练
for step in range(train_step):
    start, end = step*time_step, (step+1)*time_step#
    steps = np.linspace(start, end, (end-start), dtype=np.float32)#该参数仅仅用于画图过程中使用
    x_np = x_data[start:end]        #按照批次大小从样本中切片出若干个数据,用作RNN网络的输入
    y_np = y_data[start:end]        #按照批次大小从样本中切片出若干个数据,用作与神经网络训练的结果对比求取损失
    x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis])
    y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
    pridect, h_state = rnn.forward(x, h_state)
    h_state = h_state.detach()     # 重要!!! 需要将该时刻隐藏层的状态作为下一时刻rnn的输入

    loss = loss_function(pridect, y)#求解损失值,该损失值用于后续参数的优化
    optimizer.zero_grad()           #优化器的梯度清零,这一步必须要做

    loss.backward()                 #调用反向传播网络对损失值求反向传播,优化该网络
    optimizer.step()                #调用优化器对rnn中所有有关参数进行优化处理

    plt.plot(steps, pridect.detach().numpy().flatten(), 'b-')
    plt.draw()
    plt.pause(0.05)
    plt.ioff()
    plt.show()



这个代码使用的python3.6+pytorch开发,进行了比较详细的注释,希望能够抛砖引玉,有问题还请大家不吝赐教

  • 9
    点赞
  • 89
    收藏
    觉得还不错? 一键收藏
  • 12
    评论
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值