线性回归预测股票价格
本文将演示如何使用Python中的scikit-learn
库来预测股票价格。我们将使用线性回归模型,并基于一家公司的历史股票价格来预测未来的股票价格。
数据准备
首先,我们需要导入一些必要的库,并加载股票价格的数据。这里使用的数据是000001
的股票价格数据。
from datetime import datetime as dt
import numpy as np
import pandas as pd
df = pd.read_csv('000001.csv')
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date')
df.sort_index(inplace=True, ascending=True)
df.dropna(axis=0, inplace=True)
在这段代码中,我们首先读取CSV文件中的数据,然后将date
列转换为日期类型,并将其设置为索引。接着,我们按日期对数据进行排序,并删除存在缺失值的行。
我们可以看到数据的时间范围如下:
min_date = df.index.min()
max_date = df.index.max()
print(min_date)
print(max_date)
print(max_date - min_date)
输出结果:
2021-01-11 00:00:00
2023-07-10 00:00:00
910 days 00:00:00
这表示我们的数据集包含从2021年1月11日到2023年7月10日期间的股票价格数据,共计910天。
数据可视化
在进行预测之前,我们可以先通过绘制k线图来观察股票价格的变化趋势。我们可以使用plotly
库来绘制k线图。
from plotly import tools
import plotly.graph_objs as go
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode()
trace = go.Candlestick(x=df.index, open=df['open'], high=df['high'], low=df['low'], close=df['close'])
data = [trace]
iplot(data, filename='simple_candlestick')
线性回归模型
接下来,我们将开始构建线性回归模型。首先,我们需要创建一个新的列label
,它将包含我们要预测的值,即5天后的收盘价。
from sklearn.linear_model import LinearRegression
from sklearn import preprocessing
num = 5
df['label'] = df['close'].shift(-num)
然后,我们将丢弃label
、price_change
、和p_change
列,因为这些列与收盘价密切相关,可能会引起数据的线性相关性。
Data = df.drop(['price_change', 'p_change', 'label'], axis=1)
Data.tail()
再构造数据集:
X = Data.values
X = preprocessing.scale(X)
X = X[:-num]
df.dropna(inplace=True)
Target = df.label
Y = Target.values
print(np.shape(X), np.shape(Y))
# 创建训练集和测试集,先减去5天的数据,然后将数据分为训练集和测试集
X_train = X[:int(0.8 * len(X))]
X_test = X[int(0.8 * len(X)):]
Y_train = Y[:int(0.8 * (len(Y)))]
Y_test = Y[int(0.8 * len(Y)):int(len(Y))]
print(np.shape(X_train), np.shape(X_test), np.shape(Y_train), np.shape(Y_test))
创建线性回归模型并进行拟合
# 创建线性回归模型
model = LinearRegression()
model.fit(X_train, Y_train)
评估模型效果:
# 评估模型R^2
score = model.score(X_test, Y_test)
print(score)
预测后五天的股票收盘价:
# 预测
X_predict = X[-num:]
forecast = model.predict(X_predict)
print(forecast)
数据可视化:
# 画出预测值和真实值的对比图,观察预测值和真实值的差异
import matplotlib.pyplot as plt
plt.plot(forecast, label='predict')
plt.plot(Y[-num:], label='true')
# 以xxxx-xx-xx的形式显示日期
plt.xticks(np.arange(0, num, 1), [dt.strftime(x, '%Y-%m-%d') for x in df.index[-num:]])
plt.legend()
plt.show()
模型参数可视化:
# 画出每个特征的系数,观察每个特征对预测值的影响
coef = pd.DataFrame(model.coef_, index=Data.columns, columns=['coef'])
coef.plot(kind='bar')
plt.show()
for idx, col_name in enumerate(Data.columns):
print("The coefficient for {} is {}".format(col_name, model.coef_[idx]))