【Python量化交易笔记】股票数据获取 (一)

获取股票数据主要是靠网页爬虫或者现成的库。

关于股票数据获取,python的接口有非常多,教程也有很多。

最后我选择了使用tushare和baostock。由于tushare升级之后有积分限制,很多数据是需要获取足够的积分才能得到,所以在学习量化交易的最开始使用了baostock获取数据。由于最后只需要获取沪深300和中证500的一些基础数据,所以也就联合baostock和tushare解决了这个问题(tushare获取沪深300和中证500成分股好像需要600积分)。实际上感觉tushare的接口比baostock更完善一些。

话不多说 直接上代码:

import tushare as ts
import os
import datetime
import baostock as bs
import pandas as pd
import json
import logger
import csv

class stockDump():
    def __init__(self, objLog=None, dataDir="stockData"):
        self.dictStock = {}
        self.dictMarket = {}
        self.objLog    = objLog
        self.dataDir   = dataDir
        self.maxLine   = 0
        self.maxName   = ""

        self.data_init()

    def data_init(self):
        if os.path.exists(self.dataDir) is False:
            os.makedirs(self.dataDir)

    def log(self, string):
        if self.objLog is None:
            print(string)
        else:
            self.objLog.addDrc(string)

    def login(self):
        bs.login()

        token = "your token"
        self.objTu = ts.pro_api(token)
        ts.set_token(token)

    def logout(self):
        bs.logout()

    def get_one_stock_kline(self, mark, startDate, endDate, dumpList=False):
        self.log("INFO: get info of '%s'"%(mark))
        rs = bs.query_history_k_data_plus(mark,
            "date,code,open,high,low,close,volume,amount,adjustflag,turn,pctChg,peTTM",
            start_date=startDate, end_date=endDate,
            frequency="d", adjustflag="2")

        data_list = []
        while (rs.error_code == '0') & rs.next():
            data_list.append(rs.get_row_data())
        result = pd.DataFrame(data_list, columns=rs.fields)
        self.maxLine = max(len(data_list), self.maxLine)
        if len(data_list) > self.maxLine:
            self.maxLine = len(data_list)
            self.maxName = mark

        fName = mark.split(".")[-1]+".csv"
        fPath = os.path.join(self.dataDir, fName)
        result.to_csv(fPath, index=False)

        if dumpList is True:
            self.dictStock[mark] = "Add"

    def dump_stock_list(self):
        fPath = os.path.join(self.dataDir, "stocklist.js")
        fp = open(fPath, "w")
        json.dump(self.dictStock, fp)
        fp.close()

    def dump_industry(self):
        rs = bs.query_stock_industry()
        dictIndustry = {}
        while (rs.error_code == '0') & rs.next():
            rowData = rs.get_row_data()
            mark = rowData[1].split(".")[-1]
            dictIndustry[mark] = rowData[1:]

        fPath = os.path.join(self.dataDir, "industry.js")
        fp = open(fPath, "w")
        json.dump(dictIndustry, fp)
        fp.close()

    def get_his300_stocks(self, startDate, endDate, dumpList=True):
        rs = bs.query_hs300_stocks()
        #query_zz500_stocks
        hs300_stocks = []
        while (rs.error_code == '0') & rs.next():
            hs300_stocks.append(rs.get_row_data())

        for item in hs300_stocks:
            mark = item[1]
            name = item[2]
            if dumpList is True:
                self.dictStock[mark] = name
            self.get_one_stock_kline(mark, startDate, endDate)

    def get_zz500_stocks(self, startDate, endDate, dumpList=True):
        rs = bs.query_zz500_stocks()
        zz500_stocks = []
        while (rs.error_code == '0') & rs.next():
            zz500_stocks.append(rs.get_row_data())

        for item in zz500_stocks:
            mark = item[1]
            name = item[2]
            if dumpList is True:
                self.dictStock[mark] = name
            self.get_one_stock_kline(mark, startDate, endDate)


    def align_csv(self, name=None):
        maxName = self.maxName if name is None else name
        maxName = maxName.split(".")[-1]
        fName = maxName +".csv"

        objCsv = pd.read_csv(os.path.join(self.dataDir, fName))

        dateList = list(objCsv["date"])
        headerList = list(objCsv.columns.values)
        dateInd  = headerList.index("date")
        codeInd  = headerList.index("code")
        openInd  = headerList.index("open")
        closeInd = headerList.index("close")
        highInd  = headerList.index("high")
        lowInd   = headerList.index("low")
        tmpLine = [0 for item in headerList]

        fPath = os.path.join(self.dataDir, "stocklist.js")
        fp = open(fPath, "r")
        dictStockList = json.load(fp)
        fp.close()

        for code in list(dictStockList.keys()):
            closePri = 0
            preFix = code.split(".")[0]
            code = code.split(".")[-1]
            pCode = "%s.%s"%(code, preFix.upper())
            if code == maxName:
                continue
            fPath = os.path.join(self.dataDir, "%s.csv"%(code))

            self.log("Align '%s'"%(fPath))

            objCsv = pd.read_csv(fPath)

            curDateList = list(objCsv["date"])

            #objNew = pd.DataFrame([[]], columns=headerList)

            tmpList = []
            addNum = 0
            for i, date in enumerate(dateList):
                if date in curDateList:
                    lineList = list(objCsv.loc[i-addNum])
                    closePri = lineList[closeInd]
                    tmpList.append(lineList)
                    continue
                addNum += 1
                tmpLine[dateInd]  = date
                tmpLine[codeInd]  = pCode
                tmpLine[openInd]  = closePri
                tmpLine[closeInd] = closePri
                tmpLine[highInd]  = closePri
                tmpLine[lowInd]   = closePri
                tmpList.append(tmpLine.copy())

            objNew = pd.DataFrame(tmpList, columns=headerList)

            objNew.to_csv(fPath, index=False)

    def get_his300_stocks_ts(self, startDate, endDate):
        rs = bs.query_hs300_stocks()
        hs300_stocks = []
        while (rs.error_code == '0') & rs.next():
            hs300_stocks.append(rs.get_row_data())

        for item in hs300_stocks:
            mark = item[1]
            name = item[2]
            self.dictStock[mark] = name
            mark = "%s.%s"%(mark.split(".")[-1], mark.split(".")[0].upper())
            startDate = startDate.replace("-", "")
            endDate = endDate.replace("-", "")

            self.get_one_stock_kline_ts(mark, startDate, endDate)

    def get_zz500_stocks_ts(self, startDate, endDate, dumpList=True):
        rs = bs.query_zz500_stocks()
        zz500_stocks = []
        while (rs.error_code == '0') & rs.next():
            zz500_stocks.append(rs.get_row_data())

        for item in zz500_stocks:
            mark = item[1]
            name = item[2]
            self.dictStock[mark] = name
            mark = "%s.%s"%(mark.split(".")[-1], mark.split(".")[0].upper())
            startDate = startDate.replace("-", "")
            endDate = endDate.replace("-", "")

            self.get_one_stock_kline_ts(mark, startDate, endDate)

    def get_one_stock_kline_ts(self, mark, startDate, endDate, dumpList=False):

        self.log("INFO: get info of '%s'"%(mark))

        markNum = mark.split(".")[0]
        dPath = os.path.join(self.dataDir, "%s.csv"%(markNum))

        objData  = ts.pro_bar(ts_code=mark, adj='qfq', start_date=startDate, end_date=endDate)

        objData.rename(columns={"trade_date": "date"}, inplace=True)
        objData.rename(columns={"ts_code": "code"}, inplace=True)

        for i in range(0, len(objData)):
            date = datetime.datetime.strptime(objData.loc[i, "date"], "%Y%m%d")
            date = date.strftime("%Y-%m-%d")
            objData.loc[i, "date"] = date

        indexList = list(reversed(objData["date"]))

        if self.maxLine < len(objData):
            self.maxLine = len(objData)
            self.maxName = markNum

        objData = objData.set_index("date")
        objData = objData.reindex(index=indexList)
        objData.to_csv(dPath)

        if dumpList is True:
            self.dictStock[markNum] = "Add"

if __name__ == "__main__":
    objDump = stockDump()
    objDump.login()
    endDate = str(datetime.datetime.today()).split()[0]

    objDump.dump_industry()
    objDump.get_his300_stocks_ts("2020-10-01", endDate)
    objDump.get_zz500_stocks_ts(startDate='2020-10-01', endDate=endDate)
    objDump.dump_stock_list()
    objDump.align_csv()
    objDump.logout()

需要注意几点:

1. token变量设置为用户自己的token,通过tushare官网获取

2. align_csv函数用于对齐股票数据,是由于个股存在停牌机制,停牌的时间段就没有数据,需要手动补充上去,否则不同股票的数据长度就会不一致,后续处理可能会有问题

3. 沪深300成分股和中证500成分股会不断更新

4. 注意是需要前复权的股价还是后复权

5. 获取的是日K数据

6. 历史前复权数据是会变化的,好像是遇到增发这种情况的时候,复权股票价格就会变化

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Python中,数据获取量化交易中必不可少的一步。以下是一些获取数据的方法: 1. tushare库:tushare是一个免费、开源、易于使用Python财经数据接口包,可以提供股票、基金、期货等市场数据,非常适合量化交易数据获取。你可以使用pip安装: ``` pip install tushare ``` 然后按照如下代码获取股票数据: ```python import tushare as ts # 获取股票数据 df = ts.get_hist_data('600519', '2020-01-01', '2021-01-01') print(df) ``` 这里的参数'600519'表示茅台股票的代码,'2020-01-01'和'2021-01-01'分别表示开始日期和结束日期。 2. jqdatasdk库:jqdatasdk是一个免费的Python金融数据接口库,可以获取股票、基金、期货、外汇等市场数据。你可以使用pip安装: ``` pip install jqdatasdk ``` 然后按照如下代码获取股票数据: ```python import jqdatasdk # 登录聚宽账号(需要先注册) jqdatasdk.auth('username', 'password') # 获取股票数据 df = jqdatasdk.get_price('000001.XSHE', start_date='2020-01-01', end_date='2021-01-01', frequency='daily') print(df) ``` 这里的参数'000001.XSHE'表示平安银行股票的代码,'2020-01-01'和'2021-01-01'分别表示开始日期和结束日期。 3. akshare库:akshare是一个免费、开源的Python财经数据接口库,可以提供股票、基金、期货等市场数据。你可以使用pip安装: ``` pip install akshare ``` 然后按照如下代码获取股票数据: ```python import akshare as ak # 获取股票数据 df = ak.stock_zh_a_daily(symbol='sh600519', start_date='20200101', end_date='20210101') print(df) ``` 这里的参数'sh600519'表示茅台股票的代码,'20200101'和'20210101'分别表示开始日期和结束日期。 以上是几种获取股票数据的方法,你可以根据自己的需求选择其中一种。另外,对于其他市场的数据获取,也可以使用类似的方法。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值