获取股票数据主要是靠网页爬虫或者现成的库。
关于股票数据获取,python的接口有非常多,教程也有很多。
最后我选择了使用tushare和baostock。由于tushare升级之后有积分限制,很多数据是需要获取足够的积分才能得到,所以在学习量化交易的最开始使用了baostock获取数据。由于最后只需要获取沪深300和中证500的一些基础数据,所以也就联合baostock和tushare解决了这个问题(tushare获取沪深300和中证500成分股好像需要600积分)。实际上感觉tushare的接口比baostock更完善一些。
话不多说 直接上代码:
import tushare as ts
import os
import datetime
import baostock as bs
import pandas as pd
import json
import logger
import csv
class stockDump():
def __init__(self, objLog=None, dataDir="stockData"):
self.dictStock = {}
self.dictMarket = {}
self.objLog = objLog
self.dataDir = dataDir
self.maxLine = 0
self.maxName = ""
self.data_init()
def data_init(self):
if os.path.exists(self.dataDir) is False:
os.makedirs(self.dataDir)
def log(self, string):
if self.objLog is None:
print(string)
else:
self.objLog.addDrc(string)
def login(self):
bs.login()
token = "your token"
self.objTu = ts.pro_api(token)
ts.set_token(token)
def logout(self):
bs.logout()
def get_one_stock_kline(self, mark, startDate, endDate, dumpList=False):
self.log("INFO: get info of '%s'"%(mark))
rs = bs.query_history_k_data_plus(mark,
"date,code,open,high,low,close,volume,amount,adjustflag,turn,pctChg,peTTM",
start_date=startDate, end_date=endDate,
frequency="d", adjustflag="2")
data_list = []
while (rs.error_code == '0') & rs.next():
data_list.append(rs.get_row_data())
result = pd.DataFrame(data_list, columns=rs.fields)
self.maxLine = max(len(data_list), self.maxLine)
if len(data_list) > self.maxLine:
self.maxLine = len(data_list)
self.maxName = mark
fName = mark.split(".")[-1]+".csv"
fPath = os.path.join(self.dataDir, fName)
result.to_csv(fPath, index=False)
if dumpList is True:
self.dictStock[mark] = "Add"
def dump_stock_list(self):
fPath = os.path.join(self.dataDir, "stocklist.js")
fp = open(fPath, "w")
json.dump(self.dictStock, fp)
fp.close()
def dump_industry(self):
rs = bs.query_stock_industry()
dictIndustry = {}
while (rs.error_code == '0') & rs.next():
rowData = rs.get_row_data()
mark = rowData[1].split(".")[-1]
dictIndustry[mark] = rowData[1:]
fPath = os.path.join(self.dataDir, "industry.js")
fp = open(fPath, "w")
json.dump(dictIndustry, fp)
fp.close()
def get_his300_stocks(self, startDate, endDate, dumpList=True):
rs = bs.query_hs300_stocks()
#query_zz500_stocks
hs300_stocks = []
while (rs.error_code == '0') & rs.next():
hs300_stocks.append(rs.get_row_data())
for item in hs300_stocks:
mark = item[1]
name = item[2]
if dumpList is True:
self.dictStock[mark] = name
self.get_one_stock_kline(mark, startDate, endDate)
def get_zz500_stocks(self, startDate, endDate, dumpList=True):
rs = bs.query_zz500_stocks()
zz500_stocks = []
while (rs.error_code == '0') & rs.next():
zz500_stocks.append(rs.get_row_data())
for item in zz500_stocks:
mark = item[1]
name = item[2]
if dumpList is True:
self.dictStock[mark] = name
self.get_one_stock_kline(mark, startDate, endDate)
def align_csv(self, name=None):
maxName = self.maxName if name is None else name
maxName = maxName.split(".")[-1]
fName = maxName +".csv"
objCsv = pd.read_csv(os.path.join(self.dataDir, fName))
dateList = list(objCsv["date"])
headerList = list(objCsv.columns.values)
dateInd = headerList.index("date")
codeInd = headerList.index("code")
openInd = headerList.index("open")
closeInd = headerList.index("close")
highInd = headerList.index("high")
lowInd = headerList.index("low")
tmpLine = [0 for item in headerList]
fPath = os.path.join(self.dataDir, "stocklist.js")
fp = open(fPath, "r")
dictStockList = json.load(fp)
fp.close()
for code in list(dictStockList.keys()):
closePri = 0
preFix = code.split(".")[0]
code = code.split(".")[-1]
pCode = "%s.%s"%(code, preFix.upper())
if code == maxName:
continue
fPath = os.path.join(self.dataDir, "%s.csv"%(code))
self.log("Align '%s'"%(fPath))
objCsv = pd.read_csv(fPath)
curDateList = list(objCsv["date"])
#objNew = pd.DataFrame([[]], columns=headerList)
tmpList = []
addNum = 0
for i, date in enumerate(dateList):
if date in curDateList:
lineList = list(objCsv.loc[i-addNum])
closePri = lineList[closeInd]
tmpList.append(lineList)
continue
addNum += 1
tmpLine[dateInd] = date
tmpLine[codeInd] = pCode
tmpLine[openInd] = closePri
tmpLine[closeInd] = closePri
tmpLine[highInd] = closePri
tmpLine[lowInd] = closePri
tmpList.append(tmpLine.copy())
objNew = pd.DataFrame(tmpList, columns=headerList)
objNew.to_csv(fPath, index=False)
def get_his300_stocks_ts(self, startDate, endDate):
rs = bs.query_hs300_stocks()
hs300_stocks = []
while (rs.error_code == '0') & rs.next():
hs300_stocks.append(rs.get_row_data())
for item in hs300_stocks:
mark = item[1]
name = item[2]
self.dictStock[mark] = name
mark = "%s.%s"%(mark.split(".")[-1], mark.split(".")[0].upper())
startDate = startDate.replace("-", "")
endDate = endDate.replace("-", "")
self.get_one_stock_kline_ts(mark, startDate, endDate)
def get_zz500_stocks_ts(self, startDate, endDate, dumpList=True):
rs = bs.query_zz500_stocks()
zz500_stocks = []
while (rs.error_code == '0') & rs.next():
zz500_stocks.append(rs.get_row_data())
for item in zz500_stocks:
mark = item[1]
name = item[2]
self.dictStock[mark] = name
mark = "%s.%s"%(mark.split(".")[-1], mark.split(".")[0].upper())
startDate = startDate.replace("-", "")
endDate = endDate.replace("-", "")
self.get_one_stock_kline_ts(mark, startDate, endDate)
def get_one_stock_kline_ts(self, mark, startDate, endDate, dumpList=False):
self.log("INFO: get info of '%s'"%(mark))
markNum = mark.split(".")[0]
dPath = os.path.join(self.dataDir, "%s.csv"%(markNum))
objData = ts.pro_bar(ts_code=mark, adj='qfq', start_date=startDate, end_date=endDate)
objData.rename(columns={"trade_date": "date"}, inplace=True)
objData.rename(columns={"ts_code": "code"}, inplace=True)
for i in range(0, len(objData)):
date = datetime.datetime.strptime(objData.loc[i, "date"], "%Y%m%d")
date = date.strftime("%Y-%m-%d")
objData.loc[i, "date"] = date
indexList = list(reversed(objData["date"]))
if self.maxLine < len(objData):
self.maxLine = len(objData)
self.maxName = markNum
objData = objData.set_index("date")
objData = objData.reindex(index=indexList)
objData.to_csv(dPath)
if dumpList is True:
self.dictStock[markNum] = "Add"
if __name__ == "__main__":
objDump = stockDump()
objDump.login()
endDate = str(datetime.datetime.today()).split()[0]
objDump.dump_industry()
objDump.get_his300_stocks_ts("2020-10-01", endDate)
objDump.get_zz500_stocks_ts(startDate='2020-10-01', endDate=endDate)
objDump.dump_stock_list()
objDump.align_csv()
objDump.logout()
需要注意几点:
1. token变量设置为用户自己的token,通过tushare官网获取
2. align_csv函数用于对齐股票数据,是由于个股存在停牌机制,停牌的时间段就没有数据,需要手动补充上去,否则不同股票的数据长度就会不一致,后续处理可能会有问题
3. 沪深300成分股和中证500成分股会不断更新
4. 注意是需要前复权的股价还是后复权
5. 获取的是日K数据
6. 历史前复权数据是会变化的,好像是遇到增发这种情况的时候,复权股票价格就会变化