股票预测模型的改良

import numpy as np
import pandas as pd
from sklearn import svm
from sklearn.linear_model import LogisticRegression
import tushare as ts
from sklearn import cross_validation
data=ts.get_k_data('600000',start='2007-01-01',end='2018-04-13')
print(data.head())
print(data.tail())
         date   open  close   high    low     volume    code
0  2007-01-04  3.702  3.670  4.009  3.413  508894.23  600000
1  2007-01-05  3.670  3.562  3.670  3.338  357055.86  600000
2  2007-01-08  3.557  3.634  3.708  3.525  254888.76  600000
3  2007-01-09  3.596  3.897  3.916  3.591  329619.46  600000
4  2007-01-10  3.904  4.041  4.152  3.904  352768.36  600000
            date   open  close   high    low    volume    code
2686  2018-04-09  11.53  11.50  11.59  11.49  167224.0  600000
2687  2018-04-10  11.52  11.77  11.79  11.51  287482.0  600000
2688  2018-04-11  11.79  11.91  12.02  11.75  312985.0  600000
2689  2018-04-12  11.91  11.78  11.96  11.76  188242.0  600000
2690  2018-04-13  11.83  11.69  11.89  11.69  140948.0  600000
data_SZ_index=ts.get_k_data('000001',index=True,start='2007-01-01',end='2018-04-13')
print(data_SZ_index.head())
print(data_SZ_index.tail())
         date     open    close     high      low       volume      code
0  2007-01-04  2728.19  2715.72  2847.61  2684.82  120156000.0  sh000001
1  2007-01-05  2668.58  2641.33  2685.80  2617.02  106156000.0  sh000001
2  2007-01-08  2621.07  2707.20  2708.44  2620.62  106813000.0  sh000001
3  2007-01-09  2711.05  2807.80  2809.39  2691.36  110751000.0  sh000001
4  2007-01-10  2838.11  2825.58  2841.74  2770.99  111769000.0  sh000001
            date     open    close     high      low       volume      code
2738  2018-04-09  3125.44  3138.29  3146.09  3110.30  139608621.0  sh000001
2739  2018-04-10  3144.26  3190.32  3190.65  3139.08  168201359.0  sh000001
2740  2018-04-11  3197.37  3208.08  3220.85  3191.59  175867197.0  sh000001
2741  2018-04-12  3203.28  3180.16  3205.25  3177.05  148231313.0  sh000001
2742  2018-04-13  3192.04  3159.05  3197.90  3155.51  127552310.0  sh000001
num_data=len(data)
num_SZ_index=len(data_SZ_index)
print(num_data,num_SZ_index)#股票会有停牌天数,但指数不会
(2691, 2743)
from datetime import datetime
data['date'] = [datetime.strptime(x,'%Y-%m-%d') for x in data['date']]
data['date'].head()
0   2007-01-04
1   2007-01-05
2   2007-01-08
3   2007-01-09
4   2007-01-10
Name: date, dtype: datetime64[ns]
data_SZ_index['date'] = [datetime.strptime(x,'%Y-%m-%d') for x in data_SZ_index['date']]
subdata_SZ_index=data_SZ_index[data_SZ_index['date'].isin(data['date'])]#数据对齐
sub_index_open=subdata_SZ_index['open'].values  #z做了对齐之后,丢失一部分大盘数据对index有影响,直接取数据部分
sub_index_close=subdata_SZ_index['close'].values
col_index=[]
y=[]
data_open=data['open']
data_close=data['close']

for i in xrange(2691):
    if sub_index_close[i]>=sub_index_open[i]:
        col_index.append(1)
    else:
        col_index.append(0)
    if data_close[i]>=data_open[i]:
        y.append(1)
    else:
        y.append(0)
x_data=data[['open','close','high','low','volume']].as_matrix()
x=np.c_[x_data,col_index]#将大盘指数的涨跌合并到特征值中
data_shape=x.shape
data_rows=data_shape[0]
data_cols=data_shape[1]
data_col_max=x.max(axis=0)
data_col_min=x.min(axis=0)
print(data_col_max,data_col_min)
(array([  1.36700000e+01,   1.37600000e+01,   1.40200000e+01,
         1.35300000e+01,   1.19802410e+07,   1.00000000e+00]), array([  2.42500000e+00,   2.47000000e+00,   2.65400000e+00,
         2.41600000e+00,   2.89912100e+04,   0.00000000e+00]))
for i in xrange(0, data_rows, 1):#将输入数组归一化
    for j in xrange(0, data_cols, 1):
        x[i][j] = \
            (x[i][j] - data_col_min[j]) / \
            (data_col_max[j] - data_col_min[j])
print(x[0:2])
[[ 0.11356158  0.10628875  0.1192152   0.08970668  0.04015505  0.        ]
 [ 0.11071587  0.09672276  0.08938941  0.08295843  0.02745024  0.        ]]
y=y[1:2691]
x=x[0:2690]
clf1 = svm.SVC(kernel='rbf')
clf2 = LogisticRegression()
result1 = []
result2 = []
for i in range(5):
    # x和y的验证集和测试集,切分80-20%的测试集
    x_train, x_test, y_train, y_test = \
        cross_validation.train_test_split(x, y, test_size=0.2)
    # 训练数据进行训练
    clf1.fit(x_train, y_train)
    # 将预测数据和测试集的验证数据比对
    result1.append(np.mean(y_test == clf1.predict(x_test)))
    clf2.fit(x_train, y_train)
    result2.append(np.mean(y_test == clf2.predict(x_test)))
print("svm classifier accuacy:")
print(result1)
print("LogisticRegression classifier accuacy:")
print(result2)
svm classifier accuacy:
[0.53345724907063197, 0.53717472118959109, 0.55018587360594795, 0.51115241635687736, 0.54460966542750933]
LogisticRegression classifier accuacy:
[0.55390334572490707, 0.55576208178438657, 0.53717472118959109, 0.52416356877323422, 0.53903345724907059]

这篇代码主要改进是将大盘的涨跌加入到了模型的参数中,并且是通过股票昨天的情况来预测明天的涨跌,发现SVM和逻辑回归效果差不多

阅读更多
文章标签: 量化交易 SVM
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

关闭
关闭
关闭