基于深度学习的股票预测(完整版,有代码)


在这里插入图片描述

数据获取

采用tushare的数据接口(不知道tushare的筒子们自行百度一下,简而言之其免费提供各类金融数据 , 助力智能投资与创新型投资。)
python可以直接使用pip安装tushare

!pip install tushare

Collecting tushare
  Downloading https://files.pythonhosted.org/packages/17/76/dc6784a1c07ec040e748c8e552a92e8f4bdc9f3e0e42c65699efcfee032b/tushare-1.2.62-py3-none-any.whl (214kB)
     |████████████████████████████████| 215kB 6.5MB/s 
Collecting simplejson>=3.16.0
  Downloading https://files.pythonhosted.org/packages/a8/04/377418ac1e530ce2a196b54c6552c018fdf1fe776718053efb1f216bffcd/simplejson-3.17.2-cp37-cp37m-manylinux2010_x86_64.whl (128kB)
     |████████████████████████████████| 133kB 30.0MB/s 
Collecting websocket-client>=0.57.0
  Downloading https://files.pythonhosted.org/packages/85/ee/7aa724dc2dbed9b028f463eada5482770c13b7381a0c79457d12b3b62de2/websocket_client-1.0.1-py2.py3-none-any.whl (68kB)
     |████████████████████████████████| 71kB 9.0MB/s 
Requirement already satisfied: requests>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from tushare) (2.23.0)
Requirement already satisfied: lxml>=3.8.0 in /usr/local/lib/python3.7/dist-packages (from tushare) (4.2.6)
Requirement already satisfied: bs4>=0.0.1 in /usr/local/lib/python3.7/dist-packages (from tushare) (0.0.1)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2.0.0->tushare) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2.0.0->tushare) (2020.12.5)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2.0.0->tushare) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2.0.0->tushare) (1.24.3)
Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.7/dist-packages (from bs4>=0.0.1->tushare) (4.6.3)
Installing collected packages: simplejson, websocket-client, tushare
Successfully installed simplejson-3.17.2 tushare-1.2.62 websocket-client-1.0.1

在tushare平台进行账号注册,获取自己的接口token码

import tushare as ts
import pandas as pd
import os
import time
import glob

pro = ts.pro_api('your token')
# ----------------------下载某只股票数据------------------- #
# code:股票编码 日期格式:2019-05-21 filename:写到要存放数据的根目录即可如D:\data\
# length是筛选股票长度,默认值为False,既不做筛选,可人为指定长度,如200,既少于200天的股票不保存
def get_stock_data(code, date1, date2, filename, length=-1):
    df = pro.daily(ts_code=code, start_date=date1, end_date=date2)
    df1 = pd.DataFrame(df)
    df1 = df1[['trade_date','open', 'high', 'close', 'low', 'vol', 'pct_chg']]
    df1 = df1.sort_values(by='trade_date')
    print('共有%s天数据' % len(df1))
    if(len(df1)<length):
        path = code+ '.csv'
        df1.to_csv(os.path.join(filename, path))
# ------------------------更新股票数据------------------------ #
# 将股票数据从本地文件的最后日期更新至当日
# filename:具体到文件名如d:\data\000001.csv
def update_stock_data(filename):
    (filepath, tempfilename) = os.path.split(filename)
    (stock_code, extension) = os.path.splitext(tempfilename)
    f = open(filename, 'r')
    df = pd.read_csv(f)
    print('股票{}文件中的最新日期为:{}'.format(stock_code, df.iloc[-1, 1]))
    data_now = time.strftime('%Y%m%d', time.localtime(time.time()))
    print('更新日期至:%s' % data_now)
    nf = pro.daily(ts_code=stock_code, start_date=str(df.iloc[-1, 1]), end_date=data_now)
    nf = nf.sort_values(by='trade_date')
    nf = nf.iloc[1:]
    print('共有%s天数据' % len(nf))
    nf = pd.DataFrame(nf)
    nf = nf[['trade_date','open', 'high', 'close', 'low', 'vol', 'pct_chg']]
    nf.to_csv(filename, mode='a', header=False)
    f.close()
# ------------------------获取股票长度----------------------- #
# 辅助函数
def get_data_len(file_path):
	with open(file_path) as f:
		df = pd.read_csv(f)
		return len(df)
# --------------------------文件合并------------------------- #
# 将多个文件合并为一个文件,在文件末尾添加
# filename是需要合并的文件夹,tfile是存放合并后文件的文件夹
def merge_stock_data(filename, tfile):
	csv_list = glob.glob(filename + '*.csv')
	print(u'共发现%s个CSV文件' % len(csv_list))
	f = open(csv_list[0])
	df = pd.read_csv(f)
	for i in range(1, len(csv_list)):
		f1 = open(csv_list[i], 'rb')
		df1 = pd.read_csv(f1)
		df = pd.concat([df, df1])
	df.to_csv(tfile+'train_mix.csv', index=None)

通过循环的方式获取沪深股票编码从000001SZ-000999SZ的每日线行情
分别存入对应的csv文件中

for i in range(1,1000):
    name='{:0>6d}'.format(i)+'.SZ'
    get_stock_data(name, '20150101', '20210529', '/content/gdrive/MyDrive/代码/股票/data/原始数据//',200)

数据转换

# The data path is at:
name='000995.SZ'
PATH = '/content/gdrive/MyDrive/代码/股票/data/原始数据/'+name+'.csv'
STEP = 60

from torch.utils.data import Dataset, DataLoader

class StockData (Dataset):
    def __init__ (self, path:str, step:int = 30):
        self.path = path
        self.step = step

        data = pd.read_csv(path).values[1:,2:]
        print(len(data))

        self.len = len(data)

        self.y_max = data[:,3].max()
        self.y_min = data[:,3].min()

        # visualise the data
        data = self.normalise(data)+1e-5

        #print(data[:-1,:])
        self.X = torch.tensor(data[:-1,:].astype(np.float32))
        self.y = torch.tensor(data[1:,3].astype(np.float32))

        plt.plot(self.y)
        plt.show()


    def __getitem__ (self, index):
        return self.X[index:index+self.step], self.y[index+self.step]

    def __len__ (self):
        return self.len-1-self.step

    def normalise (self, data):
        data = data.T
        for i in range(len(data)):
            data_min = data[i].min()
            data_max = data[i].max()
            data[i] = (data[i] - data_min) / (data_max - data_min)
        return data.T


stock_data = StockData(path = PATH, step = STEP)
data = DataLoader(dataset = stock_data, batch_size = 5, shuffle = False)
print(len(data))

在这里插入图片描述

LSTM模型搭建

# now let us write a LSTM model
import torch.nn as nn
import torch.nn.functional as F

class Net (nn.Module):
	
	def __init__ (self, input_size = 6, hidden_size = 20, output_size = 1, layers = 3):
		super().__init__()
		self.lstm = nn.LSTM(input_size, hidden_size, layers, batch_first = True, bidirectional = True)
		self.linear = nn.Linear(hidden_size*2, output_size)
		self.function = torch.sigmoid

	def forward(self, X):
		X, hidden = self.lstm(X, None)
		X = X[:,-1,:]
		X = self.linear(X)
		X = self.function(X)
		return X

net = Net()

训练模型

取数据的前80%进行训练

# now Train the model
import torch.optim as op

# LR = 0.01
criteria = nn.MSELoss()
optimiser = op.Adam(net.parameters())
EPCHO = 10


for epcho in range(EPCHO):
	for i, Xy in enumerate (data):

		if i == len(data)*0.8: break

		X = Xy[0]
		y = Xy[1]

		predict = net(X)
		optimiser.zero_grad()
		loss = criteria(predict,y)
		loss.backward()
		optimiser.step()

	print('Epcho: {}.......... loss is {}'.format(epcho,loss))

torch.save(net, '/content/gdrive/MyDrive/代码/股票/model/stock_predict'+name+'.pkl')

预测结果

net = torch.load('/content/gdrive/MyDrive/代码/股票/model/stock_predict'+name+'.pkl')
predict = np.array([])
actual = np.array([])

torch.no_grad() 


for X,y in data:
	predict = np.append(predict, net(X).data[0,0])
	actual = np.append(actual, y.data[0])


plt.plot(predict, label = 'prediction')
plt.plot(actual, label = 'actual')
plt.vlines(len(data)*0.8,0,1,color="red")#竖线
plt.title('stock_predict step = '+ str(STEP))
save_name='/content/gdrive/MyDrive/代码/股票/picture/'+name+'_prediction.png'
plt.legend()
plt.savefig(save_name,dpi = 600)
plt.rcParams['figure.figsize'] = 20, 10
plt.show()

在这里插入图片描述

  • 18
    点赞
  • 292
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 16
    评论
股票预测是金融领域中的一个重要问题,其目的是通过分析历史市场数据来预测未来的股票价格趋势。传统的股票预测方法通常基于统计模型或机器学习算法,但它们通常不能处理非线性关系和高维数据,因此难以获得准确的预测结果。 随着深度学习的发展,越来越多的研究者开始使用深度学习技术来解决股票预测问题。深度学习是一种基于神经网络的机器学习方法,它可以处理非线性关系和高维数据,并且可以自动提取特征。因此,深度学习股票预测中具有广泛的应用前景。 目前,常用的深度学习模型包括循环神经网络(RNN)、长短期记忆网络(LSTM)、卷积神经网络(CNN)和深度置信网络(DBN)等。这些模型可以通过对历史市场数据进行训练来预测未来的股票价格趋势。 在进行深度学习股票预测时,需要注意以下几点: 1. 数据预处理:数据预处理是深度学习股票预测的重要步骤。在进行数据预处理时,需要对数据进行归一化、平滑处理和特征提取等操作,以便更好地为模型提供输入数据。 2. 模型选择:不同的深度学习模型适用于不同的股票预测问题。在选择模型时,需要考虑输入数据的特点、时间序列数据的长度和预测的时间跨度等因素。 3. 参数优化:深度学习模型中有很多参数需要进行优化,包括学习率、批量大小、迭代次数等。优化这些参数可以提高模型的预测性能。 4. 模型评估:模型评估是深度学习股票预测的最后一步。在评估模型时,可以使用交叉验证、均方误差、平均绝对误差等指标来评估模型的预测性能。 总之,深度学习股票预测中具有广泛的应用前景。通过对历史市场数据进行训练,深度学习模型可以预测未来的股票价格趋势,从而帮助投资者做出更加明智的投资决策。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

取个名字真难啊啊

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值