量化投资从0开始系列 ---- 2. 获取数据的基础架构

环境搭建完成后 (见 https://blog.csdn.net/kengxie/article/details/118084858),接着把沪深两市所有的股票基本信息获取到本地。

因为是通过internet远程连接tushare的服务器获取数据,所以要实现一个tushare的客户端程序。这里要根据tushare提供的api做一些适当的封装,以提升程序的开发效率,降低维护成本。tushare给数据作了分类,每种数据有独立的api接口。但这里可以先不看具体的数据,先把整体的步骤抽象出来封装成一个父类,每种数据获取的实现类最后都从父类继承。

考虑网络连接波动的影响,获取数据的时候,要加入重试机制,如果遇到异常,自动重试三次:

    def retrieve(self, **kwargs):
        df = None
        ex = None
        retry = 3
        for _ in range(retry):
            try:
                df = self._get_data(**kwargs)
            except Exception as e:
                self.logger.debug("Failed retrieving data:", e)
                if _ == retry - 1:
                    ex = e;
                time.sleep(1)
            else:
                break

        if df is not None:
            self._save(df)

        if ex is not None:
            self.logger.exception(ex)

        self.logger.info(f"completed")

股票信息是不断更新的,所以要考虑两种方式,一是全量获取至今已有的所有数据,二是每天增量获取发生变化的数据。如果本地从来没有获取过某种数据,那么就做一次全量的更新,之后就只需要每天做一次增量更新了。

    def _get_data(self, **kwargs):
        if self._initialized(**kwargs):
            self.logger.info(f"_delta: {kwargs}")
            return self._delta(**kwargs)
        else:
            self.logger.info(f"_full: {kwargs}")
            return self._full(**kwargs)

怎么判断是哪种情况呢?看对应的数据库表是否存在,如果存在就增量更新,不存在就全量更新。

    def _initialized(self, **kwargs):
        sql = f"select count(*) from information_schema.tables where table_name = '{self.table_name}';"
        df = pd.read_sql_query(sql, engine_ts)
        return df.iat[0,0] > 0

以上的逻辑都封装到父类中,子类只要具体实现增量获取和全量获取的方法就行了。

    @abstractmethod
    def _full(self, **kwargs):
        pass

    @abstractmethod
    def _delta(self, **kwargs):
        pass

此外,所有数据都要保存到数据库中去,所以父类还可以放上读写mysql的方法,方便子类调用。这里也是参考tushare的例子,直接用dataframe读写数据库。这个不是对外的web应用,不可能被恶意攻击,所以也不用考虑sql注入之类的安全问题,直接用字符串拼接就好了。

    def _save(self, df):
        df.to_sql(self.table_name, engine_ts, index=False, if_exists=self.if_exists, chunksize=5000)

    def query(self, drop_meta=True, **kwargs):
        sql = f"select {kwargs.setdefault('fields', '*')} " \
              f"from {self.table_name} " \
              f"where {kwargs.setdefault('where', '1=1')} " \
              f"{('order by ' + kwargs.get('order_by')) if 'order_by' in kwargs.keys() else '' }"
        self.logger.debug(f"query: {sql}")
        df = pd.read_sql_query(sql, engine_ts)
        if drop_meta and 'update_time' in df.columns:
            return df.drop(columns=['update_time'])
        else:
            return df

父类的完整实现如下:

class AbstractDataRetriever(object):
    def __init__(self, table_name, if_exists='append'):
        self.table_name = table_name
        self.if_exists = if_exists
        self.logger = CustomLogger(extra={'classname': self.__class__.__name__})

    def retrieve(self, **kwargs):
        df = None
        ex = None
        retry = 3
        for _ in range(retry):
            try:
                df = self._get_data(**kwargs)
            except Exception as e:
                self.logger.debug("Failed retrieving data:", e)
                if _ == retry - 1:
                    ex = e;
                time.sleep(1)
            else:
                break

        if df is not None:
            self._save(df)

        if ex is not None:
            self.logger.exception(ex)

        self.logger.info(f"completed")

    def _save(self, df):
        df.to_sql(self.table_name, engine_ts, index=False, if_exists=self.if_exists, chunksize=5000)

    def _initialized(self, **kwargs):
        sql = f"select count(*) from information_schema.tables where table_name = '{self.table_name}';"
        df = pd.read_sql_query(sql, engine_ts)
        return df.iat[0,0] > 0

    def _get_data(self, **kwargs):
        if self._initialized(**kwargs):
            self.logger.info(f"_delta: {kwargs}")
            return self._delta(**kwargs)
        else:
            self.logger.info(f"_full: {kwargs}")
            return self._full(**kwargs)

    def query(self, drop_meta=True, **kwargs):
        sql = f"select {kwargs.setdefault('fields', '*')} " \
              f"from {self.table_name} " \
              f"where {kwargs.setdefault('where', '1=1')} " \
              f"{('order by ' + kwargs.get('order_by')) if 'order_by' in kwargs.keys() else '' }"
        self.logger.debug(f"query: {sql}")
        df = pd.read_sql_query(sql, engine_ts)
        if drop_meta and 'update_time' in df.columns:
            return df.drop(columns=['update_time'])
        else:
            return df

    @abstractmethod
    def _full(self, **kwargs):
        pass

    @abstractmethod
    def _delta(self, **kwargs):
        pass

日志类做了一点点增强,把类名添加到格式中去了,这样在读日志的时候,方便分辨是哪个类的输出内容。如果输出日志的方法不在我们定义的类中,就只放一个占位符。

class CustomFormatter(logging.Formatter):
    def __init__(self, fmt=None, datefmt=None, style='{', validate=True):
        super().__init__(fmt, datefmt, style, validate)

    def formatMessage(self, record):
        return self._fmt.format_map(CustomFormatter.Default(record.__dict__))

    class Default(dict):
        def __missing__(self, key):
            return '{' + key + '}'


def default_logger():
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(CustomFormatter(LOG_FORMAT))
    logger.addHandler(console_handler)

    return logger


class CustomLogger(object):
    def __init__(self, logger=default_logger(), extra={}):
        self.logger = logger
        self.extra = extra

    def __getattr__(self, name):
        return partial(getattr(self.logger, name), extra=self.extra)

全部代码上传到https://github.com/xiekeng/tushare-client,感兴趣的可以自取。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

blkq

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值