keras 时间序列数据预测与结果分析

数据来源

使用tushare的接口获取股票历史数据.

# 安装 tushare
pip3 install tushare -i https://pypi.tuna.tsinghua.edu.cn/simple 

需要在他们网站注册然后修改下个人资料,然后拿个token才能用.

# 设置 token
>>> ts.set_token('*****************')

tushare返回的是panda.dataframe格式的数据.
使用时,先将数据下载到 “…/data/” 文件夹下,避免每次使用都要调用网络接口.
要更新数据,将 “…/data/” 文件夹清空即可.

文件结构

  • project
    • data # 存放数据
    • market # 代码和模型
      • model # 存放模型
      • load_tools.py # 加载模型,数据
      • get_tools.py # 提供数据
      • get_samples.py # 构建模型需要的数据样本
      • new_generator.py # 数据样本生成器
      • dotrain.py # 训练模型
      • evaluate_model.py # 衡量模型各种性能
      • history_predict.py # 显示模型的历史预测曲线
      • serch_predict.py # 搜索目标时间预测的结果
      • run.py # 执行今日预测
      • run.bat # 执行今日预测

代码

load_tools.py

import os

import keras
import pandas as pd
import tushare as ts
from keras import backend as K


# 模型衡量标准
# 正样本中有多少被识别为正样本
def recall(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    real_true = K.sum(y_true)
    return true_positives / (real_true + K.epsilon())


def recall1(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * (y_pred - 0.2), 0, 1)))
    real_true = K.sum(y_true)
    return true_positives / (real_true + K.epsilon())


def recall2(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * (y_pred - 0.4), 0, 1)))
    real_true = K.sum(y_true)
    return true_positives / (real_true + K.epsilon())


# 负样本中有多少被识别为负样本
def n_recall(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip((1 - y_true) * (1 - y_pred), 0, 1)))
    real_true = K.sum(1 - y_true)
    return true_positives / (real_true + K.epsilon())


def n_recall1(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip((1 - y_true) * ((1 - y_pred) - 0.2), 0, 1)))
    real_true = K.sum(1 - y_true)
    return true_positives / (real_true + K.epsilon())


def n_recall2(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip((1 - y_true) * ((1 - y_pred) - 0.4), 0, 1)))
    real_true = K.sum(1 - y_true)
    return true_positives / (real_true + K.epsilon())


# 识别为正样本中有多少是正样本
def precision(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predict_true = K.sum(K.round(K.clip(y_pred, 0, 1)))
    return true_positives / (predict_true + K.epsilon())


def precision1(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * (y_pred - 0.2), 0, 1)))
    predict_true = K.sum(K.round(K.clip((y_pred - 0.2), 0, 1)))
    return true_positives / (predict_true + K.epsilon())


def precision2(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * (y_pred - 0.4), 0, 1)))
    predict_true = K.sum(K.round(K.clip((y_pred - 0.4), 0, 1)))
    return true_positives / (predict_true + K.epsilon())


# 识别为负样本中有多少是负样本
def n_precision(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip((1 - y_true) * (1 - y_pred), 0, 1)))
    predict_true = K.sum(K.round(K.clip((1 - y_pred), 0, 1)))
    return true_positives / (predict_true + K.epsilon())


def n_precision1(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip((1 - y_true) * ((1 - y_pred) - 0.2), 0, 1)))
    predict_true = K.sum(K.round(K.clip(((1 - y_pred) - 0.2), 0, 1)))
    return true_positives / (predict_true + K.epsilon())


def n_precision2(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip((1 - y_true) * ((1 - y_pred) - 0.4), 0, 1)))
    predict_true = K.sum(K.round(K.clip(((1 - y_pred) - 0.4), 0, 1)))
    return true_positives / (predict_true + K.epsilon())


# 预测结果中有多少是正样本
def prate(y_true, y_pred):
    return K.mean(K.round(K.clip(y_pred, 0, 1)))


# 实际中有多少是正样本
def trate(y_true, y_pred):
    return K.mean(K.round(K.clip(y_true, 0, 1)))


# 加载模型
def load_model(model_name='./model/cnn960to1080b.model', lookback=61, shape=5):
    # 加载模型时使用 keras.models.load_model(path, custom_objects=dependencies)
    dependencies = {
        'recall': recall,
        'recall1': recall1,
        'recall2': recall2,
        'precision': precision,
        'precision1': precision1,
        'precision2': precision2,
        'prate': prate,
        'trate': trate,
        'lookback': lookback,
        'shape': shape
    }

    model = keras.models.load_model(model_name, custom_objects=dependencies)
    model.compile(optimizer=keras.optimizers.RMSprop(),
                  loss=keras.losses.binary_crossentropy,
                  metrics=[recall, precision, recall1, precision1, recall2, precision2,
                           n_recall, n_precision, n_recall1, n_precision1, n_recall2, n_precision2, trate, prate])
    return model


# 获取一支股票的历史数据
def load_data(ts_code):
    # 判断文件是否存在,不存在则通过网络接口获得
    data_dir = '../data/'
    if not os.path.exists(data_dir + ts_code + '.csv'):
        # 初始化pro接口
        # pro = ts.pro_api('********************************')
        # 获取前复权数据
        df = ts.pro_bar(ts_code=ts_code, adj='qfq')
        # 保存数据到文件
        if df is None:
            print('can not get data')
            return
        df.to_csv(data_dir + ts_code + '.csv', index=False)
    df = pd.read_csv(data_dir + ts_code + '.csv')
    # ts_code, trade_date, open, high, low, close, pre_close, change, pct_chg, vol, amount, adj_factor
    # 股票代码, 交易日期, 开盘价, 最高价, 最低价, 收盘价, 昨收价, 涨跌额, 涨跌幅, 成交量, 成交额(千元)
    # 去空
    df.dropna(inplace=True)
    # 正序
    df = df.sort_index(ascending=False)
    # 索引重排序
    df.reset_index(drop=True, inplace=True)
    return df


# 加载股票列表
def load_code_list(market='SSE'):
    file_dir = '../data/' + 'code_list_' + market + '.csv'
    # 判断文件是否存在,不存在则通过网络接口获得
    if os.path.exists(file_dir):
        code_list = pd.read_csv(file_dir)
    else:
        # 初始化pro接口
        pro = ts.pro_api('*****************************')
        # 查询某交易所所有上市公司
        code_list = pro.stock_basic(exchange=market, list_status='L', fields='ts_code')  # ,symbol,name,market,list_date
        # 保存数据到文件
        code_list.to_csv(file_dir, index=False)
    code_list = code_list[['ts_code']].values.flatten()
    return code_list


# 根据模式输出
def print_verbose(verbose, text):
    if verbose:
        print(text)

get_tools.py

import numpy as np
from load_tools import *
from matplotlib import pyplot as plt


# 加载所需数据
def init(market_list, normalize=False):
    market_names = []
    market_datas = []
    market_datas_normal = []
    market_datas_date = []
    for i in range(len(market_list)):
        print('Load ', market_list[i])
        market_names.append([])
        market_datas.append([])
        market_datas_normal.append([])
        market_datas_date.append([])
        print('正在加载数据进入内存')
        for code_name in load_code_list(market=market_list[i]):
            print(code_name)
            market_names[i].append(code_name)
            market_datas[i].append(load_data(code_name))
        print('数据加载完毕')
        print('正在检查数据')
        for j in range(len(market_names[i])):
            # 查空
            if market_datas[i][j] is None:
                market_datas_normal[i].append(None)
                market_datas_date[i].append(None)
                continue
            if market_datas[i][j].empty:
                market_datas_normal[i].append(None)
                market_datas_date[i].append(None)
                continue
            # data = market_datas[i][j][['close', 'high', 'low', 'amount']].values
            data = market_datas[i][j][['close', 'open', 'high', 'low', 'amount']].values
            # 检查是否有错误值
            if np.isnan(data).any():
                print('nan in %s' % market_names[i][j])
                market_datas_normal[i].append(None)
                market_datas_date[i].append(None)
                continue
            # 进行正规化
            if normalize:
                mean = data.mean(axis=0)  # [6.98017146e+00, 7.12046020e+00, 6.83100609e+00, 1.65669341e+05]
                std = data.std(axis=0)  # [6.36818017e+00, 6.50689074e+00, 6.22204203e+00, 4.74562019e+05]
            else:
                mean = [0]
                std = [1]
            data -= mean
            if data.std(axis=0)[0] == 0:
                print('std is 0 in %s' % market_names[i][j])
                market_datas_normal[i].append(None)
                market_datas_date[i].append(None)
                continue
            data /= std
            market_datas_normal[i].append([data, mean, std])
            market_datas_date[i].append(market_datas[i][j]['trade_date'].tolist())
        print('数据检查完成')
    return market_names, market_datas, market_datas_normal, market_datas_date


# 加载所需数据
market_list = ['SSE', 'SZSE']
market_names, market_datas, market_datas_normal, market_datas_date = init(market_list, True)
date_list = market_datas[0][0]['trade_date'].values.tolist()


def get_data(ts_code):
    for i in range(len(market_list)):
        if ts_code in market_names[i]:
            return market_datas[i][market_names[i].index(ts_code)]


def get_data_normal(ts_code):
    for i in range(len(market_list)):
        if ts_code in market_names[i]:
            return market_datas_normal[i][market_names[i].index(ts_code)]


def get_data_date(ts_code):
    for i in range(len(market_list)):
        if ts_code in market_names[i]:
            return market_datas_date[i][market_names[i].index(ts_code)]


def get_code_list(market='SSE'):
    if market == 'ALL':
        ALL_names = []
        for i in range(len(market_list)):
            ALL_names += market_names[i]
        return ALL_names
    else:
        return market_names[market_list.index(market)]


# 训练历史可视化
def show_train_history(train_history, train_metrics, validation_metrics):
    plt.plot(train_history.history[train_metrics])
    plt.plot(train_history.history[validation_metrics])
    # plt.title('Train History')
    plt.ylabel(train_metrics)
    plt.xlabel('Epoch')
    plt.legend(['train', 'validation'], loc='upper left')


# 显示训练过程
def plot_history(history):
    plt.figure(figsize=(12, 8))
    plt.subplot(2, 2, 1)
    show_train_history(history, 'loss', 'val_loss')
    plt.subplot(2, 2, 2)
    show_train_history(history, 'recall', 'val_recall')
    plt.subplot(2, 2, 3)
    show_train_history(history, 'precision', 'val_precision')
    plt.subplot(2, 2, 4)
    show_train_history(history, 'precision2', 'val_precision2')
    plt.savefig('./model/auto_save.jpg')
    plt.show()


# 通过时间搜索index
def date2index(ts_code, start_date, end_date, lookback, delay, verbose=0):
    start = lookback - 1
    dl = get_data_date(ts_code)
    if not dl:
        return
    if start_date != '':
        if start_date not in dl:
            print_verbose(verbose, 'can not find date')
            return
        else:
            start = max(start, dl.index(start_date))
    end = len(dl) - delay
    if end_date != '':
        if end_date not in dl:
            print_verbose(verbose, 'can not find date')
            return
        else:
            end = min(end, dl.index(end_date))
    if start >= end:
        print_verbose(verbose, 'data range too small, may be date too close to boundary.')
        return
    return start, end

L36:根据提供的数据筛选某些值进行训练,这里选了[‘close’, ‘open’, ‘high’, ‘low’, ‘amount’],分别是收盘价,开盘价,最高值,最低值,成交量.由于后面的预测都是根据每日的收盘价进行,所以收盘价必需在这个列表的第一位
L44:若normalize=True,这里对每支股票单独进行正规化,这样可能会导致模型发现该股票的最高值和最低值,但是后来发现好像影响不大.
L64:market_list 这里进行设置要使用的交易所的数据 ‘SSE’:上海交易所,‘SZSE’:深圳交易所
L65:开始加载数据.init的第二个参数选择是否进行正规化.第一次运行可能需要差不多一个小时来下载数据.
L108:plot_history()根据训练的类型来调整这个函数里各项的值以显示需要的学习曲线.

get_samples.py

import random

import numpy as np

from get_tools import *


# 给出一支股票某段时间的Samples
def get_samples(ts_code='600004.SH', date=20191108, lookback=61, duiring=20, verbose=1, normalize=True):
    # 获取数据
    data_normal = get_data_normal(ts_code)
    if data_normal is None:
        print_verbose(verbose, 'can not find date normal')
        return
    # 获取标准化
    data = data_normal[0]
    mean = data_normal[1]
    std = data_normal[2]

    # 找到预测集
    se_index = date2index(ts_code, date, '', lookback, 0, verbose)
    if not se_index:
        print_verbose(verbose, 'can not get date')
        return
    i = se_index[0] + 1
    rows = np.arange(i - duiring + 1, i + 1)
    samples = np.zeros((len(rows),
                        lookback,
                        data.shape[-1]))
    # targets = np.zeros((len(rows),))
    for j, row in enumerate(rows):
        if rows[j] - lookback < 0:
            print_verbose(verbose, 'date range too small in %s' % ts_code)
            return
        indices = range(rows[j] - lookback, rows[j])
        samples[j] = data[indices]
    return samples


# 给出一支股票某段时间的Samples和Targets
def get_samples_targets(ts_code='600004.SH', start_date='', end_date='', lookback=61, delay=1, uprate=0.0, mod='', rand=False, verbose=1):
    # 获取数据
    data_normal = get_data_normal(ts_code)
    if data_normal is None:
        print_verbose(verbose, 'can not find date normal')
        return
    # 获取标准化
    data = data_normal[0]
    mean = data_normal[1]
    std = data_normal[2]

    # 找到起点终点位置
    se_index = date2index(ts_code, start_date, end_date, lookback, delay, verbose)  # 0.08
    if se_index is None:
        return
    start, end = se_index

    # 随机抽取一个
    if rand:
        start = random.randint(start, end - 1)
        end = start + 1

    # 构建
    rows = np.arange(start, end)
    samples = np.zeros((len(rows),
                        lookback,
                        data.shape[-1]))
    targets = np.zeros((len(rows),))
    for j, row in enumerate(rows):
        indices = range(rows[j] - (lookback - 1), rows[j] + 1)
        samples[j] = data[indices]
        # 涨跌值
        if mod == 'delta':
            targets[j] = (data[row + delay][0] * std[0] + mean[0]) - (data[row][0] * std[0] + mean[0])
            continue
        # 涨跌幅
        if mod == 'rate':
            targets[j] = (data[row + delay][0] * std[0] + mean[0]) / (data[row][0] * std[0] + mean[0]) - 1
            continue
        # 是否上涨
        if data[row + delay][0] * std[0] + mean[0] > (data[row][0] * std[0] + mean[0]) * (1 + uprate):
            targets[j] = 1
        else:
            targets[j] = 0
    return samples, targets


# 计算每只股票一段时间内的sample大小
def count_samples_weight(market, start_date='', end_date='', lookback=61, delay=1, verbose=1):
    code_list = get_code_list(market=market)
    names = []
    weight = []
    for code_name in code_list:
        print_verbose(verbose, code_name)
        df = get_data(code_name)
        if df is None:
            print_verbose(verbose, 'can not find data')
            continue
        # 找到起点终点位置
        se_index = date2index(code_name, start_date, end_date, lookback, delay, verbose)
        if se_index is None:
            continue
        start, end = se_index
        names.append(code_name)
        weight.append(end - start)
    return names, weight

get_samples_targets()中mod参数控制产生target的方式,取决于是要预测涨跌幅mod=‘delta’,涨跌值mod=‘rate’还是是否上涨mod=’’.

new_generator.py

import os
import random

import tushare as ts
import numpy as np
import pandas as pd

from get_tools import *
from get_samples import *


def new_generator(market='SSE', batch_size=1024, shape=4, start_date='', end_date='', lookback=61, delay=1,
                  uprate=0.0):
    # 加载权重
    # print('init generator')
    data = count_samples_weight(market, start_date=start_date, end_date=end_date, lookback=lookback,
                                delay=delay, verbose=0)
    if not data[0]:
        print('can not get data, maybe date wrong')
        return
    samples = np.zeros((batch_size,
                        lookback,
                        shape))
    targets = np.zeros((batch_size,))
    while 1:
        for i in range(batch_size):
            name = random.choices(data[0], data[1])[0]
            sample, target = get_samples_targets(ts_code=name, start_date=start_date, end_date=end_date,
                                                 lookback=lookback, delay=delay, uprate=uprate, rand=True, mod='')
            samples[i] = sample[0]
            targets[i] = target[0]
        yield samples, targets

evaluate_model.py

import os
import random

import tushare as ts
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

import tensorflow as tf
import keras
from keras.models import Sequential
from keras import layers
from keras.optimizers import RMSprop
from keras import backend as K
import keras.backend.tensorflow_backend as KTF

from get_samples import get_samples_targets
from get_tools import *
from make_generators import make_generators
from new_generator import new_generator


# 衡量模型对一支股票的准确率
def evaluate_old(model, ts_code='600004.SH'):
    print(ts_code)
    generator = make_generators(ts_code, shuffle=False, batch_size='auto')
    if generator is None:
        return
    result = model.evaluate_generator(generator[0], steps=1)
    print(result)


# 根据正输出的价格差衡量模型(结果为按照模型交易每个交易日平均价格变化,只适用于delay=1)
def evaluate_delta(model, ts_code='600004.SH', start_date='', end_date='', lookback=61, delay=1, base_line=0.5, verbose=1):
    data = get_samples_targets(ts_code=ts_code, start_date=start_date, end_date=end_date,
                               lookback=lookback, delay=delay, mod='delta', verbose=verbose)
    if data is None:
        return
    result = model.predict(data[0])
    result = result.T[0]
    if base_line < 0.5 and base_line != 0.0:
        predict = 1 - np.round(result - base_line + 0.5)
    else:
        predict = np.round(result - base_line + 0.5)
    return sum(predict * data[1]) / sum(predict)


# 衡量模型对一支股票的准确率
def evaluate(model, ts_code='600004.SH', start_date='', end_date='', lookback=61, delay=1):
    print(ts_code)
    data = get_samples_targets(ts_code=ts_code, start_date=start_date, end_date=end_date,
                               lookback=lookback, delay=delay)
    if data is None:
        return
    result = model.evaluate(data[0], data[1], batch_size=9999, verbose=0)
    return result


# 衡量模型对所有股票的准确率
def evaluate_total(model, market='ALL', steps=10, shape=5, start_date='', end_date='', lookback=61, delay=1, uprate=0.0):
    generator = new_generator(market=market,
                              shape=shape,
                              start_date=start_date,
                              end_date=end_date,
                              lookback=lookback,
                              delay=delay,
                              uprate=uprate,
                              batch_size=len(get_code_list(market)))
    test = next(generator)
    if test is None:
        return
    result = model.evaluate_generator(generator, steps=steps)
    return result


# 批量衡量模型对每支股票的准确率
def evaluate_all(model, market='SSE', start_date='', end_date='', lookback=61, delay=1):
    # 加载股票列表
    code_list = get_code_list(market=market)
    for code_name in code_list[:]:
        result = evaluate(model=model, ts_code=code_name, start_date=start_date, end_date=end_date,
                          lookback=lookback, delay=delay)
        print(result)


# 批量衡量模型对每支股票的delta
def evaluate_all_delta(model, market='SSE', start_date='', end_date='', lookback=61, delay=1, base_line=0.5, verbose=0):
    # 加载股票列表
    code_list = get_code_list(market=market)
    sum_list = []
    for code_name in code_list[:]:
        print_verbose(verbose, code_name)
        result = evaluate_delta(model=model, ts_code=code_name, start_date=start_date, end_date=end_date,
                                lookback=lookback, delay=delay, base_line=base_line, verbose=verbose)
        print_verbose(verbose, result)
        if result and not np.isnan(result):
            sum_list.append(result)
    print("平均:", np.average(sum_list))
    return sum_list


# 按时间衡量模型准确度
def evaluate_total_time(model, date_step=61, steps=3, start_date='20170103', end_date='', lookback=61, delay=1, uprate=0.0):
    # 计算起止index
    if start_date == '':
        start = 0
    elif int(start_date) not in date_list:
        print('can not find date')
        return
    else:
        start = date_list.index(int(start_date))
    if end_date == '':
        end = len(date_list)
    elif int(end_date) not in date_list:
        print('can not find date')
        return
    else:
        end = date_list.index(int(end_date))
    # 开始计算
    dates = []
    results = []
    for i in range(start, end, date_step):
        if i + date_step >= end:
            continue
        date = '%s : %s' % (date_list[i], date_list[i + date_step])
        print(date)
        result = evaluate_total(model, market='ALL', steps=steps, start_date=date_list[i],
                                end_date=date_list[i + date_step], lookback=lookback, delay=delay, uprate=uprate)
        if result:
            dates.append(date)
            results.append(result)
            print(result)
    plt.plot([i[2] for i in results], label='acc5', c='green')
    plt.plot([i[4] for i in results], label='acc7', c='blue')
    plt.plot([i[6] for i in results], label='acc9', c='red')
    plt.plot([i[1] for i in results], label='rec5', c='lightgreen')
    plt.plot([i[3] for i in results], label='rec7', c='lightblue')
    plt.plot([i[5] for i in results], label='rec9', c='pink')
    # plt.plot([i[8] for i in results], label='n_acc5', c='green')
    # plt.plot([i[10] for i in results], label='n_acc3', c='blue')
    # plt.plot([i[12] for i in results], label='n_acc1', c='red')
    # plt.plot([i[7] for i in results], label='n_rec5', c='lightgreen')
    # plt.plot([i[9] for i in results], label='n_rec3', c='lightblue')
    # plt.plot([i[11] for i in results], label='n_rec1', c='pink')
    plt.plot([i[13] for i in results], label='Trate', c='black')
    plt.plot([i[14] for i in results], label='Prate', c='brown')
    plt.legend()
    plt.show()
    return dates, results

大部分衡量方法只对进行预测是否上涨(即mod=’’)的模型有效

衡量对所有股票的准确率
evaluate_total()

In [1]: evaluate_total(model)
Out[1]:
[0.6048318326473237,	# loss
 0.61862553358078,		# 0.5 recall
 0.6655296385288239,	# 0.5 precision
 0.26063627749681473,	# 0.7 recall
 0.8474650800228118,	# 0.7 precision
 0.08491439446806907,	# 0.9 recall
 0.9620700001716613,	# 0.9 precision
 0.6881303012371063,	# 负结果 0.5 recall
 0.6426444888114929,	# 负结果 0.5 precision
 0.2039469599723816,	# 负结果 0.3 recall
 0.8001361966133118,	# 负结果 0.3 precision
 0.03344998843967915,	# 负结果 0.1 recall
 0.945061206817627,		# 负结果 0.1 precision
 0.5008360147476196,	# 预测得正的比例
 0.4655305534601212]	# 实际为正的比例

对每支股票按照模型的输出交易,衡量平均每个交易日的价格变化
evaluate_all_delta()

In [11]: sum_list = evaluate_all_delta(model, market='ALL', start_date=20191108, end_date='', base_line=0.9, delay=1)
平均: 0.01869930191972075

按时间衡量模型准确度
evaluate_total_time()
date_step:每个时间段持续时间
steps:计算次数,次数越多结果越准确

In [13]: evaluate_total_time(model, date_step=61, steps=3, start_date=20170103, end_date='', lookback=61, delay=1, uprate=0.0)
20170103 : 20170407
[0.5714327494303385, 0.5911784172058105, 0.7350501616795858, 0.29650260011355084, 0.8807578881581625, 0.09752903630336125, 0.9589986602465311, 0.7945913275082906, 0.6683776577313741, 0.3276803294817607, 0.8117450277010599, 0.03380454579989115, 0.9170262813568115, 0.49095123012860614, 0.3948471049467723]
20170407 : 20170706
[0.5693944891293844, 0.5984461506207784, 0.7022876739501953, 0.29884132742881775, 0.861210823059082, 0.11071240405241649, 0.9637072682380676, 0.7712886929512024, 0.6806734005610148, 0.34076804916063946, 0.8298511107762655, 0.0571845310429732, 0.940330425898234, 0.47428010900815326, 0.4042078951994578]
20170706 : 20170929
[0.5798394282658895, 0.5932405988375345, 0.7093638777732849, 0.2722257872422536, 0.8895139296849569, 0.08448692659536998, 0.9772547682126363, 0.7707486947377523, 0.6676571170488993, 0.27164021134376526, 0.8189897139867147, 0.03643824098010858, 0.9544040362040201, 0.48542391260464984, 0.40599090854326886]
20170929 : 20180102
[0.5729995369911194, 0.6039112210273743, 0.7030234138170878, 0.27464069922765094, 0.8634665409723917, 0.07280133416255315, 0.9295691649119059, 0.7785913745562235, 0.6936554312705994, 0.29994242389996845, 0.8393307328224182, 0.0484745018184185, 0.9419203003247579, 0.46465187271436054, 0.3991263310114543]
20180102 : 20180404
[0.734302838643392, 0.4271387755870819, 0.5242762168248495, 0.09093193213144939, 0.5486705501874288, 0.005700886559983094, 0.6018797159194946, 0.6124776403109232, 0.5167023837566376, 0.10401579737663269, 0.5127503971258799, 0.0012460101085404556, 0.3690476218859355, 0.500044584274292, 0.40723901987075806]
20180404 : 20180705
[0.7454104423522949, 0.46056893467903137, 0.4625197152296702, 0.12217340618371964, 0.5129115482171377, 0.007561440734813611, 0.4466666678587596, 0.5656497875849406, 0.5636967420578003, 0.10787152250607808, 0.5407769481341044, 0.003071665569829444, 0.5476190447807312, 0.44806989034016925, 0.4461085796356201]
20180705 : 20181008
[0.7507510582605997, 0.45997997124989826, 0.4787188569704692, 0.12299211571613948, 0.5070395072301229, 0.009159183750549952, 0.44206081827481586, 0.5611742933591207, 0.5425703724225363, 0.10632317264874776, 0.5173460145791372, 0.0021679462709774575, 0.49470900495847064, 0.4669697781403859, 0.44869394103686017]
20181008 : 20190103
[0.7226652503013611, 0.49212220311164856, 0.5187315940856934, 0.12433483948310216, 0.5550123651822408, 0.008488856721669436, 0.5555555522441864, 0.5768180886904398, 0.5506282846132914, 0.10334843893845876, 0.5724086364110311, 0.0015320322709158063, 0.5301587382952372, 0.4810555378595988, 0.45636088649431866]
20190103 : 20190408
[0.7271912296613058, 0.4668257534503937, 0.5723404884338379, 0.07485490789016087, 0.5707581837972006, 0.0043184501118958, 0.6057692368825277, 0.560244639714559, 0.45459697643915814, 0.054048859824736915, 0.44462979833285016, 0.0008052353902409474, 0.8333332538604736, 0.5576357245445251, 0.4548453191916148]
20190408 : 20190708
[0.767724335193634, 0.5531678597132365, 0.47463998198509216, 0.16808140774567923, 0.47606175144513446, 0.01771405277152856, 0.48153934876124066, 0.4719281991322835, 0.550476054350535, 0.08642558256785075, 0.5370267828305563, 0.004146686522290111, 0.7166666587193807, 0.46304715673128766, 0.5397165020306905]
20190708 : 20191009
[0.7437284390131632, 0.42941097418467206, 0.4692676564057668, 0.08559017876784007, 0.45823829372723895, 0.00663522615407904, 0.5171670218308767, 0.5693001747131348, 0.5294030706087748, 0.09079131732384364, 0.5481707056363424, 0.0008423991190890471, 0.4833333343267441, 0.4700900415579478, 0.43015066782633465]

在这里插入图片描述

history_predict.py

import os
import random
import tushare as ts
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import keras
from keras.models import Sequential
from keras import layers
from keras.optimizers import RMSprop
from keras import backend as K
import keras.backend.tensorflow_backend as KTF

from evaluate_model import evaluate
from get_samples import get_samples, get_data, get_samples_targets


# 显示历史预测曲线
def history_predict(model, ts_code='600004.SH', date=20191128, delay=1, during=244, mod='simple'):
    # 获取数据
    df = get_data(ts_code)
    if df is None:
        print('can not find data')
        return
    # 打印历史准确率
    print(evaluate(model, ts_code=ts_code))
    # 整理数据
    data = get_samples(ts_code=ts_code, date=date, duiring=during)
    if data is None:
        return
    result = model.predict(data)
    print('数据分割日:', df[df['trade_date'].isin(['20180102'])].index[0])
    today = df[df['trade_date'].isin([date])].index[0] + 1

    # 画图
    if mod != 'complex' and mod != 'c':
        # 简单
        plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
        plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
        fig = plt.figure()
        ax1 = fig.add_subplot(111)
        axis_max = max(abs((df['close'].shift(-delay) - df['close'])[today - during:today]))
        plt.ylim(ymin=-axis_max, ymax=axis_max)
        ax1.plot((df['close'].shift(-delay) - df['close'])[today - during:today], c='b', label='目标时间后涨跌幅')
        ax1.set_ylabel('目标时间涨跌幅')
        ax2 = ax1.twinx()
        plt.ylim(ymin=0, ymax=1)
        ax2.plot(range(today - during, today), result * 0 + 0.5, c='r')
        ax2.plot(range(today - during, today), result * 0 + 0.7, c='r')
        ax2.plot(range(today - during, today), result * 0 + 0.9, c='r')
        ax2.plot(range(today - during, today), result, c='y', label='预测值')
        ax2.set_ylabel('预测值')
        # 图例
        handles1, labels1 = ax1.get_legend_handles_labels()
        handles2, labels2 = ax2.get_legend_handles_labels()
        plt.legend(handles1 + handles2, labels1 + labels2, loc='upper right')
        plt.legend(handles1 + handles2, labels1 + labels2, loc='upper right')
        plt.title(ts_code)
        plt.show()
    else:
        # 复杂
        plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
        plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
        fig = plt.figure()
        ax1 = fig.add_subplot(111)
        ax1.plot(df['close'][today - during:today], c='g', label='今天')
        ax1.plot(df['close'].shift(-delay)[today - during:today], c='b', label='目标时间')
        ax1.set_ylabel('走势')
        ax2 = ax1.twinx()
        plt.ylim(ymin=0, ymax=1)
        ax2.plot(range(today - during, today), result * 0 + 0.5, c='r')
        ax2.plot(range(today - during, today), result * 0 + 0.7, c='r')
        ax2.plot(range(today - during, today), result * 0 + 0.9, c='r')
        ax2.plot(range(today - during, today), result, c='y', label='预测值')
        ax2.set_ylabel('预测值')
        # 图例
        handles1, labels1 = ax1.get_legend_handles_labels()
        handles2, labels2 = ax2.get_legend_handles_labels()
        plt.legend(handles1 + handles2, labels1 + labels2, loc='upper right')
        plt.legend(handles1 + handles2, labels1 + labels2, loc='upper right')
        plt.title(ts_code)
        plt.show()

In [87]: history_predict(model, ts_code='600004.SH', date=20191128, delay=1, during=244, mod='simple')
600004.SH
[0.5915805101394653, 0.623115599155426, 0.6919642686843872, 0.22914573550224304, 0.8923678994178772, 0.07638190686702728, 0.9870129823684692, 0.7153171896934509, 0.649040699005127, 0.14698298275470734, 0.8357771039009094, 0.024755029007792473, 0.9599999785423279, 0.5064902305603027, 0.45609569549560547]
模型构建日: 3967

在这里插入图片描述

serch_predict.py

import os
import random

import tushare as ts
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

import tensorflow as tf
import keras
from keras.models import Sequential
from keras import layers
from keras.optimizers import RMSprop
from keras import backend as K
import keras.backend.tensorflow_backend as KTF

from get_samples import get_samples
from get_tools import *


# 搜索预测值高的股票
def search_predict(model, date=20191128, market='SSE', duiring=1, baseline=0.9, verbose=1):
    # 加载股票列表
    code_list = get_code_list(market=market)
    # 准备循环
    sum_pred = None
    sum_count = 0
    result_code = []
    result_pred = []
    for code_name in code_list[:]:
        print_verbose(verbose, code_name)
        samples = get_samples(ts_code=code_name, date=date, duiring=duiring, verbose=verbose)
        if samples is None:
            continue
        # 统计
        pred = model.predict(samples)
        if sum_count == 0:
            sum_pred = np.round(pred)
        else:
            sum_pred = sum_pred + np.round(pred)
        sum_count += 1
        # 判断
        if any(pred > baseline):
            result_code.append(code_name)
            result_pred.append(pred)
            print_verbose(verbose, '%s*****************************************' % code_name)
            print_verbose(verbose, pred)
    rate_pred = sum_pred / sum_count
    print('rate_pred:\n%s' % rate_pred)
    for i in range(len(result_code)):
        print('%s*****************************************\n%s' % (result_code[i], result_pred[i]))
    return result_code, result_pred, rate_pred

搜索从date开始往前during时间内预测结果达到baseline的股票

In [7]: results=search_predict(model, date=20191108, market='SSE', duiring=1, baseline=0.9, verbose=0)
rate_pred:
[[0.36472148]]
600127.SH*****************************************
[[0.9194063]]
603018.SH*****************************************
[[0.9124184]]

run.py

import os
import shutil
import time

# 重置data文件夹
data_dir = '../data/'
if os.path.exists(data_dir+"code_list_SSE.csv"):
    shutil.rmtree(data_dir)
    os.mkdir(data_dir)

from serch_predict import *
from evaluate_model import *
from history_predict import *

# 加载模型
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
tf.keras.backend.set_session(tf.Session(config=config))
model = load_model(model_name='./model/binary/ATT140to740.model')
# model = load_model(model_name='./model/1y1d/cudnnGRU/cudnnGRU210to340conv.model')

# 开始搜索
today = time.strftime("%Y%m%d")
print('today : %s' % today)
result_code, result_pred, rate_pred = search_predict(model, date=today, duiring=1, market='ALL', verbose=1, baseline=0.9)
# 将结果保存到文件
f = open("result.txt", "w")
f.write('rate_pred:\n%s\n' % rate_pred)
for i in range(len(result_code)):
    f.write('%s*****************************************\n%s\n' % (result_code[i], result_pred[i]))
f.close()

run.bat

python -i run.py

运行run.bat将清空数据并重新加载最新的数据,以今天为目标预测目标时间(今天+dalay)的结果,并储存>0.9的结果到results.txt文件夹中.

dotrain.py

import os
import random
import time

import tushare as ts
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from keras.layers import *

import tensorflow as tf
import keras
from keras.models import Sequential
from keras import layers
from keras.optimizers import RMSprop
from keras import backend as K
import keras.backend.tensorflow_backend as KTF

from get_tools import *
from new_generator import new_generator

# GPU动态占用率
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
tf.keras.backend.set_session(tf.Session(config=config))

# gen
batch_size = 1024
shape = 5
train_val_date = 20180102
val_test_date = 20191108
lookback = 61  # 244/year
delay = 1
uprate = 0.0
generator = new_generator(market='ALL', batch_size=batch_size, shape=shape,
                          start_date='', end_date=train_val_date,
                          lookback=lookback, delay=delay, uprate=uprate)
val_generator = new_generator(market='ALL', batch_size=batch_size, shape=shape,
                              start_date=train_val_date, end_date=val_test_date,
                              lookback=lookback, delay=delay, uprate=uprate)

# 建模
# *************************************** CNN ***********************************
# model = Sequential()
# kernel_size = 4
# dropout_rate = 0.3
# model.add(layers.Conv1D(8, kernel_size=kernel_size, strides=2, padding='same',
#                         input_shape=(lookback, shape)))
# model.add(layers.BatchNormalization())
# model.add(layers.LeakyReLU())
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Conv1D(16, kernel_size=kernel_size, strides=2, padding='same'))
# model.add(layers.BatchNormalization())
# model.add(layers.LeakyReLU())
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Conv1D(32, kernel_size=kernel_size, strides=2, padding='same'))
# model.add(layers.BatchNormalization())
# model.add(layers.LeakyReLU())
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Conv1D(64, kernel_size=kernel_size, strides=2, padding='same'))
# model.add(layers.BatchNormalization())
# model.add(layers.LeakyReLU())
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Conv1D(128, kernel_size=kernel_size, strides=2, padding='same'))
# model.add(layers.BatchNormalization())
# model.add(layers.LeakyReLU())
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Conv1D(256, kernel_size=kernel_size, strides=2, padding='same'))
# model.add(layers.BatchNormalization())
# model.add(layers.LeakyReLU())
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Conv1D(512, kernel_size=kernel_size, strides=2, padding='same'))
# model.add(layers.BatchNormalization())
# model.add(layers.LeakyReLU())
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Flatten())
# model.add(layers.Dense(1, activation='sigmoid'))
# model.compile(optimizer=keras.optimizers.Adam(),  # lr=1e-4, epsilon=1e-8, decay=1e-4),
#               loss=keras.losses.binary_crossentropy,
#               metrics=[recall, precision, recall2, precision2, trate, prate])
# ************************************* G R U ******************************************
# dropout_rate = 0.5
# model = Sequential()
# # model.add(layers.BatchNormalization())
# model.add(layers.GRU(256,
#                      dropout=0.1,
#                      recurrent_dropout=0.5,
#                      input_shape=(None, shape)))
# model.add(layers.Dense(64, activation='relu'))
# model.add(layers.Dropout(dropout_rate))
# model.add(layers.Dense(1, activation='sigmoid'))
# model.compile(optimizer=keras.optimizers.RMSprop(1e-4),
#               loss=keras.losses.binary_crossentropy,
#               metrics=[recall, precision, recall2, precision2, trate, prate])
# *************************************** ResNet ***************************************
# def ResBlock(x, num_filters, resampling=None, kernel_size=3):
#     def BatchActivation(x):
#         x = BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
#         x = Activation('relu')(x)
#         return x
#
#     def Conv(x, resampling=resampling):
#         weight_decay = 1e-4
#         if resampling is None:
#             x = Conv1D(num_filters, kernel_size=kernel_size, padding='same',
#                        kernel_initializer="he_normal",
#                        kernel_regularizer=regularizers.l2(weight_decay))(x)
#         elif resampling == 'up':
#             x = UpSampling2D()(x)
#             x = Conv1D(num_filters, kernel_size=kernel_size, padding='same',
#                        kernel_initializer="he_normal",
#                        kernel_regularizer=regularizers.l2(weight_decay))(x)
#         elif resampling == 'down':
#             x = Conv1D(num_filters, kernel_size=kernel_size, strides=2, padding='same',
#                        kernel_initializer="he_normal",
#                        kernel_regularizer=regularizers.l2(weight_decay))(x)
#         return x
#
#     a = BatchActivation(x)
#     y = Conv(a, resampling=resampling)
#     y = BatchActivation(y)
#     y = Conv(y, resampling=None)
#     if resampling is not None:
#         x = Conv(a, resampling=resampling)
#     return add([y, x])
#
#
# num_layers = int(np.log2(lookback)) - 3
# max_num_channels = lookback * 8
# weight_decay = 1e-4
#
# x_in = Input(shape=(lookback, shape))
# x = x_in
# for i in range(num_layers + 1):
#     num_channels = max_num_channels // 2 ** (num_layers - i)
#     if i > 0:
#         x = ResBlock(x, num_channels, resampling='down')
#     else:
#         x = Conv1D(num_channels, kernel_size=3, strides=2, padding='same',
#                    kernel_initializer="he_normal",
#                    kernel_regularizer=regularizers.l2(weight_decay))(x)
# x = BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
# x = Activation('relu')(x)
# x = GlobalAveragePooling1D()(x)
# x = Dense(1, activation='sigmoid')(x)
# model = keras.Model(x_in, x)
# model.compile(optimizer=keras.optimizers.Adam(),  # lr=1e-4, epsilon=1e-8, decay=1e-4),
#               loss=keras.losses.binary_crossentropy,
#               metrics=[recall, precision, recall2, precision2, trate, prate])
# # model.summary()
# *********************************** Attention *********************************************
dropout_rate = 0.3
x_in = Input(shape=(lookback, shape))
x = x_in
# x = BatchNormalization()(x)
c = Conv1D(32, 5, activation='relu')(x)
c = LeakyReLU()(c)
c = Dropout(dropout_rate)(c)
c = Flatten()(c)
c = Dense(lookback * shape)(c)
c = LeakyReLU()(c)
c = Lambda(lambda k: K.reshape(k, (-1, lookback, shape)))(c)
m = multiply([x, c])
r = GRU(256)(m)
r = LeakyReLU()(r)
r = Dropout(dropout_rate)(r)
d = Dense(256)(r)
d = LeakyReLU()(d)
d = Dropout(dropout_rate)(d)
# res
res = Dense(1, activation='sigmoid')(d)
model = keras.Model(inputs=x_in, outputs=res)
model.compile(optimizer=keras.optimizers.Adam(lr=1e-4),  # lr=1e-4, epsilon=1e-8, decay=1e-4),
              loss=keras.losses.binary_crossentropy,
              metrics=[recall, precision, recall2, precision2, trate, prate]
              )
# model = load_model('./model/ATTSMALL480bad.model')
# model.load_weights('./model/ATTSMALL480bad.weight')

# callback
checkpoint = keras.callbacks.ModelCheckpoint('./model/auto_save_best.model', monitor='val_loss',
                                             verbose=1, save_best_only=True, mode='min')
learning_rate_reduction = keras.callbacks.ReduceLROnPlateau(monitor='loss', patience=60,
                                                            factor=0.5, min_lr=1e-8, verbose=1)
callbacks_list = [checkpoint, learning_rate_reduction]

# run
history = model.fit_generator(generator,
                              steps_per_epoch=200,  # 1min/epoch
                              epochs=180,
                              validation_data=val_generator,
                              validation_steps=10,
                              callbacks=callbacks_list,
                              # class_weight=class_weight,
                              verbose=1)

model.save('./model/auto_save.model')
model.save_weights('./model/auto_save.weight')
# show_train_history(history, 'loss', 'val_loss')
# plt.savefig('./model/auto_save.jpg')
# plt.show()
plot_history(history)

训练

(以进行lookback=61,delay=1,进行对第二天是否上涨进行训练)

CNN

在这里插入图片描述

GRU

在这里插入图片描述

ResNet

在这里插入图片描述

混合Attention

在这里插入图片描述

效果

(使用ATT140to740.model作为测试)

衡量总体准确率

全部数据

In [19]: evaluate_total(model, market='ALL', steps=10, shape=5, start_date='', end_date='', lookback=61, delay=1, uprate=0.0)
Out[19]:
[0.6090528190135955,
 0.6116366863250733,
 0.6690988123416901, # 取0.5时准确率
 0.256537552177906,
 0.8376343965530395,
 0.08675041273236275,
 0.9569046974182129, # 取0.9时准确率
 0.689475154876709,
 0.6336585283279419,
 0.20669595450162886,
 0.7789588630199432,
 0.03284015003591776,
 0.9456658005714417,
 0.50647232234478,
 0.4630115032196045]

交叉验证集数据

In [22]: evaluate_total(model, market='ALL', steps=10, shape=5, start_date=20180102, end_date=20191108, lookback=61, delay=1, uprate=0.0)
Out[22]:
[0.7405495762825012,
 0.46770868003368377,
 0.5000730127096176, # 取0.5时准确率
 0.1060352623462677,
 0.5055912852287292,
 0.009883439354598521,
 0.5355186939239502, # 取0.9时准确率
 0.5606197416782379,
 0.5284954369068146,
 0.09424014165997505,
 0.5357251167297363,
 0.0018147096503525971,
 0.6095238149166107,
 0.4844076007604599,
 0.4531158059835434]

测试集数据

In [23]: evaluate_total(model, market='ALL', steps=10, shape=5, start_date=20191108, end_date="", lookback=61, delay=1, uprate=0.0)
Out[23]:
[0.7380412459373474,
 0.4238017022609711,
 0.47870981097221377, # 取0.5时准确率
 0.09282655492424965,
 0.5137168973684311,
 0.006658474449068308,
 0.5203565269708633, # 取0.9时准确率
 0.5914961218833923,
 0.5369187474250794,
 0.09529061019420623,
 0.5270131379365921,
 0.0015642174810636788,
 0.6009523928165436,
 0.46967103481292727,
 0.41572613418102267]

衡量模型对每只股票的准确率

全部数据

In [26]: evaluate_all(model, market='ALL', start_date='', end_date='', lookback=61, delay=1)
# code
# [loss, 0.5 recall, 0.5 precision, 0.7 recall, 0.7 precision, 0.9 recall, 0.9 precision, 0.5 neg-recall, 0.5 neg-precision, 0.3 neg-recall, 0.3 neg-precision, 0.1 neg-recall, 0.1 neg-precision, Prate, Trate]
600000.SH
[0.6263973116874695, 0.5945805311203003, 0.6121244430541992, 0.15789473056793213, 0.8885630369186401, 0.029702970758080482, 1.0, 0.64207923412323, 0.6250602602958679, 0.10198019444942474, 0.8995633125305176, 0.0089108906686306, 1.0, 0.4871794879436493, 0.47321656346321106]
600004.SH
[0.5915805101394653, 0.623115599155426, 0.6919642686843872, 0.22914573550224304, 0.8923678994178772, 0.07638190686702728, 0.9870129823684692, 0.7153171896934509, 0.649040699005127, 0.14698298275470734, 0.8357771039009094, 0.024755029007792473, 0.9599999785423279, 0.5064902305603027, 0.45609569549560547]
600006.SH
[0.5564819574356079, 0.6736953258514404, 0.7041322588920593, 0.3494991958141327, 0.8655352592468262, 0.13020558655261993, 0.9610894918441772, 0.7370225191116333, 0.7085687518119812, 0.30313417315483093, 0.8398914337158203, 0.05239960923790932, 0.9553571343421936, 0.4815943241119385, 0.46077683568000793]
600007.SH
[0.5916787385940552, 0.6156201958656311, 0.6794366240501404, 0.26033690571784973, 0.886956512928009, 0.09596733003854752, 0.9740932583808899, 0.7126262784004211, 0.652033269405365, 0.2050504982471466, 0.8319672346115112, 0.03181818127632141, 0.9692307710647583, 0.497334361076355, 0.45062199234962463]
600008.SH
[0.5691049098968506, 0.6397739052772522, 0.6939799189567566, 0.3288797438144684, 0.8839778900146484, 0.1320657730102539, 0.9661654233932495, 0.7245358824729919, 0.6731934547424316, 0.2679377794265747, 0.8436018824577332, 0.05017561465501785, 0.9174311757087708, 0.4940340220928192, 0.4554455578327179]
600009.SH
[0.6188881397247314, 0.5797174572944641, 0.6558219194412231, 0.16195762157440186, 0.9093484282493591, 0.03279515728354454, 1.0, 0.6918753385543823, 0.6191129684448242, 0.11854879558086395, 0.8787878751754761, 0.01532958634197712, 0.9375, 0.5031734108924866, 0.44478294253349304]
600010.SH
[0.5921441912651062, 0.6390403509140015, 0.647871732711792, 0.2764449417591095, 0.8311475515365601, 0.1035986915230751, 0.949999988079071, 0.697387158870697, 0.6892018914222717, 0.227078378200531, 0.8400703072547913, 0.0370546318590641, 0.9397590160369873, 0.4656004011631012, 0.45925360918045044]
600011.SH
[0.6160035729408264, 0.5720338821411133, 0.638675332069397, 0.19968220591545105, 0.8285714387893677, 0.05932203307747841, 0.991150438785553, 0.7020965218544006, 0.6405693888664246, 0.19697707891464233, 0.800000011920929, 0.02096538245677948, 0.9555555582046509, 0.47930946946144104, 0.4292967617511749]
600012.SH
[0.5865558981895447, 0.6079456210136414, 0.6793224215507507, 0.26346054673194885, 0.8704662919044495, 0.07736539095640182, 0.9801324605941772, 0.7290226817131042, 0.6632240414619446, 0.2221125364303589, 0.855513334274292, 0.026159921661019325, 0.9636363387107849, 0.4856562614440918, 0.43462806940078735]
600015.SH
[0.6084867715835571, 0.6029331684112549, 0.6479859948158264, 0.19282998144626617, 0.8897243142127991, 0.04888647422194481, 0.9890109896659851, 0.6902927756309509, 0.6477108597755432, 0.15716487169265747, 0.8571428656578064, 0.02003081701695919, 0.9750000238418579, 0.4860084354877472, 0.4522175192832947]
600016.SH
[0.6412468552589417, 0.589085042476654, 0.5941716432571411, 0.12038522958755493, 0.8426966071128845, 0.024077046662569046, 0.9375, 0.6367149949073792, 0.6318312287330627, 0.08647342771291733, 0.8564593195915222, 0.010628019459545612, 1.0, 0.4744859039783478, 0.47042396664619446]
600017.SH
[0.557805061340332, 0.6498655676841736, 0.6931899785995483, 0.36491936445236206, 0.8715890645980835, 0.13508065044879913, 0.9757281541824341, 0.7346559166908264, 0.6946072578430176, 0.3161810338497162, 0.8279221057891846, 0.06695598363876343, 0.9908257126808167, 0.47984522581100464, 0.44985488057136536]
600018.SH
[0.5923075079917908, 0.6194751262664795, 0.6476534008979797, 0.2790055274963379, 0.8541226387023926, 0.10428176820278168, 0.9679487347602844, 0.6970825791358948, 0.6708482503890991, 0.2178770899772644, 0.839712917804718, 0.03414028510451317, 0.9166666865348816, 0.47335731983184814, 0.45276233553886414]
600019.SH
[0.6044419407844543, 0.6025437116622925, 0.6553314328193665, 0.22363540530204773, 0.8737059831619263, 0.07101219147443771, 0.9781022071838379, 0.7085769772529602, 0.659709632396698, 0.18323586881160736, 0.8430493474006653, 0.025828460231423378, 0.9298245906829834, 0.47905558347702026, 0.44046711921691895]
600020.SH
[0.5614857077598572, 0.6450549364089966, 0.6988095045089722, 0.3434065878391266, 0.8680555820465088, 0.14725275337696075, 0.9605734944343567, 0.7481334209442139, 0.6993950605392456, 0.2862120568752289, 0.8468335866928101, 0.056246887892484665, 0.9576271176338196, 0.47531992197036743, 0.4387568533420563]
# ***********************************数据过多,省略部分***********************************
300770.SZ
[0.6937742233276367, 0.5438596606254578, 0.6458333134651184, 0.12280701845884323, 0.699999988079071, 0.0, 0.0, 0.574999988079071, 0.4693877696990967, 0.05000000074505806, 0.5, 0.0, 0.0, 0.5876288414001465, 0.49484536051750183]
300771.SZ
[0.7008504271507263, 0.5964912176132202, 0.6938775777816772, 0.2631579041481018, 0.7142857313156128, 0.0, 0.0, 0.6153846383094788, 0.5106382966041565, 0.12820513546466827, 0.4545454680919647, 0.0, 0.0, 0.59375, 0.5104166865348816]
300772.SZ
[0.8239073753356934, 0.5227272510528564, 0.46000000834465027, 0.15909090638160706, 0.4117647111415863, 0.0, 0.0, 0.4375, 0.5, 0.1666666716337204, 0.800000011920929, 0.0, 0.0, 0.47826087474823, 0.54347825050354]
300773.SZ
[0.8311894536018372, 0.45098039507865906, 0.5, 0.13725490868091583, 0.46666666865348816, 0.019607843831181526, 0.5, 0.4523809552192688, 0.40425533056259155, 0.0714285746216774, 0.3333333432674408, 0.0, 0.0, 0.5483871102333069, 0.49462366104125977]
300775.SZ
[0.8624421954154968, 0.4444444477558136, 0.4444444477558136, 0.1388888955116272, 0.3125, 0.0, 0.0, 0.523809552192688, 0.523809552192688, 0.095238097012043, 0.3636363744735718, 0.0, 0.0, 0.4615384638309479, 0.4615384638309479]
300776.SZ
[0.7996936440467834, 0.39534884691238403, 0.5, 0.23255814611911774, 0.5555555820465088, 0.023255813866853714, 0.9999998807907104, 0.5405405163764954, 0.43478259444236755, 0.1621621549129486, 0.6000000238418579, 0.0, 0.0, 0.5375000238418579, 0.42500001192092896]
300777.SZ
[0.7923781871795654, 0.574999988079071, 0.5476190447807312, 0.20000000298023224, 0.42105263471603394, 0.05000000074505806, 0.4000000059604645, 0.5365853905677795, 0.5641025900840759, 0.24390244483947754, 0.7142857313156128, 0.0, 0.0, 0.4938271641731262, 0.5185185074806213]
300778.SZ
[0.7725061774253845, 0.800000011920929, 0.5161290168762207, 0.2750000059604645, 0.47826087474823, 0.05000000074505806, 1.0, 0.3333333432674408, 0.6521739363670349, 0.02222222276031971, 0.3333333432674408, 0.0, 0.0, 0.47058823704719543, 0.729411780834198]
300779.SZ
[0.8568037152290344, 0.5348837375640869, 0.574999988079071, 0.1860465109348297, 0.380952388048172, 0.023255813866853714, 0.5, 0.5, 0.45945945382118225, 0.20588235557079315, 0.5, 0.0, 0.0, 0.5584415793418884, 0.5194805264472961]
300780.SZ
[0.6919370293617249, 0.7368420958518982, 0.6222222447395325, 0.34210526943206787, 0.8125, 0.0, 0.0, 0.46875, 0.6000000238418579, 0.03125, 0.25, 0.0, 0.0, 0.5428571701049805, 0.6428571343421936]
300781.SZ
[0.7137081623077393, 0.6216216087341309, 0.6388888955116272, 0.2432432472705841, 0.692307710647583, 0.0, 0.0, 0.5517241358757019, 0.5333333611488342, 0.13793103396892548, 0.6666666865348816, 0.0, 0.0, 0.560606062412262, 0.5454545617103577]
300782.SZ
[0.6372032165527344, 0.71875, 0.6764705777168274, 0.09375, 0.5, 0.03125, 0.9999998807907104, 0.5925925970077515, 0.6399999856948853, 0.1111111119389534, 1.0, 0.0, 0.0, 0.5423728823661804, 0.5762711763381958]
300783.SZ
[0.822144091129303, 0.4545454680919647, 0.5, 0.13636364042758942, 0.5, 0.0, 0.0, 0.4736842215061188, 0.4285714328289032, 0.10526315867900848, 0.3333333432674408, 0.0, 0.0, 0.5365853905677795, 0.4878048896789551]
300785.SZ
[0.824234664440155, 0.6190476417541504, 0.5, 0.1428571492433548, 0.27272728085517883, 0.095238097012043, 0.6666666865348816, 0.31578946113586426, 0.4285714328289032, 0.10526315867900848, 0.5, 0.0, 0.0, 0.5249999761581421, 0.6499999761581421]
300786.SZ
[0.888309121131897, 0.6428571343421936, 0.375, 0.4285714328289032, 0.4000000059604645, 0.0714285746216774, 0.3333333432674408, 0.25, 0.5, 0.05000000074505806, 0.9999998807907104, 0.0, 0.0, 0.4117647111415863, 0.7058823704719543]
300787.SZ
[0.8598592281341553, 0.699999988079071, 0.5384615659713745, 0.10000000149011612, 0.3333333432674408, 0.0, 0.0, 0.1428571492433548, 0.25, 0.1428571492433548, 0.9999998807907104, 0.0, 0.0, 0.5882353186607361, 0.7647058963775635]
300788.SZ
[0.7071133255958557, 0.6666666865348816, 0.6666666865348816, 0.1666666716337204, 0.6666666865348816, 0.0, 0.0, 0.6363636255264282, 0.6363636255264282, 0.1818181872367859, 0.800000011920929, 0.0, 0.0, 0.52173912525177, 0.52173912525177]
300789.SZ
[0.9431849718093872, 0.75, 0.5, 0.5, 0.5, 0.0, 0.0, 0.25, 0.5, 0.0, 0.0, 0.0, 0.0, 0.5, 0.75]
300790.SZ
data range too small, may be date too close to boundary.
None
300791.SZ
data range too small, may be date too close to boundary.
None

交叉验证集

In [24]: evaluate_all(model, market='ALL', start_date=20180102, end_date=20191108, lookback=61, delay=1)
# code
# [loss, 0.5 recall, 0.5 precision, 0.7 recall, 0.7 precision, 0.9 recall, 0.9 precision, 0.5 neg-recall, 0.5 neg-precision, 0.3 neg-recall, 0.3 neg-precision, 0.1 neg-recall, 0.1 neg-precision, Prate, Trate]
600000.SH
[0.693852424621582, 0.3611111044883728, 0.49367088079452515, 0.0, 0.0, 0.0, 0.0, 0.6566523313522339, 0.5257731676101685, 0.004291845485568047, 0.5, 0.0, 0.0, 0.48106902837753296, 0.3518930971622467]
600004.SH
[0.7271435856819153, 0.3504672944545746, 0.5, 0.04205607622861862, 0.44999998807907104, 0.004672897048294544, 0.9999998807907104, 0.6808510422706604, 0.5351170301437378, 0.059574469923973083, 0.5185185074806213, 0.0, 0.0, 0.47661471366882324, 0.3340757191181183]
600006.SH
[0.7446303963661194, 0.4333333373069763, 0.45728641748428345, 0.09047619253396988, 0.4523809552192688, 0.009523809887468815, 1.0, 0.5481171607971191, 0.5239999890327454, 0.12970711290836334, 0.6326530575752258, 0.0, 0.0, 0.46770602464675903, 0.44320711493492126]
600007.SH
[0.7179648876190186, 0.36057692766189575, 0.4491018056869507, 0.028846153989434242, 0.6000000238418579, 0.0, 0.0, 0.6182572841644287, 0.5283687710762024, 0.029045643284916878, 0.46666666865348816, 0.0, 0.0, 0.46325168013572693, 0.37193763256073]
600008.SH
[0.7209228873252869, 0.40594059228897095, 0.48235294222831726, 0.0891089141368866, 0.529411792755127, 0.004950494971126318, 0.5, 0.6437246799468994, 0.5698924660682678, 0.11740890890359879, 0.6041666865348816, 0.004048583097755909, 0.9999998807907104, 0.4498886466026306, 0.37861916422843933]
600009.SH
[0.7013201713562012, 0.36771300435066223, 0.5430463552474976, 0.0044843051582574844, 0.25, 0.0, 0.0, 0.6946902871131897, 0.5268456339836121, 0.08407079428434372, 0.6785714030265808, 0.0, 0.0, 0.4966592490673065, 0.33630290627479553]
600010.SH
[0.7444199323654175, 0.44171780347824097, 0.35820895433425903, 0.07975459843873978, 0.3611111044883728, 0.0061349691823124886, 0.5, 0.5342960357666016, 0.6192468404769897, 0.10469313710927963, 0.6304348111152649, 0.0, 0.0, 0.3704545497894287, 0.45681819319725037]
600011.SH
[0.7138883471488953, 0.4285714328289032, 0.5113636255264282, 0.04285714402794838, 0.40909090638160706, 0.0, 0.0, 0.6401673555374146, 0.5604395866394043, 0.10878661274909973, 0.604651153087616, 0.0041841003112494946, 0.9999998807907104, 0.46770602464675903, 0.39198216795921326]
600012.SH
[0.714976966381073, 0.4439024329185486, 0.4715026021003723, 0.07804878056049347, 0.5714285969734192, 0.0, 0.0, 0.5819672346115112, 0.5546875, 0.069672130048275, 0.6538461446762085, 0.004098360426723957, 0.9999998807907104, 0.4565701484680176, 0.42984411120414734]
600015.SH
[0.6911038160324097, 0.3448275923728943, 0.48275861144065857, 0.004926108289510012, 0.5, 0.0, 0.0, 0.6951219439506531, 0.5625, 0.016260161995887756, 0.6666666865348816, 0.0, 0.0, 0.4521158039569855, 0.3229398727416992]
600016.SH
[0.7003122568130493, 0.3400000035762787, 0.4197530746459961, 0.004999999888241291, 0.5, 0.0, 0.0, 0.6224899888038635, 0.5400696992874146, 0.00803212821483612, 0.6666666865348816, 0.0, 0.0, 0.4454343020915985, 0.3608017861843109]
600017.SH
[0.7326346635818481, 0.42487046122550964, 0.4205128252506256, 0.10362694412469864, 0.5555555820465088, 0.005181347019970417, 0.5, 0.55859375, 0.5629921555519104, 0.07421875, 0.5428571701049805, 0.0, 0.0, 0.42984411120414734, 0.43429845571517944]
600018.SH
[0.7232927083969116, 0.3849765360355377, 0.46857142448425293, 0.061032865196466446, 0.5416666865348816, 0.0, 0.0, 0.6059321761131287, 0.5218977928161621, 0.09322033822536469, 0.6285714507102966, 0.0, 0.0, 0.474387526512146, 0.3897550106048584]
# ***********************************数据过多,省略部分***********************************
300722.SZ
[0.7533032298088074, 0.49344977736473083, 0.5736040472984314, 0.14847160875797272, 0.5862069129943848, 0.0043668122962117195, 0.25, 0.5714285969734192, 0.4912280738353729, 0.19387754797935486, 0.5507246255874634, 0.0, 0.0, 0.5388235449790955, 0.46352940797805786]
300723.SZ
[0.7201671600341797, 0.49514561891555786, 0.5125628113746643, 0.16019417345523834, 0.5593220591545105, 0.019417475908994675, 0.5714285969734192, 0.5488371849060059, 0.5315315127372742, 0.13488371670246124, 0.707317054271698, 0.004651162773370743, 0.9999998807907104, 0.489311158657074, 0.4726840853691101]
300724.SZ
can not find date
None
300725.SZ
[0.7402744889259338, 0.5311004519462585, 0.5235849022865295, 0.1818181872367859, 0.5757575631141663, 0.019138755276799202, 0.6666666865348816, 0.521327018737793, 0.5288461446762085, 0.11848340928554535, 0.5813953280448914, 0.0, 0.0, 0.49761903285980225, 0.5047619342803955]
300726.SZ
[0.7371765375137329, 0.5115207433700562, 0.5388349294662476, 0.1751152127981186, 0.5846154093742371, 0.0138248847797513, 1.0, 0.5273631811141968, 0.5, 0.13930347561836243, 0.5283018946647644, 0.004975124262273312, 0.9999998807907104, 0.519138753414154, 0.4928229749202728]
300727.SZ
[0.7675619721412659, 0.44954127073287964, 0.5051546096801758, 0.1376146823167801, 0.5, 0.027522936463356018, 0.8571428656578064, 0.5102040767669678, 0.4545454680919647, 0.13265305757522583, 0.48148149251937866, 0.005102040711790323, 0.9999998807907104, 0.5265700221061707, 0.46859902143478394]
300729.SZ
[0.7517649531364441, 0.5071770548820496, 0.5299999713897705, 0.13875597715377808, 0.5686274766921997, 0.0, 0.0, 0.5323383212089539, 0.5095238089561462, 0.12437810748815536, 0.5208333134651184, 0.0, 0.0, 0.5097560882568359, 0.4878048896789551]
300730.SZ
[0.771763801574707, 0.43192487955093384, 0.5257142782211304, 0.1690140813589096, 0.5538461804389954, 0.0, 0.0, 0.5631579160690308, 0.46929824352264404, 0.17894737422466278, 0.48571428656578064, 0.0, 0.0, 0.5285359621047974, 0.43424317240715027]
300731.SZ
[0.8007940649986267, 0.4545454680919647, 0.43589743971824646, 0.14438502490520477, 0.3913043439388275, 0.010695187374949455, 0.4000000059604645, 0.45812806487083435, 0.4769230782985687, 0.09852216392755508, 0.4878048896789551, 0.004926108289510012, 0.9999998807907104, 0.47948718070983887, 0.5]
300732.SZ
[0.7830197811126709, 0.5056179761886597, 0.5027933120727539, 0.13483145833015442, 0.4444444477558136, 0.00561797758564353, 0.1666666716337204, 0.5082873106002808, 0.5111111402511597, 0.14364640414714813, 0.5531914830207825, 0.01104972418397665, 1.0, 0.4958217144012451, 0.49860724806785583]
300733.SZ
can not find date
None
300735.SZ
[0.7560451626777649, 0.5095238089561462, 0.5431472063064575, 0.13333334028720856, 0.5490196347236633, 0.014285714365541935, 0.75, 0.5, 0.4663212299346924, 0.05000000074505806, 0.3461538553237915, 0.0055555556900799274, 0.9999998807907104, 0.5384615659713745, 0.5051282048225403]
300736.SZ
can not find date
None
300737.SZ
can not find date
None
300738.SZ
can not find date
None

按时间衡量模型准确度

2003-2019,按年衡量

In [27]: evaluate_total_time(model, date_step=244, steps=3, start_date='', end_date='', lookback=61, delay=1, uprate=0.0)
20030114 : 20040203
[0.6568018794059753, 0.46427900592486065, 0.5843748847643534, 0.06550159056981404, 0.773417055606842, 0.0038116557989269495, 0.8251082301139832, 0.7334766785303751, 0.6291241844495138, 0.08584380894899368, 0.7580034534136454, 0.0007994713766189913, 0.6666666666666666, 0.44664348165194195, 0.3549077312151591]
20040203 : 20050127
[0.6508340040842692, 0.535581111907959, 0.5979280471801758, 0.07494993011156718, 0.7752963105837504, 0.004925861178586881, 0.9333333373069763, 0.6983536680539449, 0.6421788732210795, 0.09007209291060765, 0.7887819210688273, 0.00033121334854513407, 0.666666587193807, 0.4558259844779968, 0.40821966528892517]
20050127 : 20060214
[0.667553981145223, 0.5478649338086446, 0.5861708919207255, 0.047291661302248635, 0.8047292033831278, 0.0020023582813640437, 0.8666666746139526, 0.6283081372578939, 0.5911598006884257, 0.03305310135086378, 0.7518921494483948, 0.0, 0.0, 0.4900597333908081, 0.4581438899040222]
20060214 : 20070424
[0.6465277671813965, 0.6215876539548238, 0.6484155257542928, 0.1326626588900884, 0.874542772769928, 0.02713957242667675, 0.9722222288449606, 0.5738389492034912, 0.5453460415204366, 0.059145551174879074, 0.7547684907913208, 0.0022163967757175365, 0.8888888955116272, 0.5583489338556925, 0.535258968671163]
20070424 : 20080428
[0.5526228348414103, 0.7043639024098715, 0.7278452714284261, 0.36831281582514447, 0.8801937301953634, 0.1299465224146843, 0.9723483721415201, 0.7016325195630392, 0.6768803000450134, 0.2890618046124776, 0.8412723143895467, 0.05818237240115801, 0.9391853213310242, 0.5311580499013265, 0.5139520565668741]
20080428 : 20090429
[0.5299248894055685, 0.7118967771530151, 0.7440837621688843, 0.4226831793785095, 0.8908692598342896, 0.18500982224941254, 0.9596697290738424, 0.7404838601748148, 0.7079606850941976, 0.33868083357810974, 0.8527288834253947, 0.051117694626251854, 0.9455873966217041, 0.5146652460098267, 0.492377628882726]
20090429 : 20100517
[0.5744742155075073, 0.7015180389086405, 0.7103689312934875, 0.32112271587053937, 0.8679341475168864, 0.09785450746615727, 0.967859665552775, 0.6708837946256002, 0.6615198651949564, 0.21318497757116953, 0.8435181975364685, 0.025281783193349838, 0.9693877498308817, 0.5349915226300558, 0.5283944010734558]
20100517 : 20110524
[0.5779383579889933, 0.6811196804046631, 0.7198743422826132, 0.3169198234875997, 0.8725708723068237, 0.08273773143688838, 0.9551554322242737, 0.6999802788098654, 0.659738302230835, 0.22190992534160614, 0.8136094411214193, 0.020922282089789707, 0.93113245566686, 0.5309797724088033, 0.5024516383806864]
20110524 : 20120525
[0.563612699508667, 0.6636995673179626, 0.6954499880472819, 0.32321880261103314, 0.8615734577178955, 0.09794096151987712, 0.9653286337852478, 0.735595683256785, 0.7062193155288696, 0.2857717474301656, 0.8460407257080078, 0.040361875047286354, 0.9873376687367758, 0.47641971707344055, 0.45466703176498413]
20120525 : 20130530
[0.5708633462587992, 0.6553651293118795, 0.7170006036758423, 0.28112663825352985, 0.8910457094510397, 0.0848972921570142, 0.9739351073900858, 0.7358709971110026, 0.6764885187149048, 0.24324760834376016, 0.8477358420689901, 0.031711009020606674, 0.9562129179636637, 0.5052153070767721, 0.4617990553379059]
20130530 : 20140605
[0.5766666332880656, 0.6566015680631002, 0.6961189905802408, 0.29123929142951965, 0.8662963509559631, 0.08678068965673447, 0.9734111825625101, 0.7191708286603292, 0.6812556187311808, 0.2289514938990275, 0.8459783395131429, 0.029651952907443047, 0.9576589266459147, 0.494873841603597, 0.46679147084554035]
20140605 : 20150603
[0.5726994872093201, 0.6853102246920267, 0.7366012533505758, 0.3206663131713867, 0.8904303908348083, 0.09374689559141795, 0.9706396659215292, 0.682244082291921, 0.6252208948135376, 0.20458133021990457, 0.8254867792129517, 0.0190351443986098, 0.9315588275591532, 0.5649460554122925, 0.5255416035652161]
20150603 : 20160719
[0.46518025795618695, 0.7420702576637268, 0.7916798988978068, 0.543129583199819, 0.9014248053232828, 0.33435707290967304, 0.9645818471908569, 0.7944092949231466, 0.7452153960863749, 0.4977227846781413, 0.8516929944356283, 0.1906640032927195, 0.9499153892199198, 0.5131496985753378, 0.48087724049886066]
20160719 : 20170720
[0.5732301076253256, 0.5929640928904215, 0.7088420589764913, 0.2925766507784526, 0.8704588413238525, 0.11173826456069946, 0.9588313500086466, 0.7756170431772867, 0.6740103562672933, 0.30973106622695923, 0.8179851770401001, 0.04629548266530037, 0.942859947681427, 0.47953999042510986, 0.40108763178189594]
20170720 : 20180719
[0.6642247041066488, 0.5296896497408549, 0.5801922480265299, 0.18797219296296439, 0.7028467853864034, 0.04001099616289139, 0.8669108748435974, 0.6634118358294169, 0.6162890593210856, 0.1820149372021357, 0.7172070145606995, 0.017918592939774197, 0.9570804635683695, 0.46759383877118427, 0.4268520971139272]
20180719 : 20190722
[0.7431689103444418, 0.48683969179789227, 0.5005723039309183, 0.11705733835697174, 0.507864753405253, 0.00803534360602498, 0.5076609253883362, 0.5368337829907736, 0.5231437285741171, 0.08796707292397817, 0.5264059901237488, 0.0013925166955838602, 0.600000003973643, 0.488187571366628, 0.4747258722782135]

在这里插入图片描述

2016-2019年,按月衡量

In [31]: evaluate_total_time(model, date_step=20, steps=3, start_date=20160104, end_date='', lookback=61, delay=1, uprate=0.0)
20160104 : 20160201
[0.3300749659538269, 0.8333790898323059, 0.8538292646408081, 0.7283818125724792, 0.9206656614939371, 0.5122989416122437, 0.9672505259513855, 0.879724125067393, 0.8622852166493734, 0.7320470015207926, 0.9174177447954813, 0.47305670380592346, 0.9752546151479086, 0.4574306805928548, 0.4464651842912038]
20160201 : 20160331
[0.4164064625898997, 0.7700598041216532, 0.8489522139231364, 0.6232134501139323, 0.9363580544789633, 0.46540990471839905, 0.9835724830627441, 0.833103617032369, 0.7484932343165079, 0.5455527305603027, 0.8428746660550436, 0.17206532756487528, 0.9406526684761047, 0.5491664409637451, 0.4981724222501119]
20160331 : 20160429
[0.5099559426307678, 0.680933674176534, 0.7686257163683573, 0.4594770272572835, 0.8953756093978882, 0.25383887191613513, 0.9734796682993571, 0.8032093644142151, 0.7237701018651327, 0.4142257372538249, 0.8234434127807617, 0.08621150255203247, 0.9528759717941284, 0.4898814260959625, 0.4338949918746948]
20160429 : 20160530
[0.5221199989318848, 0.6397658785184225, 0.736701230208079, 0.38599979877471924, 0.8892502983411154, 0.21077392001946768, 0.9650895595550537, 0.8060749967892965, 0.7252204418182373, 0.4048948486646016, 0.842205802599589, 0.13769917686780295, 0.9546858469645182, 0.45885709921518963, 0.39850228031476337]
20160530 : 20160629
[0.5209031701087952, 0.7052847544352213, 0.7694478432337443, 0.4261919856071472, 0.920309861501058, 0.23533829549948374, 0.9804604848225912, 0.7425735394159952, 0.6740126411120096, 0.2977428336938222, 0.8266976277033488, 0.08406675358613332, 0.9445225795110067, 0.5492556095123291, 0.5034323036670685]
20160629 : 20160727
[0.577679435412089, 0.6042988896369934, 0.7121192812919617, 0.2804385224978129, 0.8790800968805949, 0.11709364255269368, 0.9579868714014689, 0.7689986030260721, 0.6724971532821655, 0.25034789244333905, 0.8153106768925985, 0.033337254698077835, 0.9791463017463684, 0.48622627059618634, 0.41258804003397626]
20160727 : 20160824
[0.554732064406077, 0.6512430508931478, 0.7526447375615438, 0.3729069133599599, 0.8827725251515707, 0.18998457491397858, 0.9538133343060812, 0.7602296868960062, 0.6606735587120056, 0.33815370003382367, 0.7928603291511536, 0.04158638541897138, 0.9380980332692465, 0.5283052325248718, 0.4571632345517476]
20160824 : 20160923
[0.580666204293569, 0.6040971676508585, 0.6967231233914694, 0.2599690556526184, 0.8593732317288717, 0.07136169075965881, 0.9675402045249939, 0.7572202086448669, 0.6743767460187277, 0.30664053559303284, 0.8178765575091044, 0.03992802401383718, 0.978074312210083, 0.4800748825073242, 0.4162432054678599]
20160923 : 20161028
[0.5732341408729553, 0.586691419283549, 0.716315766175588, 0.2835877339045207, 0.8973076740900675, 0.11739646891752879, 0.9615386724472046, 0.7705871065457662, 0.6538620789845785, 0.27918680508931476, 0.8132069309552511, 0.0357852429151535, 0.9409313201904297, 0.4967460036277771, 0.4068824152151744]
20161028 : 20161125
[0.6074507435162863, 0.5327289700508118, 0.7111793756484985, 0.2069567491610845, 0.8885485728581747, 0.07067673156658809, 0.976579487323761, 0.7743290662765503, 0.6136269966761271, 0.21285533905029297, 0.7852751612663269, 0.010242866973082224, 0.9273664951324463, 0.5106534560521444, 0.382544348637263]
20161125 : 20161223
[0.5683644811312357, 0.5833013852437338, 0.6826119621594747, 0.27338143189748126, 0.864722470442454, 0.10352096954981486, 0.9643404086430868, 0.7853506604830424, 0.7042409181594849, 0.33368319272994995, 0.8386533657709757, 0.04345869769652685, 0.9276942610740662, 0.4418293635050456, 0.3775519331296285]
20161223 : 20170123
[0.5523642102877299, 0.6215600172678629, 0.7359422047932943, 0.34820979833602905, 0.8872754772504171, 0.11651075383027394, 0.9595034718513489, 0.797012209892273, 0.6985656420389811, 0.37634897232055664, 0.8233867685000101, 0.0510556697845459, 0.9010580778121948, 0.4760631223519643, 0.40224658449490863]
20170123 : 20170227
[0.5793325304985046, 0.6538892587025961, 0.7454086343447367, 0.34014496207237244, 0.8570442199707031, 0.10472188889980316, 0.9422970414161682, 0.7396511435508728, 0.6470246315002441, 0.2873672346274058, 0.7673331697781881, 0.011376170751949152, 0.9003527363141378, 0.5382009545962015, 0.47205134232838947]
20170227 : 20170327
[0.6039857069651285, 0.5411506295204163, 0.6933780908584595, 0.22313153743743896, 0.8510897159576416, 0.07072568933169048, 0.9354835549990336, 0.7844724853833517, 0.6548874179522196, 0.27066503961881, 0.783754567305247, 0.007463090121746063, 0.8140350977579752, 0.47392351428667706, 0.36988498767217]
20170327 : 20170426
[0.5763227144877116, 0.4942589004834493, 0.6415992180506388, 0.1934088667233785, 0.8342723250389099, 0.05644249667723974, 0.9770851532618204, 0.8163425525029501, 0.7080773909886678, 0.3732957144578298, 0.842408299446106, 0.07203817615906398, 0.9345154364903768, 0.39957208434740704, 0.3077471653620402]
20170426 : 20170525
[0.575218121210734, 0.5653707981109619, 0.6791580518086752, 0.25602561235427856, 0.8325321674346924, 0.06646337856849034, 0.9657052159309387, 0.7783886591593424, 0.6833484768867493, 0.35639803608258563, 0.8446292281150818, 0.06118106345335642, 0.9591080546379089, 0.45350806911786395, 0.3773736258347829]
20170525 : 20170626
[0.5323339104652405, 0.6749416589736938, 0.7707645098368326, 0.40354963143666583, 0.906677226225535, 0.19083259999752045, 0.973891019821167, 0.7609143455823263, 0.6626681685447693, 0.3228462835152944, 0.8150670131047567, 0.07761002580324809, 0.9612560669581095, 0.5436391234397888, 0.4759739637374878]
20170626 : 20170724
[0.5840664903322855, 0.6040328939755758, 0.6892185807228088, 0.2734345992406209, 0.8448772033055624, 0.07662152623136838, 0.9442191123962402, 0.7424192031224569, 0.664755642414093, 0.2839577893416087, 0.8331067562103271, 0.05305617799361547, 0.9661096334457397, 0.48604796330134076, 0.42596060037612915]
20170724 : 20170821
[0.5749452710151672, 0.6443217992782593, 0.7131819923718771, 0.3207562466462453, 0.8860552906990051, 0.10838656624158223, 0.9762597680091858, 0.7355836232503256, 0.6696333686510721, 0.24367683132489523, 0.81400199731191, 0.02395140565931797, 0.9423990249633789, 0.5050370097160339, 0.4562717278798421]
20170821 : 20170918
[0.614362915356954, 0.5368500749270121, 0.6896100242932638, 0.20729578534762064, 0.8752237359682719, 0.053821162631114326, 0.9794501264890035, 0.7633321285247803, 0.6270307302474976, 0.1895776391029358, 0.7742082476615906, 0.003888158050055305, 0.9666666587193807, 0.49505213896433514, 0.38539716601371765]
20170918 : 20171023
[0.5713058908780416, 0.5930356979370117, 0.7139697869618734, 0.2773873209953308, 0.8864696621894836, 0.07482695579528809, 0.9767130414644877, 0.7787680427233378, 0.6725957989692688, 0.296446959177653, 0.8383763829867045, 0.0320691696057717, 0.9491950869560242, 0.4822144905726115, 0.4004635810852051]
20171023 : 20171120
[0.5949939688046774, 0.5280914306640625, 0.658048411210378, 0.17660345137119293, 0.8440783818562826, 0.02302190288901329, 0.9464573264122009, 0.7902607520421346, 0.6865093111991882, 0.27049265305201214, 0.8357558449109396, 0.035086605697870255, 0.9732725222905477, 0.4334492286046346, 0.3477756977081299]
20171120 : 20171218
[0.5636487801869711, 0.6233387589454651, 0.6986027161280314, 0.31632526715596515, 0.8656089504559835, 0.11647027482589085, 0.9493562976519266, 0.7688710689544678, 0.7036671241124471, 0.32529234886169434, 0.8361475268999735, 0.06481341272592545, 0.9492471814155579, 0.46224479873975116, 0.4124097327391307]
20171218 : 20180116
[0.6507112582524618, 0.5340342919031779, 0.6045805811882019, 0.20625622073809305, 0.7602420051892599, 0.039490206787983574, 0.9187212189038595, 0.6878354748090109, 0.6228431661923727, 0.17598187426726022, 0.71885613600413, 0.014866196550428867, 0.8887022137641907, 0.47196219364802044, 0.4168672462304433]
20180116 : 20180213
[0.7278844118118286, 0.4599837561448415, 0.4602884848912557, 0.1162630170583725, 0.5063264071941376, 0.008183811946461598, 0.502400149901708, 0.5823865334192911, 0.5821098883946737, 0.09899737934271495, 0.5990180373191833, 0.002371248362275461, 0.7354497412840525, 0.4363911946614583, 0.4361237386862437]
20180213 : 20180320
[0.7609819769859314, 0.40759652853012085, 0.5561315218607584, 0.07812982300917308, 0.5676490664482117, 0.0039500615481908126, 0.6099715133508047, 0.5780378381411234, 0.429262638092041, 0.0909203365445137, 0.3707800308863322, 0.0004082465699563424, 0.07407407462596893, 0.5647677580515543, 0.413836141427358]
20180320 : 20180419
[0.7320980230967203, 0.4390726884206136, 0.5026789903640747, 0.1000448614358902, 0.5208313961823782, 0.006649315978089969, 0.5265359580516815, 0.5986337661743164, 0.5359570980072021, 0.12316566954056422, 0.5758434136708578, 0.002742952046295007, 0.5703703860441843, 0.4802531798680623, 0.41945261756579083]
20180419 : 20180521
[0.7409135103225708, 0.47827096780141193, 0.5364595651626587, 0.1279920091231664, 0.5895991921424866, 0.01170201258112987, 0.6014203230539957, 0.5556497176488241, 0.49759358167648315, 0.078870490193367, 0.44944079717000324, 0.0011092253068151574, 0.3682539810736974, 0.5182312726974487, 0.46188820401827496]
20180521 : 20180619
[0.7485045591990153, 0.42101940512657166, 0.3634924689928691, 0.10312127818663915, 0.3786735236644745, 0.008763244065145651, 0.4955555597941081, 0.5734093983968099, 0.6313724716504415, 0.11546564350525539, 0.6275175015131632, 0.0019732690804327526, 0.5111111303170522, 0.3663189808527629, 0.4246233403682709]
20180619 : 20180717
[0.7536139289538065, 0.4777560234069824, 0.5488387942314148, 0.14947044352690378, 0.5735464890797933, 0.015432522632181644, 0.6164737145105997, 0.5742554068565369, 0.5035801033178965, 0.14320466915766397, 0.48119376103083294, 0.003156732146938642, 0.42606837550799054, 0.5201034148534139, 0.452794869740804]
20180717 : 20180814
[0.7543952465057373, 0.4393500288327535, 0.4530588189760844, 0.10608174155155818, 0.46461119254430133, 0.006722171790897846, 0.4158549904823303, 0.559668223063151, 0.5459282199541727, 0.11492372552553813, 0.5481685002644857, 0.0016486222545305889, 0.4662698457638423, 0.45359720786412555, 0.4397789041201274]
20180814 : 20180911
[0.7457842429478964, 0.45029595494270325, 0.4246404270331065, 0.12239157408475876, 0.45915862917900085, 0.011601551125446955, 0.4523486892382304, 0.5808447599411011, 0.605999767780304, 0.12287912021080653, 0.5751838684082031, 0.002706772298552096, 0.6160714328289032, 0.4072390099366506, 0.43184452255566913]
20180911 : 20181017
[0.7386055986086527, 0.4584916631380717, 0.48098509510358173, 0.10596313327550888, 0.48386502265930176, 0.009336340085913738, 0.5197132527828217, 0.5827573935190836, 0.560651163260142, 0.1124221682548523, 0.5634302099545797, 0.0013082946922319632, 0.5714285870393118, 0.4575198292732239, 0.4361237386862437]
20181017 : 20181114
[0.7217960158983866, 0.48923853039741516, 0.6020886500676473, 0.12335498879353206, 0.6565253138542175, 0.0061720275940994425, 0.5736111203829447, 0.5822777350743612, 0.46872907876968384, 0.12075733641783397, 0.5177871783574423, 0.0010121881496161222, 0.31111112236976624, 0.5637871225674947, 0.45805474122365314]
20181114 : 20181212
[0.7150911688804626, 0.45989829301834106, 0.5372519493103027, 0.09378999720017116, 0.5626257260640463, 0.004192532428229849, 0.47306398550669354, 0.6187326510747274, 0.5434430440266927, 0.09837564080953598, 0.5888131856918335, 0.0007019117280530432, 0.6666666368643442, 0.4904163380463918, 0.41980921228726703]
20181212 : 20190111
[0.7351961731910706, 0.5297191540400187, 0.48315619428952533, 0.1464984118938446, 0.5075726807117462, 0.014215043745934963, 0.6187636852264404, 0.5218847791353861, 0.5680830478668213, 0.07329785575469334, 0.5738670229911804, 0.001151621089472125, 0.36666667461395264, 0.4576089878877004, 0.5017384390036265]
20190111 : 20190215
[0.7424782117207845, 0.46431917945543927, 0.49375678102175397, 0.08034547666708629, 0.47481729586919147, 0.005360288700709741, 0.5629458427429199, 0.5253990491231283, 0.4959198832511902, 0.05608691523472468, 0.46760066350301105, 0.0001793078651341299, 0.08333333333333333, 0.49924222628275555, 0.4694659908612569]
20190215 : 20190315
[0.7225977381070455, 0.4271313150723775, 0.645826001962026, 0.04964143534501394, 0.7091577053070068, 0.002311936734865109, 0.5704665879408518, 0.6215667525927225, 0.40172789494196576, 0.04497983058293661, 0.3651146988073985, 0.0002305209830713769, 0.3333332935969035, 0.6178122361501058, 0.40848712126413983]
20190315 : 20190415
[0.7330876191457113, 0.5333328445752462, 0.5149634480476379, 0.10322515666484833, 0.5073318779468536, 0.007539146890242894, 0.5527859230836233, 0.5027465720971426, 0.5211473703384399, 0.0569533904393514, 0.5033248861630758, 0.0021315155706057944, 0.614814817905426, 0.4974592129389445, 0.515200138092041]
20190415 : 20190516
[0.7772349715232849, 0.5582201282183329, 0.465119868516922, 0.18224313855171204, 0.48673463861147565, 0.02209085536499818, 0.5040304362773895, 0.4351457456747691, 0.528003474076589, 0.07193617274363835, 0.5164011915524801, 0.0032036421665300927, 0.48055557409922284, 0.4679504334926605, 0.5614691972732544]
20190516 : 20190614
[0.770625114440918, 0.49911148349444073, 0.4433234731356303, 0.15570829312006632, 0.4420177141825358, 0.020726947113871574, 0.46438642342885333, 0.5272402366002401, 0.5824871857961019, 0.11702437698841095, 0.5896152456601461, 0.0018686645586664479, 0.28733766575654346, 0.4299723605314891, 0.48399750391642254]
20190614 : 20190712
[0.7580109437306722, 0.5463989575703939, 0.5106613536675771, 0.16729296743869781, 0.5230797231197357, 0.018630903214216232, 0.4902886251608531, 0.4822925428549449, 0.5181306600570679, 0.07482960323492686, 0.4963911871115367, 0.0023025821428745985, 0.5833333432674408, 0.4971026082833608, 0.5320495764414469]
20190712 : 20190809
[0.7350301941235861, 0.4402405619621277, 0.4303315778573354, 0.09702148040135701, 0.4221001962820689, 0.007121649881203969, 0.4404761989911397, 0.5892868041992188, 0.5989782015482584, 0.11840027074019115, 0.6435391902923584, 0.002883524284698069, 0.6515151659647623, 0.4134795367717743, 0.42292948563893634]
20190809 : 20190906
[0.7477307319641113, 0.36333195368448895, 0.5445758104324341, 0.06072922423481941, 0.5755683382352194, 0.003100892761722207, 0.5476190646489462, 0.631607711315155, 0.4500224788983663, 0.10058060536781947, 0.44978277881940204, 0.0011707139977564414, 0.48888889451821643, 0.5480074882507324, 0.36560577154159546]
20190906 : 20191014
[0.7384669780731201, 0.47571444511413574, 0.5011145075162252, 0.10083309312661488, 0.5169278581937155, 0.00750604597851634, 0.49126983682314557, 0.5304775635401408, 0.5050997932751974, 0.05927569419145584, 0.4959230919679006, 0.0007218405759582917, 0.4166666666666667, 0.4978158175945282, 0.4725862542788188]
20191014 : 20191111
[0.7220893700917562, 0.48640816410382587, 0.4252362052599589, 0.11220294733842213, 0.4682639042536418, 0.008166713795314232, 0.5461446841557821, 0.5729021032651266, 0.6318842768669128, 0.08572812626759212, 0.6430438955624899, 0.0017701273318380117, 0.725000003973643, 0.39386645952860516, 0.45038779576619464]
20191111 : 20191209
[0.7353871663411459, 0.42778585354487103, 0.5087565382321676, 0.09316737701495488, 0.5437935789426168, 0.00675405686100324, 0.6495310366153717, 0.6070831418037415, 0.5271624724070231, 0.09810450424750645, 0.5071024199326833, 0.001036028688152631, 0.45000000794728595, 0.48765265941619873, 0.40973522265752155]

在这里插入图片描述

2019年,按天衡量

In [28]: evaluate_total_time(model, date_step=1, steps=3, start_date=20190102, end_date='', lookback=61, delay=1, uprate=0.0)
20190102 : 20190103
[0.705572267373403, 0.16796444356441498, 0.34837595621744794, 0.019603818655014038, 0.2872556348641713, 0.0004671862892185648, 0.3333333333333333, 0.8054861227671305, 0.6101937095324198, 0.20878386000792185, 0.6408450206120809, 0.0030286312103271484, 1.0, 0.38209859530131024, 0.18436301747957864]
20190103 : 20190104
[0.7293160359064738, 0.4662709931532542, 0.9566953976949056, 0.11041913429896037, 0.9694114128748575, 0.005439450653890769, 0.9207017620404562, 0.5960521896680196, 0.05528130133946737, 0.078614491969347, 0.061956015725930534, 0.0, 0.0, 0.9502540628115336, 0.46313631534576416]
20190104 : 20190107
[0.611860195795695, 0.6754318475723267, 0.9231745402018229, 0.21614141762256622, 0.9518770178159078, 0.013156835610667864, 0.9726867278416952, 0.4886184235413869, 0.1419760783513387, 0.0676371989150842, 0.1624330331881841, 0.0, 0.0, 0.9009539087613424, 0.6591780384381613]
20190107 : 20190108
[0.7243722875912985, 0.4978080987930298, 0.4071895082791646, 0.08383707453807195, 0.3930290639400482, 0.0060423092606167, 0.5100233256816864, 0.5485713283220927, 0.6368274291356405, 0.05625045796235403, 0.6789639393488566, 0.001010074425721541, 0.8222221930821737, 0.3837924599647522, 0.4691985348860423]
20190108 : 20190109
[0.7111263871192932, 0.40813730160395306, 0.5202601154645284, 0.03683588653802872, 0.4601799746354421, 0.0007390297250822186, 0.26666667064030963, 0.6501630942026774, 0.5416418711344401, 0.03953922167420387, 0.5126555760701498, 0.0, 0.0, 0.4817687471707662, 0.37790852785110474]
20190109 : 20190110
[0.7426256934801737, 0.5168083707491556, 0.38925134142239887, 0.08652547498544057, 0.3327723840872447, 0.0026520504616200924, 0.35952381292978924, 0.5255582531293234, 0.6502917011578878, 0.05014399935801824, 0.5897094209988912, 0.00042585157401238877, 0.6666666269302368, 0.369082639614741, 0.4901488820711772]
20190110 : 20190111
[0.6785141626993815, 0.602329154809316, 0.7728579839070638, 0.12033901115258534, 0.7923774719238281, 0.004967429519941409, 0.8315789500872294, 0.5077767968177795, 0.31471818685531616, 0.04856595521171888, 0.3893883327643077, 0.0010114980395883322, 0.9999998807907104, 0.7355799078941345, 0.5732370416323344]
20190111 : 20190114
[0.6953628063201904, 0.2894180317719777, 0.28369874755541485, 0.034236966321865715, 0.2994111180305481, 0.0, 0.0, 0.6982697447141012, 0.704126238822937, 0.09876756866772969, 0.6752941211064657, 0.0003779436810873449, 0.9999998807907104, 0.29223500688870746, 0.29811891913414]
20190114 : 20190115
[0.7420263091723124, 0.43638981382052106, 0.822896420955658, 0.05986353134115537, 0.7901981472969055, 0.002437230432406068, 0.8773809472719828, 0.614015797773997, 0.20959109564622244, 0.036879474918047585, 0.14328667024771372, 0.0, 0.0, 0.8043148914972941, 0.4264954924583435]
20190115 : 20190116
[0.732903261979421, 0.509804348150889, 0.36012641588846844, 0.05011440689365069, 0.3243343234062195, 0.0024426247303684554, 0.4212121268113454, 0.482978622118632, 0.6332911849021912, 0.032063632582624756, 0.6195197502772013, 0.00041390730378528434, 0.3333333333333333, 0.36337701479593915, 0.5144869486490885]
20190116 : 20190117
[0.7116859952608744, 0.4611208339532216, 0.20259833335876465, 0.05057622243960699, 0.21190446615219116, 0.0027048024348914623, 0.3777777850627899, 0.5551392634709676, 0.8075803319613138, 0.05120836322506269, 0.7793605923652649, 0.0006700240968105694, 0.6666666666666666, 0.19702237844467163, 0.44798075159390766]
20190117 : 20190118
[0.7438547015190125, 0.4360784391562144, 0.6822526653607687, 0.062473613768815994, 0.6320405205090841, 0.0030737899554272494, 0.8671024044354757, 0.5344203511873881, 0.2924879988034566, 0.03547515037159125, 0.2521167993545532, 0.0, 0.0, 0.6962645848592123, 0.4450387756029765]
20190118 : 20190121
[0.7438400189081827, 0.39680179953575134, 0.6188981930414835, 0.04004719853401184, 0.6233495473861694, 0.0017499179036046069, 0.6500000059604645, 0.6146770914395651, 0.39254732926686603, 0.047346084068218865, 0.29914529124895733, 0.0, 0.0, 0.6119283239046732, 0.39235090216000873]
20190121 : 20190122
[0.6645412643750509, 0.3758576611677806, 0.21247625350952148, 0.06790528694788615, 0.3083601991335551, 0.003910915615657966, 0.55158731341362, 0.6874891320864359, 0.8308535814285278, 0.08534777909517288, 0.8197145859400431, 0.0003274783957749605, 0.9999998807907104, 0.18320405979951224, 0.3241508404413859]
# ***********************************数据过多,省略部分***********************************

在这里插入图片描述

历史预测曲线

2017六月到2018六月.

In [34]: history_predict(model, ts_code='600004.SH', date=20180601, delay=1, during=244, mod='simple')
600004.SH
[0.5915805101394653, 0.623115599155426, 0.6919642686843872, 0.22914573550224304, 0.8923678994178772, 0.07638190686702728, 0.9870129823684692, 0.7153171896934509, 0.649040699005127, 0.14698298275470734, 0.8357771039009094, 0.024755029007792473, 0.9599999785423279, 0.5064902305603027, 0.45609569549560547]
数据分割日: 3518

在这里插入图片描述

总结

多个模型发现均出现过拟合现象(训练准确率可达0.61,以0.9作为baseline准确率甚至高达0.96),经过分析发现过拟合的原因应该是通过对过去的大盘走向进行过拟合,毕竟800,000个参数(上述混合Attention模型)应该不至于拟合深交所和上交所从2004年到2018年的所有数据.
训练期间尝试了进行统一归一化(即所有股票计算出一个mid和std进行归一化),发现并不能提升性能反而会导致训练过程不稳定.
之前做图像识别的时候有很多预训练模型,可以直接拿来用,但是找了下发现没有进行时间序列预测的预训练模型.
训练过程中进行过多种条件的预测,包括:

  • 使用过去一年的数据对一个月后是否上涨10%进行预测(由于数据不平衡,训练曲线波动极大,练不起来),
  • 使用过去一年的数据对一个月后是否上涨进行预测(更容易过拟合),
  • 使用过去四个月的数据对明天是否上涨进行预测(上述例子),
  • 使用过去一个月数据对明天上涨幅度进行预测(卡在baseline练不起来).

正常情况下,对模型阀值(baseline)取得越高,recall越低,准确率越高.
尝试过在Generator出口进行正则化(直接加了个BN层在模型开头),好像能一定程度上提高稳定性.
一开始使用沪股作训练集,港股作测试集,发现无法发现过拟合现象,后来以时间作为分割标准,准确地识别过拟合现象.
最初使用的生成器每次只能生成同一只股票的数据,现在换了个每次生成的数据完全随机的生成器,虽然性能降低了,但是对稳定性应该有帮助.

  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值