Python构建一个“聚宽式”的策略回测平台

1.用户编写的策略脚本,和聚宽风格保持一致,代码可以在聚宽平台运行。

userconfig = {
    "start":"2018-01-01",
    "end":"2018-03-29"
}

def initialize(context):
    context.current_dt = '2018-01-01'
    g.today = "2018-03-23"
    run_daily(market_open, time='open')

def market_open(context):
    print('context:',context.current_dt)
    pass

def handle_data(context, data):
    #print('HANDLE:',context.current_dt)
    pass

2.策略环境配置

# -*- coding: utf-8 -*-
from userStrategy import userconfig
config = {
  "mod": {
    "stock": {
      "enabled": True,
    },
    "future": {
        "enabled": False,
    }
  }
}
from events import EventBus
class Env(object):
    _env = None
    def __init__(self, config):
        Env._env = self
        self.config = config
        self.event_bus = EventBus()
        self.usercfg = userconfig
        self.global_vars = None
        self.current_dt = None
        self.event_source = None

    @classmethod
    def get_instance(cls):
        """
        返回已经创建的 Environment 对象
        """
        if Env._env is None:
            raise RuntimeError("策略还未初始化")
        return Env._env

    def set_global_vars(self, global_vars):
        self.global_vars = global_vars

    def set_event_source(self, event_source):
        self.event_source = event_source

3.事件及事件注册机

# -*- coding: utf-8 -*-
from enum import Enum
from collections import defaultdict


class Event(object):
    def __init__(self, event_type, **kwargs):
        self.__dict__ = kwargs
        self.event_type = event_type

    def __repr__(self):
        return ' '.join('{}:{}'.format(k, v) for k, v in self.__dict__.items())


class EventBus(object):
    def __init__(self):
        self._listeners = defaultdict(list)

    def add_listener(self, event, listener):
        self._listeners[event].append(listener)

    def prepend_listener(self, event, listener):
        self._listeners[event].insert(0, listener)

    def publish_event(self, event):
        for l in self._listeners[event.event_type]:
            # 如果返回 True ,那么消息不再传递下去
            if l(event):
                break

class EVENT(Enum):
    # 股票
    STOCK = 'stock'
    # 期货
    FUTURE = 'future'
    #事件
    TIME = 'time'

def parse_event(event_str):
    return EVENT.__members__.get(event_str.upper(), None)
4.事件分发器
from  events import  EVENT,Event

class Executor(object):
    def __init__(self, env):
        self._env = env

    KNOWN_EVENTS = {
        EVENT.TICK,
        EVENT.BAR,
        EVENT.BEFORE_TRADING,
        EVENT.AFTER_TRADING,
        EVENT.POST_SETTLEMENT,
    }

    def run(self, bar_dict):

        start = self._env.usercfg['start']
        end = self._env.usercfg['end']
        frequency = self._env.config.base.frequency
        event_bus = self._env.event_bus

        for event in self._env.event_source.events(start, end, frequency):
            if event.event_type in self.KNOWN_EVENTS:
                self._env.calendar_dt = event.calendar_dt
                #self._env.trading_dt = event.trading_dt

                event_bus.publish_event(event)

5.业务处理模块加载及handler

# -*- coding: utf-8 -*-

from collections import OrderedDict
from importlib import import_module

class ModHandler(object):
    def __init__(self):
        self._env = None
        self._mod_list = list()
        self._mod_dict = OrderedDict()

    def set_env(self, environment):
        self._env = environment

        config = environment.config

        for mod_name in config['mod']:
            if config['mod'][mod_name]['enabled'] == False:
                continue
            self._mod_list.append(mod_name)

    def start_up(self):
        for mod_name in self._mod_list:
            #动态加载模块
            mod = import_module(mod_name)
            for i in mod.__all__:
                i(self._env.event_bus)
# -*- coding: utf-8 -*-
from events import EVENT
__all__ = [
]
def export_as_api(func):
    __all__.append(func)
    return func

@export_as_api
def startup(event_bus):
    event_bus.add_listener(EVENT.STOCK, handler)
    print('load and register future mod')

def handler(event):
    print('do future handler success')
    pass
# -*- coding: utf-8 -*-
import random
from events import EVENT
__all__ = [
]
def export_as_api(func):
    __all__.append(func)
    return func

@export_as_api
def startup(event_bus):
    event_bus.add_listener(EVENT.STOCK, handler)
    print('load and register stock mod')

def handler(event):
    rd = random.randint(0,9)
    if rd == 7:
        print('成交记录:UserTrade({''secu:000001.XHSG,''order_id'': 1522310538, ''trade_id'': 1522310538, ''price'': 10.52, ''amount'': 2300')
    if rd == 5:
        print('成交记录:UserTrade({''secu:000254.XHSG,''order_id'': 1522310538, ''trade_id'': 1522310538, ''price'': 23.52, ''amount'': 1700')
    if rd == 2:
        print('成交记录:UserTrade({''secu:600012.XHSG,''order_id'': 1522310538, ''trade_id'': 1522310538, ''price'': 13.52, ''amount'': 1700')
    pass

6.处理策略文本主程序

# -*- coding: utf-8 -*-
from Env import  Env, config
from Mod import  ModHandler
from events import EVENT, Event
from globalVars import GlobalVars
from CodeLoader import  CodeLoader
from strategy_context import StrategyContext
import api
from api import Scheduler
import  datetime
def run_file(strategy_file_path):
    #加载configenv
    env = Env(config)
    #启动加载模块
    mod = ModHandler()
    #加载模块中注入config
    mod.set_env(env)
    #启动加载
    mod.start_up()

    loader = CodeLoader(strategy_file_path)
    scope = {}
    scope = loader.load(scope)
    env.set_global_vars(GlobalVars())
    scope.update({
        "g": env.global_vars

    })
    env.current_dt =  datetime.datetime.strptime(env.usercfg['start'], "%Y-%m-%d")
    context = StrategyContext()
    scheduler = Scheduler()
    scheduler.set_user_context(context)
    api._scheduler = scheduler

    f = scope.get('initialize', None)
    f(context)

    data = {}
    f1 = scope.get('handle_data', None)
    f1(context,data)




    #事件发布
    # event_bus = env.event_bus
    # event_bus.publish_event(Event(EVENT.STOCK))
    # event_bus.publish_event(Event(EVENT.FUTURE))

7.类加载器

import codecs
import copy
from six import exec_

class CodeLoader:
    def __init__(self, strategyfile):
        self._strategyfile = strategyfile

    def compile_strategy(self, source_code, strategyfile, scope):
        code = compile(source_code, strategyfile, 'exec')
        exec_(code, scope)
        return scope
    #
    def load(self, scope):
        with codecs.open(self._strategyfile, encoding="utf-8") as h:
            source_code = h.read()
        source_code = 'from api import *\n' + source_code
        return self.compile_strategy(source_code, self._strategyfile, scope)

8.api

# -*- coding: utf-8 -*-
from Env import Env
import datetime
from  events import Event, EVENT

_scheduler = None

def run_daily(func, time):
    _scheduler._run_daily(func, time)

class Scheduler(object):
    def __init__(self):
        # self._registry = []
        # self._today = None
        # self._this_week = None
        # self._this_month = None
        # self._last_minute = 0
        # self._current_minute = 0
        # self._stage = None
        self._ucontext = None
        # self._frequency = frequency

        # event_bus = Environment.get_instance().event_bus
        # event_bus.add_listener(EVENT.PRE_BEFORE_TRADING, self.next_day_)
        # event_bus.add_listener(EVENT.BEFORE_TRADING, self.before_trading_)
        # event_bus.add_listener(EVENT.BAR, self.next_bar_)

    def set_user_context(self, ucontext):
        self._ucontext = ucontext

    def _run_daily(self,func, time):
        env = Env.get_instance()
        event_bus = env.event_bus
        eventSrc = []
        if time == 'open':
            start = datetime.datetime.strptime(env.usercfg['start'], "%Y-%m-%d")
            end = datetime.datetime.strptime(env.usercfg['end'], "%Y-%m-%d")
            for i in range((end - start).days + 1):
                day = start + datetime.timedelta(days=i, hours=9, minutes=30)
                event = Event(EVENT.TIME)
                eventSrc.append(event)

                print(day)
                event_bus.publish_event(Event(EVENT.STOCK))
                # for i in eventSrc:
                #    for k, v in i.__dict__.items():
                #        print(type(k))
                #        pass




                func(self._ucontext)
                self._ucontext.current_dt = day
                env.current_dt = day

        pass
9.全局g
import six
import pickle
class GlobalVars(object):
    def get_state(self):
        dict_data = {}
        for key, value in six.iteritems(self.__dict__):
            try:
                dict_data[key] = pickle.dumps(value)
            except Exception as e:
                print('CCC')
        return pickle.dumps(dict_data)

    def set_state(self, state):
        dict_data = pickle.loads(state)
        for key, value in six.iteritems(dict_data):
            try:
                self.__dict__[key] = pickle.loads(value)
            except Exception as e:
                print('CCC')
10.context场景

import six
import pickle
class GlobalVars(object):
    def get_state(self):
        dict_data = {}
        for key, value in six.iteritems(self.__dict__):
            try:
                dict_data[key] = pickle.dumps(value)
            except Exception as e:
                print('CCC')
        return pickle.dumps(dict_data)

    def set_state(self, state):
        dict_data = pickle.loads(state)
        for key, value in six.iteritems(dict_data):
            try:
                self.__dict__[key] = pickle.loads(value)
            except Exception as e:
                print('CCC')

11.执行策略的脚本

# -*- coding: utf-8 -*-
from main import run_file
file_path = "./userStrategy.py"
run_file(file_path)

12所有类图


13。测试


14.把1中的代码拷贝至聚宽中执行


15.结束。

阅读更多
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

关闭
关闭
关闭