ml_svc_预测股票

 

目标:

       根据2019-01-01 至 2019-07-30 , 沪市指数的收盘价, 使用SVR, 回归预测 2019-07-31( 或者2019-08-01)的收盘价

拟合结果:

[LibSVM]..........................*...........*
optimization finished, #iter = 10450
obj = -1700429.608042, rho = -2906.668575
nSV = 141, nBSV = 52
SVR(C=1000.0, cache_size=1000, coef0=0.0, degree=3, epsilon=0.1, gamma=0.1,
  kernel='rbf', max_iter=-1, shrinking=True, tol=0.001, verbose=True)
2933.009977517439

拟合效果

实际K线图

 

代码

import os
import numpy as np
import pandas as pd
from scipy import stats
import matplotlib.pyplot as plt
from datetime import datetime as dt
from sklearn import preprocessing
from sklearn.svm import SVC, SVR
import plotly.offline as of
import plotly.graph_objs as go
import tushare as ts


# pip install ciso8601
# pip install stockai
def get_stock_data(stock_num, start):
    """
    下载数据
    股票数据的特征
        date:日期
        open:开盘价
        high:最高价
        close:收盘价
        low:最低价
        volume:成交量
        price_change:价格变动
        p_change:涨跌幅
        ma5:5日均价
        ma10:10日均价
        ma20:20日均价
        v_ma5:5日均量
        v_ma10:10日均量
        v_ma20:20日均量
    :param stock_num:
    :return:df
    """
    df = ts.get_hist_data(stock_num, start=start, ktype='D')
    return df


def draw_kchart(df, filename):
    """
    画k线图
    """
    Min_date = df.index.min()
    Max_date = df.index.max()
    print("First date is", Min_date)
    print("Last date is", Max_date)
    interval_date = dt.strptime(Max_date, "%Y-%m-%d") - dt.strptime(Min_date, "%Y-%m-%d")
    print(interval_date)
    trace = go.Ohlc(x=df.index, open=df['open'], high=df['high'], low=df['low'], close=df['close'])
    data = [trace]
    of.plot(data, filename=filename)


def stock_etl(df):
    df.dropna(axis=0, inplace=True)
    # print(df.isna().sum())
    df.sort_values(by=['date'], inplace=True, ascending=True)
    return df


def get_data(df):
    data = df.copy()
    # 年,月,天
    # data['date'] = data.index.str.split('-').str[2]
    # data['date'] = data.index.str.replace('-','')
    # print(data.index.tolist())
    data['date'] = [(dt.strptime(x, '%Y-%m-%d') - dt.strptime('2019-01-01', '%Y-%m-%d')).days for x in data.index.tolist()]
    data['date'] = pd.to_numeric(data['date'])
    return [data['date'].tolist(), data['close'].tolist()]


def predict_prices(dates, prices, x):
    dates = np.reshape(dates, (len(dates), 1))
    x = np.reshape(x, (len(x), 1))

    svr_lin = SVR(kernel='linear', C=1e3,gamma=0.1, verbose=True, cache_size=1000)
    svr_poly = SVR(kernel='poly', C=1e3, degree=2, gamma=0.1, verbose=True, cache_size=1000)
    svr_rbf = SVR(kernel='rbf', C=1e3, gamma=0.1, verbose=True, cache_size=1000)

    plt.scatter(dates, prices, c='k', label='Data')
    # 训练
    # svr_lin.fit(dates, prices)
    # print(svr_lin)
    # print(svr_lin.predict(x)[0])
    # plt.plot(dates, svr_lin.predict(dates), c='g', label='svr_lin')

    # svr_poly.fit(dates, prices)
    # print(svr_poly)
    # print(svr_poly.predict(x)[0])
    # plt.plot(dates, svr_lin.predict(dates), c='g', label='svr_lin')

    svr_rbf.fit(dates, prices)
    print(svr_rbf)
    print(svr_rbf.predict(x)[0])
    plt.plot(dates, svr_rbf.predict(dates), c='b', label='svr_rbf')

    plt.xlabel('date')
    plt.ylabel('Price')
    plt.grid(True)
    plt.legend()
    plt.show()
    # return svr_lin.predict(x)[0], svr_poly.predict(x)[0], svr_rbf.predict(x)[0]


if __name__ == "__main__":
    """
    预测股价和时间之间的关系
    """
    # sh 获取上证指数k线数据
    # sz 获取深圳成指k线数据
    # cyb 获取创业板指数k线数据
    df = get_stock_data('sh', '2019-01-01')
    # + 张家港行
    # df = get_stock_data('002839', '2019-01-01')
    df = stock_etl(df)

    curPath = os.path.abspath(os.path.dirname(__file__))
    draw_kchart(df, curPath + '/simple_ohlc.html')

    dates, prices = get_data(df)
    print(dates)
    print(prices)
    # print(predict_prices(dates, prices, [31]))
    # print(predict_prices(dates, prices, ['20190731']))
    a = dt.strptime('2019-07-31', '%Y-%m-%d')
    b = dt.strptime('2019-01-01', '%Y-%m-%d')
    c = (a - b).days
    predict_prices(dates, prices, [c])

遇到问题: 使用多项式核函数, 60%的CPU跑了4个小时没有拟合出来. 

                使用线性核函数, 跑了40分钟才拟合出来结果.

                使用高斯核函数,1分钟就拟合出结果了

源码:

https://github.com/clark99/learnMachinelearning/blob/master/sklearn/demo/svm/

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值