pytorch股票预测

#导包
import numpy as np
import torch
from torch.autograd import Variable
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

#声明随机种子
torch.manual_seed(777)

#读取数据
data=np.loadtxt('../data-02-stock_daily.csv',delimiter=',')

#顺序颠倒
data=data[::-1]

#归一化
min=MinMaxScaler()
data=min.fit_transform(data)

x=data
y=data[:,-1:]

# 7天为一组
length=7
num_class=5

data_x=[]
data_y=[]

for i in range(0,len(y)-length):
    x1=x[i:i+length]
    y1=y[i+length]
    data_x.append(x1)
    data_y.append(y1)

x_train,x_test,y_train,y_test=train_test_split(data_x,data_y,test_size=0.3,shuffle=False)

x_train=Variable(torch.Tensor(x_train).float())
y_train=Variable(torch.Tensor(y_train).float())
x_test=torch.Tensor(x_test).float()
y_test=torch.Tensor(y_test).float()

class LSTM(torch.nn.Module):
    def __init__(self, input_size):
        super(LSTM, self).__init__()
        self.h_size=input_size
        self.lstm=torch.nn.LSTM(input_size=self.h_size,hidden_size=self.h_size,num_layers=2,batch_first=True)
        self.fc=torch.nn.Linear(self.h_size,1)

    def forward(self,x):
        _,(h,_)=self.lstm(x)
        # 取lstm最后一层的最后一个时间步的输出,并展开成二维
        h=h[-1].view(-1,self.h_size)
        out=self.fc(h)
        return out

model=LSTM(num_class)

loss=torch.nn.MSELoss() # 股票预测是回归算法
optim=torch.optim.Adam(model.parameters(),lr=0.01)

for ei in range(200):
    optim.zero_grad()

    h=model(x_train)

    cost=loss(h,y_train)

    cost.backward()

    optim.step()

    if ei % 50 == 0:
        print(ei,'损失值:',cost.data.numpy())

#真实值与预测值画图
h_test=model(x_test)
plt.plot(h_test.data.numpy(),c='r')
plt.plot(y_test,c='b')
plt.show()
  • 2
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值