环境搭建完成后 (见 https://blog.csdn.net/kengxie/article/details/118084858),接着把沪深两市所有的股票基本信息获取到本地。
因为是通过internet远程连接tushare的服务器获取数据,所以要实现一个tushare的客户端程序。这里要根据tushare提供的api做一些适当的封装,以提升程序的开发效率,降低维护成本。tushare给数据作了分类,每种数据有独立的api接口。但这里可以先不看具体的数据,先把整体的步骤抽象出来封装成一个父类,每种数据获取的实现类最后都从父类继承。
考虑网络连接波动的影响,获取数据的时候,要加入重试机制,如果遇到异常,自动重试三次:
def retrieve(self, **kwargs):
df = None
ex = None
retry = 3
for _ in range(retry):
try:
df = self._get_data(**kwargs)
except Exception as e:
self.logger.debug("Failed retrieving data:", e)
if _ == retry - 1:
ex = e;
time.sleep(1)
else:
break
if df is not None:
self._save(df)
if ex is not None:
self.logger.exception(ex)
self.logger.info(f"completed")
股票信息是不断更新的,所以要考虑两种方式,一是全量获取至今已有的所有数据,二是每天增量获取发生变化的数据。如果本地从来没有获取过某种数据,那么就做一次全量的更新,之后就只需要每天做一次增量更新了。
def _get_data(self, **kwargs):
if self._initialized(**kwargs):
self.logger.info(f"_delta: {kwargs}")
return self._delta(**kwargs)
else:
self.logger.info(f"_full: {kwargs}")
return self._full(**kwargs)
怎么判断是哪种情况呢?看对应的数据库表是否存在,如果存在就增量更新,不存在就全量更新。
def _initialized(self, **kwargs):
sql = f"select count(*) from information_schema.tables where table_name = '{self.table_name}';"
df = pd.read_sql_query(sql, engine_ts)
return df.iat[0,0] > 0
以上的逻辑都封装到父类中,子类只要具体实现增量获取和全量获取的方法就行了。
@abstractmethod
def _full(self, **kwargs):
pass
@abstractmethod
def _delta(self, **kwargs):
pass
此外,所有数据都要保存到数据库中去,所以父类还可以放上读写mysql的方法,方便子类调用。这里也是参考tushare的例子,直接用dataframe读写数据库。这个不是对外的web应用,不可能被恶意攻击,所以也不用考虑sql注入之类的安全问题,直接用字符串拼接就好了。
def _save(self, df):
df.to_sql(self.table_name, engine_ts, index=False, if_exists=self.if_exists, chunksize=5000)
def query(self, drop_meta=True, **kwargs):
sql = f"select {kwargs.setdefault('fields', '*')} " \
f"from {self.table_name} " \
f"where {kwargs.setdefault('where', '1=1')} " \
f"{('order by ' + kwargs.get('order_by')) if 'order_by' in kwargs.keys() else '' }"
self.logger.debug(f"query: {sql}")
df = pd.read_sql_query(sql, engine_ts)
if drop_meta and 'update_time' in df.columns:
return df.drop(columns=['update_time'])
else:
return df
父类的完整实现如下:
class AbstractDataRetriever(object):
def __init__(self, table_name, if_exists='append'):
self.table_name = table_name
self.if_exists = if_exists
self.logger = CustomLogger(extra={'classname': self.__class__.__name__})
def retrieve(self, **kwargs):
df = None
ex = None
retry = 3
for _ in range(retry):
try:
df = self._get_data(**kwargs)
except Exception as e:
self.logger.debug("Failed retrieving data:", e)
if _ == retry - 1:
ex = e;
time.sleep(1)
else:
break
if df is not None:
self._save(df)
if ex is not None:
self.logger.exception(ex)
self.logger.info(f"completed")
def _save(self, df):
df.to_sql(self.table_name, engine_ts, index=False, if_exists=self.if_exists, chunksize=5000)
def _initialized(self, **kwargs):
sql = f"select count(*) from information_schema.tables where table_name = '{self.table_name}';"
df = pd.read_sql_query(sql, engine_ts)
return df.iat[0,0] > 0
def _get_data(self, **kwargs):
if self._initialized(**kwargs):
self.logger.info(f"_delta: {kwargs}")
return self._delta(**kwargs)
else:
self.logger.info(f"_full: {kwargs}")
return self._full(**kwargs)
def query(self, drop_meta=True, **kwargs):
sql = f"select {kwargs.setdefault('fields', '*')} " \
f"from {self.table_name} " \
f"where {kwargs.setdefault('where', '1=1')} " \
f"{('order by ' + kwargs.get('order_by')) if 'order_by' in kwargs.keys() else '' }"
self.logger.debug(f"query: {sql}")
df = pd.read_sql_query(sql, engine_ts)
if drop_meta and 'update_time' in df.columns:
return df.drop(columns=['update_time'])
else:
return df
@abstractmethod
def _full(self, **kwargs):
pass
@abstractmethod
def _delta(self, **kwargs):
pass
日志类做了一点点增强,把类名添加到格式中去了,这样在读日志的时候,方便分辨是哪个类的输出内容。如果输出日志的方法不在我们定义的类中,就只放一个占位符。
class CustomFormatter(logging.Formatter):
def __init__(self, fmt=None, datefmt=None, style='{', validate=True):
super().__init__(fmt, datefmt, style, validate)
def formatMessage(self, record):
return self._fmt.format_map(CustomFormatter.Default(record.__dict__))
class Default(dict):
def __missing__(self, key):
return '{' + key + '}'
def default_logger():
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setFormatter(CustomFormatter(LOG_FORMAT))
logger.addHandler(console_handler)
return logger
class CustomLogger(object):
def __init__(self, logger=default_logger(), extra={}):
self.logger = logger
self.extra = extra
def __getattr__(self, name):
return partial(getattr(self.logger, name), extra=self.extra)
全部代码上传到https://github.com/xiekeng/tushare-client,感兴趣的可以自取。