使用jqdata和hikyuu平台进行C++/python混合策略编写的方法

很多时候为了运行复杂的策略用python速度会很慢,而核心部分用C 编写可以大幅提升策略的运行速度。另外通达信、金字塔等主流证券软件都支持C 的dll库,而且可以很方便地图形化展示策略结果,那么策略核心部分用C 编写成dll库也是一种通用的跨平台方案。

而传统的python对C 库调用方法,需要自己编写很多封装函数,且聚宽的策略回测平台本身也不支持调用本地的C 库。
这时可以借助一个开源的第三方平台hikyuu来方便地完成该需求。将jqdata与hikyuu整合起来实现C /python混合编程。
首先我们需要在hikyuu的C 工程文件中添加自己的策略代码,自己的策略代码可以作为自定义指标的一部分。

在hikyuu_msvc10工程下增加一个指标,首先在indicator/crt目录下增加一个策略的包装头文件,例如一个移动平均线的策略:

#ifndef EMA_H_
#define EMA_H_

#include "../Indicator.h"

namespace hku {
/**
 * 指数移动平均线(Exponential Moving Average)
 * @param n 计算均值的周期窗口,必须为大于0的整数
 * @ingroup Indicator
 */

Indicator HKU_API EMA(int n = 22);

/**
 * 指数移动平均线(Exponential Moving Average)
 * @param data 待计算的源数据
 * @param n 计算均值的周期窗口,必须为大于0的整数
 * @ingroup Indicator
 */

Indicator HKU_API EMA(const Indicator& data, int n = 22);
} /* namespace */

#endif /* EMA_H_ */



然后在indicator/imp目录中增加这个策略的实现类,包含头文件和实现文件:

#ifndef EMA_H_
#define EMA_H_

#include "../Indicator.h"

namespace hku {

/*
 * 指数移动平均线(Exponential Moving Average)
 * 参数: n: 计算均值的周期窗口,必须为大于0的整数
 * 抛弃数 = 0
 */

class Ema: public IndicatorImp {
    INDICATOR_IMP(Ema)
    INDICATOR_IMP_NO_PRIVATE_MEMBER_SERIALIZATION
public:
    Ema();
    virtual ~Ema();
};

} /* namespace hku */

#endif /* EMA_H_ */



#include "Ema.h"

namespace hku {
Ema::Ema(): IndicatorImp("EMA", 1) {
    setParam("n", 22);
}

Ema::~Ema() {

}

bool Ema::check() {
    int n = getParam("n");
&nbsp;&nbsp;&nbsp; if (n <= 0) {
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; HKU_ERROR("Invalid param[n] must > 0 ! [Ema::check]");
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; return false;
&nbsp;&nbsp;&nbsp; }

&nbsp;&nbsp;&nbsp; return true;
}

void Ema::_calculate(const Indicator&amp; indicator) {
&nbsp;&nbsp;&nbsp; size_t total = indicator.size();
&nbsp;&nbsp;&nbsp; int n = getParam("n");
&nbsp;&nbsp;&nbsp; m_discard = indicator.discard();
&nbsp;&nbsp;&nbsp; if (total <= m_discard) {
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; return;
&nbsp;&nbsp;&nbsp; }

&nbsp;&nbsp;&nbsp; size_t startPos = discard();
&nbsp;&nbsp;&nbsp; price_t ema = indicator[startPos];
&nbsp;&nbsp;&nbsp; _set(ema, startPos);
&nbsp;&nbsp;&nbsp; price_t multiplier = 2.0 / (n   1);
&nbsp;&nbsp;&nbsp; for (size_t i = startPos   1; i < total;   i) {
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ema = (indicator[i] - ema) * multiplier   ema;
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; _set(ema, i);
&nbsp;&nbsp;&nbsp; }
}

Indicator HKU_API EMA(int n) {
&nbsp;&nbsp;&nbsp; IndicatorImpPtr p = make_shared();
&nbsp;&nbsp;&nbsp; p->setParam("n", n);
&nbsp;&nbsp;&nbsp; return Indicator(p);
}

Indicator HKU_API EMA(const Indicator&amp; data, int n) {
&nbsp;&nbsp;&nbsp; IndicatorImpPtr p = make_shared();
&nbsp;&nbsp;&nbsp; p->setParam("n", n);
&nbsp;&nbsp;&nbsp; p->calculate(data);
&nbsp;&nbsp;&nbsp; return Indicator(p);
}

} /* namespace hku */

最后在indicator目录下的build_in.h文件中增加包含关系:

#include "crt/EMA.h"

在indicator目录下的export.cpp文件中导出策略类:

#include "imp/Ema.h"
BOOST_CLASS_EXPORT(hku::Ema)

最后编译整个hikyuu_msvc10工程得到一个新的dll库,直接代替原hikyuu相应的dll库就实现了C 策略类的导出。

第二步就是在python中直接使用这个新的移动平均线策略指标。
首先导入jqdata和hikyuu

from jqdatasdk import *
from hikyuu.interactive.interactive import *

然后封装一个由jqdata作为数据源的自定义数据源类,具体的要实现的接口可以参考hikyuu平台的文档。这个封装只需要编写一次即可,不需要每个策略都编写。
封装类如下:

from .._hikyuu import KDataDriver, DataDriverFactory
from hikyuu import KRecord, Query, Datetime, Parameter

from jqdatasdk import *
from datetime import *

class jqdataKDataDriver(KDataDriver):
    def __init__(self):
        super(jqdataKDataDriver, self).__init__('jqdata')

    def _init(self):
        """【重载接口】(可选)初始化子类私有变量"""
        self._max = {Query.DAY:10,
                     Query.WEEK:2,
                     Query.MONTH:1,
                     Query.QUARTER:1,
                     #Query.HALFYEAR:1,
                     Query.YEAR:1,
                     Query.MIN:25,
                     Query.MIN5:25,
                     Query.MIN15:25,
                     Query.MIN30:25,
                     Query.MIN60:25}
        return 

    def loadKData(self, market, code, ktype, start_ix, end_ix, out_buffer):
        """
        【重载接口】(必须)按指定的位置[start_ix, end_ix)读取K线数据至out_buffer
        :param str market: 市场标识
        :param str code: 证券代码
        :param KQuery.KType ktype: K线类型
        :param int start_ix: 起始位置
        :param int end_ix: 结束位置
        :param KRecordListPtr out_buffer: 传入的数据缓存,读取数据后使用 
                                           out_buffer.append(krecord) 加入数据        
        """
        if start_ix >= end_ix or start_ix <0 or end_ix <0:
            return

        data = self._get_bars(market, code, ktype)

        if len(data) < start_ix:
            return

        total = end_ix if end_ix < len(data) else len(data)
        for i in range(start_ix, total):
            record = KRecord()
            record.datetime = Datetime(data.index[i])
            record.openPrice = data['open'][i]
            record.highPrice = data['high'][i]
            record.lowPrice = data['low'][i]
            record.closePrice = data['close'][i]
            record.transAmount = data['money'][i]
            record.transCount = data['volume'][i]
            out_buffer.append(record)


    def getCount(self, market, code, ktype):
        """
        【重载接口】(必须)获取K线数量
        :param str market: 市场标识
        :param str code: 证券代码
        :param KQuery.KType ktype: K线类型        
        """
        data = self._get_bars(market, code, ktype)
        return len(data)

    def _getIndexRangeByDate(self, market, code, query):
        """
        【重载接口】(必须)按日期获取指定的K线数据
        :param str market: 市场标识
        :param str code: 证券代码
        :param KQuery query: 日期查询条件(QueryByDate)        
        """
        print("getIndexRangeByDate")

        if query.queryType != Query.DATE:
            return (0, 0)

        start_datetime = query.startDatetime
        end_datetime = query.endDatetime
        if start_datetime >= end_datetime or start_datetime > Datetime.max():
            return (0, 0)

        data = self._get_bars(market, code, query.kType)
        total = len(data)
        if total == 0:
            return (0, 0)

        mid, low = 0, 0
        high = total-1
        while low <= high:
            tmp_datetime = Datetime(data.index[high])
            if start_datetime > tmp_datetime:
                mid = high   1
                break

            tmp_datetime = Datetime(data.index[low])
            if tmp_datetime >= start_datetime:
                mid = low
                break

            mid = (low   high) // 2
            tmp_datetime = Datetime(data.index[mid])
            if start_datetime > tmp_datetime:
                low = mid   1
            else:
                high = mid - 1

        if mid >= total:
            return (0, 0)

        start_pos = mid
        low = mid
        high = total - 1
        while low <= high:
            tmp_datetime = Datetime(data.index[high])
            if end_datetime > tmp_datetime:
                mid = high   1
                break

            tmp_datetime = Datetime(data.index[low])
            if tmp_datetime >= end_datetime:
                mid = low
                break

            mid = (low   high) // 2
            tmp_datetime = Datetime(data.index[mid])
            if end_datetime > tmp_datetime:
                low = mid   1
            else:
                high = mid - 1

        end_pos = total if mid >= total else mid
        if start_pos >= end_pos:
            return (0,0)

        return (start_pos, end_pos)


    def getKRecord(self, market, code, pos, ktype):
        """
        【重载接口】(必须)获取指定位置的K线记录
        :param str market: 市场标识
        :param str code: 证券代码
        :param int pos: 指定位置(大于等于0)
        :param KQuery.KType ktype: K线类型        
        """
        record = KRecord()
        if pos < 0:
            return record

        data = self._get_bars(market, code, ktype)
        if data is None:
            return record

        if pos < len(data):
            record.datetime =  Datetime(data.index[pos])
            record.openPrice = data['open'][pos]
            record.highPrice = data['high'][pos]
            record.lowPrice = data['low'][pos]
            record.closePrice = data['close'][pos]
            record.transAmount = data['money'][pos]
            record.transCount = data['volume'][pos]

        return record


    def _trans_ktype(self, ktype): #此处的周月季年数据只是近似的,目前jqdata未提供聚宽网络平台上的get_bar函数,不能直接取,需要自行用日线数据拼装
        ktype_map = {Query.MIN: '1m',
                     Query.MIN5: '5m',
                     Query.MIN15: '15m',
                     Query.MIN30: '30m',
                     Query.MIN60: '60m',
                     Query.DAY: '1d',
                     Query.WEEK: '7d',
                     Query.MONTH: '30d',
                     Query.QUARTER: '90d',
                     Query.YEAR: '365d'}
        return ktype_map.get(ktype)

    def _get_bars(self, market, code, ktype):
        data = []
        username = self.getParam('username')
        password = self.getParam('password')
        auth(username, password)

        jqdataCode = normalize_code(code)
        jqdata_ktype = self._trans_ktype(ktype)

        if jqdata_ktype is None:
            print("jqdata_ktype == None")
            return data

        print(jqdataCode)
        security_info = get_security_info(jqdataCode)

        if security_info is None: #有可能取不到任何信息
            return data
        #print(security_info)

        data = get_price(jqdataCode, security_info.start_date, datetime.now(), jqdata_ktype)

        return data

在interactive.py文件中替换原来的数据源即可

DataDriverFactory.regKDataDriver(jqdataKDataDriver())

jqdata_param = Parameter()
jqdata_param.set('type', 'jqdata')
jqdata_param.set('username', '用户名')
jqdata_param.set('password', '密码')

base_param = sm.getBaseInfoDriverParameter()
block_param = sm.getBlockDriverParameter()
kdata_param = sm.getKDataDriverParameter()
preload_param = sm.getPreloadParameter()
hku_param = sm.getHikyuuParameter()

#切换K线数据驱动,重新初始化
sm.init(base_param, block_param, jqdata_param, preload_param, hku_param)

最后一步就是在python中直接使用jqdata数据源调用C 编写的指标了

s = sm['sz000001']
k = s.getKData(Query(-200))
#抽取K线收盘价指标,一般指标计算参数只能是指标类型,所以必须先将K线数据生成指标类型
c = CLOSE(k)
#调用自定义的C  均线策略计算收盘价的EMA指标
a = EMA(c)
#绘制指标
c.plot(legend_on=True)
a.plot(new=False, legend_on=True)
#绘制柱状图
a.bar()

以上三步,中最复杂的第二步,写一次后就可以通用,这样可以大大简化,python中,调用C 策略库的难度。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值