股票预测pythonlstm_基于 lstm 的股票收盘价预测 -- python

开始导入 MinMaxScaler 时会报错 “from . import _arpack ImportError: DLL load failed: 找不到指定的程序。” (把sklearn更新下)和“AttributeError: module 'numpy' has no attribute 'testing'”,然后把numpy卸载重装(pip uninstall numpy; pip install numpy),问题解决。

#import datetime

import pandas as pd

import numpy as np

#from numpy import row_stack,column_stack

import tushare as ts

#import matplotlib

import matplotlib.pyplot as plt

#from matplotlib.pylab import date2num

#from matplotlib.dates import DateFormatter, WeekdayLocator, DayLocator, MONDAY,YEARLY

#from matplotlib.finance import quotes_historical_yahoo_ohlc, candlestick_ohlc

from sklearn.preprocessing import MinMaxScaler

#https://scikit-learn.org/stable/auto_examples/preprocessing/plot_all_scaling.html#sphx-glr-auto-examples-preprocessing-plot-all-scaling-py

from keras.models import Sequential

from keras.layers import LSTM, Dense, Activation

df=ts.get_hist_data('601857',start='2016-06-15',end='2018-01-12')

dd=df[['open','high','low','close']]

#print(dd.values.shape[0])

dd1=dd .sort_index()

dd2=dd1.values.flatten()

dd3=pd.DataFrame(dd1['close'])

def load_data(df, sequence_length=10, split=0.8):

#df = pd.read_csv(file_name, sep=',', usecols=[1])

#data_all = np.array(df).astype(float)

data_all = np.array(df).astype(float)

scaler = MinMaxScaler()

data_all = scaler.fit_transform(data_all)

data = []

for i in range(len(data_all) - sequence_length - 1):

data.append(data_all[i: i + sequence_length + 1])

reshaped_data = np.array(data).astype('float64')

#np.random.shuffle(reshaped_data)

# 对x进行统一归一化,而y则不归一化

x = reshaped_data[:, :-1]

y = reshaped_data[:, -1]

split_boundary = int(reshaped_data.shape[0] * split)

train_x = x[: split_boundary]

test_x = x[split_boundary:]

train_y = y[: split_boundary]

test_y = y[split_boundary:]

return train_x, train_y, test_x, test_y, scaler

def build_model():

# input_dim是输入的train_x的最后一个维度,train_x的维度为(n_samples, time_steps, input_dim)

model = Sequential()

model.add(LSTM(input_dim=1, output_dim=6, return_sequences=True))

#model.add(LSTM(6, input_dim=1, return_sequences=True))

#model.add(LSTM(6, input_shape=(None, 1),return_sequences=True))

"""

#model.add(LSTM(input_dim=1, output_dim=6,input_length=10, return_sequences=True))

#model.add(LSTM(6, input_dim=1, input_length=10, return_sequences=True))

model.add(LSTM(6, input_shape=(10, 1),return_sequences=True))

"""

print(model.layers)

#model.add(LSTM(100, return_sequences=True))

#model.add(LSTM(100, return_sequences=True))

model.add(LSTM(100, return_sequences=False))

model.add(Dense(output_dim=1))

model.add(Activation('linear'))

model.compile(loss='mse', optimizer='rmsprop')

return model

def train_model(train_x, train_y, test_x, test_y):

model = build_model()

try:

model.fit(train_x, train_y, batch_size=512, nb_epoch=300, validation_split=0.1)

predict = model.predict(test_x)

predict = np.reshape(predict, (predict.size, ))

except KeyboardInterrupt:

print(predict)

print(test_y)

print(predict)

print(test_y)

try:

fig = plt.figure(1)

plt.plot(predict, 'r:')

plt.plot(test_y, 'g-')

plt.legend(['predict', 'true'])

except Exception as e:

print(e)

return predict, test_y

if __name__ == '__main__':

#train_x, train_y, test_x, test_y, scaler = load_data('international-airline-passengers.csv')

train_x, train_y, test_x, test_y, scaler =load_data(dd3, sequence_length=10, split=0.8)

train_x = np.reshape(train_x, (train_x.shape[0], train_x.shape[1], 1))

test_x = np.reshape(test_x, (test_x.shape[0], test_x.shape[1], 1))

predict_y, test_y = train_model(train_x, train_y, test_x, test_y)

predict_y = scaler.inverse_transform([[i] for i in predict_y])

test_y = scaler.inverse_transform(test_y)

fig2 = plt.figure(2)

plt.plot(predict_y, 'g:')

plt.plot(test_y, 'r-')

plt.show()

参考资料:

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值