股价预测-LSTM网络

# coding=utf-8

import os
import time
import math
import numpy as np
import pandas as pd
import tushare as ts
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dropout,Dense,LSTM
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error,mean_absolute_error

INTnpseed=19740425

################################################################################
### 爬取指定股票从指定时间至当前最新的K线数据,输入股票代码、K线种类、开始时间、结束时间,保存到文件,无输出
### 保存到csv文件的数据实例:
###  ,      date, open,close, high,  low,   volume,  code
### 0,2010-01-04, 7.43,7.538,7.707,7.345, 66162.85,600888
### 1,2010-01-05,7.604,8.291,8.291, 7.58,227146.79,600888
### 2,2010-01-06,8.437,9.119,9.119,8.375,318151.07,600888
################################################################################
def FUNCgetstockdata(STRcode,STRtype,STRbegindate,STRenddate):
    FILEhandle=ts.get_k_data(STRcode,ktype=STRtype,start=STRbegindate,end=STRenddate)
    FILEname="./"+STRcode+".csv"
    FILEhandle.to_csv(FILEname)

################################################################################
### 获取当天日期的字符串,无输入,返回10位长字符串
################################################################################
def FUNCgetyyyymmdd():
    STRtoday=time.strftime("%Y%m%d",time.gmtime())                              # 获取当前年月日,如:20200512
    STRyyyymmdd=STRtoday[0:4]+"-"+STRtoday[4:6]+"-"+STRtoday[6:]                # 构造日期字符串,如:2020-05-12
    return STRyyyymmdd

################################################################################
### 从文件读取指定股票的K线数据,输入股票代码,输出列表数据
### 循环神经网络,输入训练数据集,测试数据集
################################################################################
def FUNClstmnet(STRcode):
    FUNCgetstockdata(STRcode,"D","2000-01-01",FUNCgetyyyymmdd())                # 爬取数据
    FILEname="./"+STRcode+".csv"
    STOCKLIST=pd.read_csv(FILEname)                                             # 读取数据
    train_set=STOCKLIST.iloc[0:int(len(STOCKLIST)*0.7),2:3]                     # 取前70%记录的第2列,即开盘价
    test_set =STOCKLIST.iloc[int(len(STOCKLIST)*0.7): ,2:3]                     # 取后30%记录的第2列,即开盘价
    sc=MinMaxScaler(feature_range=(0,1))                                        # 进行归一化处理
    ### fit(): Method calculates the parameters μ and σ and saves them as internal objects.
    ### 解释:简单来说,就是求得训练集X的均值,方差,最大值,最小值,这些训练集X固有的属性。
    ### transform(): Method using these calculated parameters apply the transformation to a particular dataset.
    ### 解释:在fit的基础上,进行标准化,降维,归一化等操作(看具体用的是哪个工具,如PCA,StandardScaler等)。
    ### fit_transform(): joins the fit() and transform() method for transformation of dataset.
    ### 解释:fit_transform是fit和transform的组合,既包括了训练又包含了转换。
    ### transform()和fit_transform()二者的功能都是对数据进行某种统一处理(比如标准化~N(0,1),将数据缩放(映射)到某个固定区间,归一化,正则化等)
    ### fit_transform(trainData)对部分数据先拟合fit,找到该part的整体指标,如均值、方差、最大值最小值等等(根据具体转换的目的),然后对该trainData进行转换transform,从而实现数据的标准化、归一化等等。
    ### transform()和fit_transform()的运行结果是一样的,运行结果一模一样不代表这两个函数可以互相替换,绝对不可以
    ### transform函数是一定可以替换为fit_transform函数的,fit_transform函数不能替换为transform函数
    train_set_scaled=sc.fit_transform(train_set)
    ###         >>> train_set       >>> train_set_scaled
    ###                  open         
    ###           0     1.547         array([[0.02520901],
    ###           1     1.678                [0.03356947],
    ###           2     1.607                [0.02903823],
    ###           ...     ...                ...,         
    ###           3356  4.105                [0.18846129],
    ###           3357  4.203                [0.19471568],
    ###           3358  4.128                [0.18992916]])
    test_set        =sc.transform(test_set)
    ###          >>> test_set        >>> test_set
    ###                  open
    ###           3359  4.270         array([[0.19899164],
    ###           3360  4.233                [0.19663029],
    ###           3361  4.113                [0.18897186],
    ###           ...     ...                ...,
    ###           4796  4.510                [0.21430851],
    ###           4797  4.510                [0.21430851],
    ###           4798  4.670                [0.22451975]]) 
    x_train=[]
    y_train=[]
    x_test =[]
    y_test =[]
    for i in range(60,len(train_set_scaled)):
        x_train.append(train_set_scaled[i-60:i,0])
        y_train.append(train_set_scaled[i,0])
    np.random.seed(INTnpseed)
    np.random.shuffle(x_train)
    np.random.seed(INTnpseed)
    np.random.shuffle(y_train)
    np.random.seed(INTnpseed)
    x_train,y_train=np.array(x_train),np.array(y_train)
    x_train=np.reshape(x_train,(x_train.shape[0],60,1))
    for i in range(60,len(test_set)):
        x_test.append(test_set[i-60:i,0])
        y_test.append(test_set[i,0])
    x_test,y_test=np.array(x_test),np.array(y_test)
    x_test=np.reshape(x_test,(x_test.shape[0],60,1))  

    model=tf.keras.Sequential([
        LSTM(80,return_sequences=True),
        Dropout(0.2),
        LSTM(100),
        Dropout(0.2),
        Dense(1)])
    model.compile(
        optimizer=tf.keras.optimizers.Adam(0.01),
        loss="mean_squared_error",
        metrics=["sparse_categorical_accuracy"])
    checkpoint_save_path="./checkpoint.stock.lstm/stock.lstm.ckpt"
    if os.path.exists(checkpoint_save_path+".index"):
        print("--------------------加载模型--------------------")
        model.load_weights(checkpoint_save_path)
    cp_callback=tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_save_path,
        save_weights_only=True,
        save_best_only=True,
        monitor="val_loss")
    history=model.fit(
        x_train,
        y_train,
        batch_size=128,
        epochs=100,
        validation_data=(x_test,y_test),
        validation_freq=1,
        callbacks=[cp_callback])
    model.summary()
    print(model.trainable_variables)
    file=open('./weights.stock.lstm.txt', 'w')
    for v in model.trainable_variables:
        file.write(str(v.name) + '\n')
        file.write(str(v.shape) + '\n')
        file.write(str(v.numpy()) + '\n')
    file.close()

    ### 显示训练集和验证集的acc和loss曲线
    acc     =history.history['sparse_categorical_accuracy']
    loss    =history.history['loss']
    val_loss=history.history['val_loss']

    plt.rcParams['font.sans-serif'] = ['simhei']

    plt.plot(acc,     label=u'训练准确率')
    plt.plot(loss,    label=u'训练错误率')
    plt.plot(val_loss,label=u'验证错误率')
    plt.title(u'LSTM网络训练和验证错误率')
    plt.legend()
    plt.show()
    
    predicted_stock_price=model.predict(x_test)
    predicted_stock_price=sc.inverse_transform(predicted_stock_price)           # 反归一化
    real_stock_price=sc.inverse_transform(test_set[60:])

    plt.plot(real_stock_price,color="red",label=u"实际股价")
    plt.plot(predicted_stock_price,color="blue",label=u"预测股价")
    plt.title(u"LSTM网络股价预测")
    plt.xlabel(u"时间")
    plt.ylabel(u"股价")
    plt.legend()
    plt.show()

    mse =mean_squared_error(predicted_stock_price,real_stock_price)
    rmse=math.sqrt(mean_squared_error(predicted_stock_price,real_stock_price))
    mae =mean_absolute_error(predicted_stock_price,real_stock_price)
    print("均方误差:{:>0.6f},均方根误差:{:>0.6f},平均绝对误差:{:>0.6f}".format(mse,rmse,mae))

################################################################################
### 使用已训练好的循环神经网络,输入前60天股票价格,预测最新估价
################################################################################
def FUNCuselstmnet(STRcode):
    FUNCgetstockdata(STRcode,"D","2000-01-01",FUNCgetyyyymmdd())                # 爬取数据
    FILEname="./"+STRcode+".csv"
    STOCKLIST=pd.read_csv(FILEname)                                             # 读取数据
    train_set=STOCKLIST.iloc[len(STOCKLIST)-60:,2:3]                            # 取前60天开盘价
    sc=MinMaxScaler(feature_range=(0,1))                                        # 进行归一化处理
    train_set_scaled=sc.fit_transform(train_set)
    x_train=[]
    for i in range(60,len(train_set_scaled)+1):
        x_train.append(train_set_scaled[i-60:i,0])
    x_train=np.array(x_train)
    x_train=np.reshape(x_train,(x_train.shape[0],60,1))
    model=tf.keras.Sequential([
        LSTM(80,return_sequences=True),
        Dropout(0.2),
        LSTM(100),
        Dropout(0.2),
        Dense(1)])
    model.compile(
        optimizer=tf.keras.optimizers.Adam(0.01),
        loss="mean_squared_error",
        metrics=["sparse_categorical_accuracy"])
    checkpoint_save_path="./checkpoint.stock.lstm/stock.lstm.ckpt"
    if os.path.exists(checkpoint_save_path+".index"):
        print("--------------------加载模型--------------------")
        model.load_weights(checkpoint_save_path)
    cp_callback=tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_save_path,
        save_weights_only=True,
        save_best_only=True,
        monitor="val_loss")
    predicted_stock_price=model.predict(x_train)
    predicted_stock_price=sc.inverse_transform(predicted_stock_price)           # 反归一化
    print("LSTM网络预测"+STRcode+"股价为:"+str(predicted_stock_price))
    time.sleep(5)

while True:
    os.system("clear")
    print("股票预测程序LSTM网络>>>>>>>>")
    print("        L600888............训练600888新疆众和数据")
    print("        P600888............预测600888新疆众和估价")
    print("        L600289............训练600289亿阳信通数据")
    print("        P600289............预测600289亿阳信通估价")
    print("        L002208............训练002208合肥城建数据")
    print("        P002208............预测002208合肥城建估价")
    print("        Quit...............退出系统")
    STRinput=input("        >>>>>>>>请输入选择项:")
    STRinput=STRinput.upper()                                                   # 将输入项转换为大写
    if STRinput=="L600888":
        FUNClstmnet("600888")
    elif STRinput=="P600888":
        FUNCuselstmnet("600888")
    elif STRinput=="L600289":
        FUNClstmnet("600289")
    elif STRinput=="P600289":
        FUNCuselstmnet("600289")
    elif STRinput=="L002208":
        FUNClstmnet("002208")
    elif STRinput=="P002208":
        FUNCuselstmnet("002208")
    elif STRinput=="QUIT":
        break
    else:
        continue


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值