Tushare进阶!搭建数据class -- (3)

本文介绍了一个使用Python编写的金融数据处理库finrl中的TusharePro Processor类,用于从Tushare Pro API获取数据,清洗、整理并计算技术指标,适合股票交易和机器学习项目。代码涵盖了数据下载、基本数据清洗、技术指标计算以及数据切分等功能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

这是我之前借鉴过的一个项目finrl -- 很棒的团队做出来的, 里面也提供了tushare数据接口的源数据获取作为开启他的项目的第一步

这段代码提供了很多功能,不单单是下载数据,也提供了清洗数据,计算数据指标等等,代码写得也非常规范,可参考之!

here is the example

1. 建立basic_processor.py

from typing import List

import numpy as np
import pandas as pd
import stockstats
import talib

class BasicProcessor:
    def __init__(self, data_source: str, start_date, end_date, time_interval, **kwargs):

        assert data_source in {
            "alpaca",
            "ccxt",
            "binance",
            "iexcloud",
            "joinquant",
            "quantconnect",
            "ricequant",
            "wrds",
            "yahoofinance",
            "tusharepro",
        }, "Data source input is NOT supported yet."
        self.data_source: str = data_source
        self.start_date = start_date
        self.end_date = end_date
        self.time_interval = time_interval
        self.time_zone: str = ""
        self.dataframe: pd.DataFrame = pd.DataFrame()
        self.dictnumpy: dict = {}

    def download_data(self, ticker_list: List[str]):
        pass

    def clean_data(self):
        if "date" in self.dataframe.columns.values.tolist():
            self.dataframe.rename(columns={'date': 'time'}, inplace=True)
        if "datetime" in self.dataframe.columns.values.tolist():
            self.dataframe.rename(columns={'datetime': 'time'}, inplace=True)
        if self.data_source == "ccxt":
            self.dataframe.rename(columns={'index': 'time'}, inplace=True)
        elif self.data_source == 'ricequant':
            ''' RiceQuant data is already cleaned, we only need to transform data format here.
                No need for filling NaN data'''
            self.dataframe.rename(columns={'order_book_id': 'tic'}, inplace=True)
            # raw df uses multi-index (tic,time), reset it to single index (time)
            self.dataframe.reset_index(level=[0, 1], inplace=True)
            # check if there is NaN values
            assert not self.dataframe.isnull().values.any()
        self.dataframe.dropna(inplace=True)
        # adj_close: adjusted close price
        if 'adj_close' not in self.dataframe.columns.values.tolist():
            self.dataframe['adj_close'] = self.dataframe['close']
        self.dataframe.sort_values(by=['time', 'tic'], inplace=True)
        self.dataframe = self.dataframe[['tic', 'time', 'open', 'high', 'low', 'close', 'adj_close', 'volume']]

    def get_trading_days(self, start: str, end: str) -> List[str]:
        if self.data_source in ["binance", "ccxt", "quantconnect", "ricequant", "tusharepro"]:
            print(f"Calculate get_trading_days not supported for {self.data_source} yet.")
            return None

    # use_stockstats_or_talib: 0 (stockstats, default), or 1 (use talib). Users can choose the method.
    def add_technical_indicator(self, tech_indicator_list: List[str], use_stockstats_or_talib:int = 0):
        """
        calculate technical indicators
        use stockstats/talib package to add technical inidactors
        :param data: (df) pandas dataframe
        :return: (df) pandas dataframe
        """
        if "date" in self.dataframe.columns.values.tolist():
            self.dataframe.rename(columns={'date': 'time'}, inplace=True)

        if self.data_source == "ccxt":
            self.dataframe.rename(columns={'index': 'time'}, inplace=True)

        self.dataframe.reset_index(drop=False, inplace=True)
        if "level_1" in self.dataframe.columns:
            self.dataframe.drop(columns=["level_1"], inplace=True)
        if "level_0" in self.dataframe.columns and "tic" not in self.dataframe.columns:
            self.dataframe.rename(columns={"level_0": "tic"}, inplace=True)
        assert use_stockstats_or_talib in {0,1}
        if use_stockstats_or_talib == 0:  # use stockstats
            stock = stockstats.StockDataFrame.retype(self.dataframe)
            unique_ticker = stock.tic.unique()
            for indicator in tech_indicator_list:
                indicator_df = pd.DataFrame()
                for i in range(len(unique_ticker)):
                    try:
                        temp_indicator = stock[stock.tic == unique_ticker[i]][indicator]
                        temp_indicator = pd.DataFrame(temp_indicator)
                        temp_indicator["tic"] = unique_ticker[i]
                        temp_indicator["time"] = self.dataframe[self.dataframe.tic == unique_ticker[i]][
                            "time"
                        ].to_list()
                        indicator_df = indicator_df.append(
                            temp_indicator, ignore_index=True
                        )
                    except Exception as e:
                        print(e)
                self.dataframe = self.dataframe.merge(
                    indicator_df[["tic", "time", indicator]], on=["tic", "time"], how="left"
                )
        else:  # use talib
            final_df = pd.DataFrame()
            for i in self.dataframe.tic.unique():
                tic_df = self.dataframe[self.dataframe.tic == i]
                tic_df['macd'], tic_df['macd_signal'], tic_df['macd_hist'] = talib.MACD(tic_df['close'], fastperiod=12,
                                                                                  slowperiod=26, signalperiod=9)
                tic_df['rsi'] = talib.RSI(tic_df['close'], timeperiod=14)
                tic_df['cci'] = talib.CCI(tic_df['high'], tic_df['low'], tic_df['close'], timeperiod=14)
                tic_df['dx'] = talib.DX(tic_df['high'], tic_df['low'], tic_df['close'], timeperiod=14)
                final_df = final_df.append(tic_df)
            self.dataframe = final_df

        self.dataframe.sort_values(by=["time", "tic"], inplace=True)
        time_to_drop = self.dataframe[self.dataframe.isna().any(axis=1)].time.unique()
        self.dataframe = self.dataframe[~self.dataframe.time.isin(time_to_drop)]
        print("Succesfully add technical indicators")

    def add_turbulence(self):
        """
        add turbulence index from a precalcualted dataframe
        :param data: (df) pandas dataframe
        :return: (df) pandas dataframe
        """
        # df = data.copy()
        # turbulence_index = self.calculate_turbulence(df)
        # df = df.merge(turbulence_index, on="time")
        # df = df.sort_values(["time", "tic"]).reset_index(drop=True)
        # return df
        if self.data_source in ["binance", "ccxt", "iexcloud", "joinquant", "quantconnect"]:
            print(f"Turbulence not supported for {self.data_source} yet. Return original DataFrame.")
        if self.data_source in ["alpaca", "ricequant", "tusharepro", "wrds", "yahoofinance"]:
            turbulence_index = self.calculate_turbulence()
            self.dataframe = self.dataframe.merge(turbulence_index, on="time")
            self.dataframe.sort_values(["time", "tic"], inplace=True).reset_index(drop=True, inplace=True)

    def calculate_turbulence(self, time_period: int = 252) -> pd.DataFrame:
        """calculate turbulence index based on dow 30"""
        # can add other market assets
        df_price_pivot = self.dataframe.pivot(index="time", columns="tic", values="close")
        # use returns to calculate turbulence
        df_price_pivot = df_price_pivot.pct_change()

        unique_date = self.dataframe['time'].unique()
        # start after a year
        start = time_period
        turbulence_index = [0] * start
        # turbulence_index = [0]
        count = 0
        for i in range(start, len(unique_date)):
            current_price = df_price_pivot[df_price_pivot.index == unique_date[i]]
            # use one year rolling window to calcualte covariance
            hist_price = df_price_pivot[
                (df_price_pivot.index < unique_date[i])
                & (df_price_pivot.index >= unique_date[i - time_period])
                ]
            # Drop tickers which has number missing values more than the "oldest" ticker
            filtered_hist_price = hist_price.iloc[
                                  hist_price.isna().sum().min():
                                  ].dropna(axis=1)

            cov_temp = filtered_hist_price.cov()
            current_temp = (current_price[list(filtered_hist_price)] - np.mean(
                filtered_hist_price, axis=0
            ))
            # cov_temp = hist_price.cov()
            # current_temp=(current_price - np.mean(hist_price,axis=0))

            temp = current_temp.values.dot(np.linalg.pinv(cov_temp)).dot(
                current_temp.values.T
            )
            if temp > 0:
                count += 1
                # avoid large outlier because of the calculation just begins: else turbulence_temp = 0
                turbulence_temp = temp[0][0] if count > 2 else 0
            else:
                turbulence_temp = 0
            turbulence_index.append(turbulence_temp)

        turbulence_index = pd.DataFrame(
            {"time": df_price_pivot.index, "turbulence": turbulence_index}
        )
        return turbulence_index

    def add_vix(self):
        """
        add vix from processors
        :param data: (df) pandas dataframe
        :return: (df) pandas dataframe
        """
        if self.data_source in ['binance', 'ccxt', 'iexcloud', 'joinquant', 'quantconnect', 'ricequant', 'tusharepro']:
            print(f'VIX is not applicable for {self.data_source}. Return original DataFrame')
            return

        # if self.data_source == 'yahoofinance':
        #     df = data.copy()
        #     df_vix = self.download_data(
        #         start_date=df.time.min(),
        #         end_date=df.time.max(),
        #         ticker_list=["^VIX"],
        #         time_interval=self.time_interval,
        #     )
        #     df_vix = self.clean_data(df_vix)
        #     vix = df_vix[["time", "adj_close"]]
        #     vix.columns = ["time", "vix"]
        #
        #     df = df.merge(vix, on="time")
        #     df = df.sort_values(["time", "tic"]).reset_index(drop=True)
        # elif self.data_source == 'alpaca':
        #     vix_df = self.download_data(["VIXY"], self.start, self.end, self.time_interval)
        #     cleaned_vix = self.clean_data(vix_df)
        #     vix = cleaned_vix[["time", "close"]]
        #     vix = vix.rename(columns={"close": "VIXY"})
        #
        #     df = data.copy()
        #     df = df.merge(vix, on="time")
        #     df = df.sort_values(["time", "tic"]).reset_index(drop=True)
        # elif self.data_source == 'wrds':
        #     vix_df = self.download_data(['vix'], self.start, self.end_date, self.time_interval)
        #     cleaned_vix = self.clean_data(vix_df)
        #     vix = cleaned_vix[['date', 'close']]
        #
        #     df = data.copy()
        #     df = df.merge(vix, on="date")
        #     df = df.sort_values(["date", "tic"]).reset_index(drop=True)

        if self.data_source == 'yahoofinance':
            ticker = "^VIX"
        elif self.data_source == 'alpaca':
            ticker = "VIXY"
        elif self.data_source == 'wrds':
            ticker = "vix"
        else:
            return
        df = self.dataframe.copy()
        self.dataframe = [ticker]
        self.download_data(self.start, self.end, self.time_interval)
        self.clean_data()
        # vix = cleaned_vix[["time", "close"]]
        # vix = vix.rename(columns={"close": "VIXY"})
        cleaned_vix = self.dataframe.rename(columns={ticker: "vix"})

        df = df.merge(cleaned_vix, on="time")
        df = df.sort_values(["time", "tic"]).reset_index(drop=True)
        self.dataframe = df

    def df_to_array(self, tech_indicator_list: list, if_vix: bool):
        unique_ticker = self.dataframe.tic.unique()
        price_array = np.column_stack([self.dataframe[self.dataframe.tic == tic].close for tic in unique_ticker])
        tech_array = np.hstack([self.dataframe.loc[(self.dataframe.tic == tic), tech_indicator_list] for tic in unique_ticker])
        if if_vix:
            risk_array = np.column_stack([self.dataframe[self.dataframe.tic == tic].vix for tic in unique_ticker])
        else:
            risk_array = np.column_stack(
                [self.dataframe[self.dataframe.tic == tic].turbulence for tic in unique_ticker]) if "turbulence" in self.dataframe.columns else None
        print("Successfully transformed into array")
        return price_array, tech_array, risk_array

2. 然后是 processor_Tusharepro.py 

这个会继承上边建立的basic 其实basic原本的意图是开发更多的数据接口,不单单是tushare,也可以外链雅虎财经等等

import pandas as pd
from tqdm import tqdm
import BasicProcessor
from typing import List
import time
import copy
import warnings
from copy import deepcopy

warnings.filterwarnings("ignore")


class TushareProProcessor(BasicProcessor):
    """Provides methods for retrieving daily stock data from tusharepro API
    Attributes
    ----------
        start_date : str
            start date of the data
        end_date : str
            end date of the data
        ticker_list : list
            a list of stock tickers 
        token : str
            get from https://waditu.com/ after registration
        adj: str
            Whether to use adjusted closing price. Default is None. 
            If you want to use forward adjusted closing price or 前复权. pleses use 'qfq'
            If you want to use backward adjusted closing price or 后复权. pleses use 'hfq'
    Methods
    -------
    download_data()
        Fetches data from tusharepro API
    
    """

    def __init__(self, data_source: str, start_date, end_date, time_interval, **kwargs):
        super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
        if 'token' not in kwargs.keys():
            raise ValueError("pleses input token!")
        self.token = kwargs["token"]
        if 'adj' in kwargs.keys():
            self.adj = kwargs["adj"]
            print(f"Using {self.adj} method")
        else:
            self.adj = None

    def get_data(self, id) -> pd.DataFrame:
        # df1 = ts.pro_bar(ts_code=id, start_date=self.start_date,end_date='20180101')
        # dfb=pd.concat([df, df1], ignore_index=True)
        # print(dfb.shape)
        return ts.pro_bar(
            ts_code=id, start_date=self.start_date, end_date=self.end_date, adj=self.adj
        )

    def download_data(self, ticker_list: List[str]):
        """Fetches data from tusharepro API
        Parameters
        ----------
        Returns
        -------
        `pd.DataFrame`
            7 columns: A tick symbol, date, open, high, low, close and volume 
            for the specified stock ticker
        """
        self.ticker_list = ticker_list

        if self.time_interval != "1D":
            raise ValueError('not supported currently')

        ts.set_token(self.token)

        self.df = pd.DataFrame()
        for i in tqdm(ticker_list, total=len(ticker_list)):
            df_temp = self.get_data(i)
            self.df = self.df.append(df_temp)
            # print("{} ok".format(i))
            time.sleep(0.25)

        self.df.columns = ['tic', 'date', 'open', 'high', 'low', 'close', 'pre_close', 'change', 'pct_chg', 'volume',
                           'amount']
        self.df = self.df.sort_values(by=['date', 'tic']).reset_index(drop=True)

        df = self.df[['tic', 'date', 'open', 'high', 'low', 'close', 'volume']]
        df["date"] = pd.to_datetime(df["date"], format="%Y%m%d")
        df["day"] = df["date"].dt.dayofweek
        df["date"] = df.date.apply(lambda x: x.strftime("%Y-%m-%d"))

        df = df.dropna()
        df = df.sort_values(by=['date', 'tic']).reset_index(drop=True)

        print("Shape of DataFrame: ", df.shape)

        self.dataframe = df

    def clean_data(self):
        dfc = copy.deepcopy(self.dataframe)

        dfcode = pd.DataFrame(columns=['tic'])
        dfdate = pd.DataFrame(columns=['date'])

        dfcode.tic = dfc.tic.unique()

        if "time" in dfc.columns.values.tolist():
            dfc = dfc.rename(columns={'time': 'date'})

        dfdate.date = dfc.date.unique()
        dfdate.sort_values(by="date", ascending=False, ignore_index=True, inplace=True)

        # the old pandas may not support pd.merge(how="cross")
        try:
            df1 = pd.merge(dfcode, dfdate, how="cross")
        except:
            print("Please wait for a few seconds...")
            df1 = pd.DataFrame(columns=["tic", "date"])
            for i in range(dfcode.shape[0]):
                for j in range(dfdate.shape[0]):
                    df1 = df1.append(pd.DataFrame(data={"tic": dfcode.iat[i, 0], "date": dfdate.iat[j, 0]},
                                                  index=[(i + 1) * (j + 1) - 1]))

        df2 = pd.merge(df1, dfc, how="left", on=["tic", "date"])

        # back fill missing data then front fill
        df3 = pd.DataFrame(columns=df2.columns)
        for i in self.ticker_list:
            df4 = df2[df2.tic == i].fillna(method="bfill").fillna(method="ffill")
            df3 = pd.concat([df3, df4], ignore_index=True)

        df3 = df3.fillna(0)

        # reshape dataframe
        df3 = df3.sort_values(by=['date', 'tic']).reset_index(drop=True)

        print("Shape of DataFrame: ", df3.shape)

        self.dataframe = df3

    # def add_technical_indicator(self, tech_indicator_list: List[str], use_stockstats_or_talib: int=0):
    #     """
    #     calculate technical indicators
    #     use stockstats/talib package to add technical inidactors
    #     :param data: (df) pandas dataframe
    #     :return: (df) pandas dataframe
    #     """
    #     df = self.dataframe.copy()
    #     if "date" in df.columns.values.tolist():
    #         df = df.rename(columns={'date': 'time'})
    #
    #     if self.data_source == "ccxt":
    #         df = df.rename(columns={'index': 'time'})
    #
    #     # df = df.reset_index(drop=False)
    #     # df = df.drop(columns=["level_1"])
    #     # df = df.rename(columns={"level_0": "tic", "date": "time"})
    #     if use_stockstats_or_talib == 0:  # use stockstats
    #         stock = stockstats.StockDataFrame.retype(df.copy())
    #         unique_ticker = stock.tic.unique()
    #         #print(unique_ticker)
    #         for indicator in tech_indicator_list:
    #             indicator_df = pd.DataFrame()
    #             for i in range(len(unique_ticker)):
    #                 try:
    #                     temp_indicator = stock[stock.tic == unique_ticker[i]][indicator]
    #                     temp_indicator = pd.DataFrame(temp_indicator)
    #                     temp_indicator["tic"] = unique_ticker[i]
    #                     temp_indicator["time"] = df[df.tic == unique_ticker[i]][
    #                         "time"
    #                     ].to_list()
    #                     indicator_df = indicator_df.append(
    #                         temp_indicator, ignore_index=True
    #                     )
    #                 except Exception as e:
    #                     print(e)
    #             #print(indicator_df)
    #             df = df.merge(
    #                 indicator_df[["tic", "time", indicator]], on=["tic", "time"], how="left"
    #             )
    #     else:  # use talib
    #         final_df = pd.DataFrame()
    #         for i in df.tic.unique():
    #             tic_df = df[df.tic == i]
    #             tic_df['macd'], tic_df['macd_signal'], tic_df['macd_hist'] = MACD(tic_df['close'], fastperiod=12,
    #                                                                               slowperiod=26, signalperiod=9)
    #             tic_df['rsi'] = RSI(tic_df['close'], timeperiod=14)
    #             tic_df['cci'] = CCI(tic_df['high'], tic_df['low'], tic_df['close'], timeperiod=14)
    #             tic_df['dx'] = DX(tic_df['high'], tic_df['low'], tic_df['close'], timeperiod=14)
    #             final_df = final_df.append(tic_df)
    #         df = final_df
    #
    #     df = df.sort_values(by=["time", "tic"])
    #     df = df.rename(columns={'time': 'date'})    # 1/11 added by hx
    #     df = df.dropna()
    #     print("Succesfully add technical indicators")
    #     self.dataframe = df

    # def get_trading_days(self, start: str, end: str) -> List[str]:
    #     print('not supported currently!')
    #     return ['not supported currently!']

    # def add_turbulence(self, data: pd.DataFrame) \
    #         -> pd.DataFrame:
    #     print('not supported currently!')
    #     return pd.DataFrame(['not supported currently!'])

    # def calculate_turbulence(self, data: pd.DataFrame, time_period: int = 252) \
    #         -> pd.DataFrame:
    #     print('not supported currently!')
    #     return pd.DataFrame(['not supported currently!'])

    # def add_vix(self, data: pd.DataFrame) \
    #         -> pd.DataFrame:
    #     print('not supported currently!')
    #     return pd.DataFrame(['not supported currently!'])

    # def df_to_array(self, df: pd.DataFrame, tech_indicator_list: List[str], if_vix: bool) \
    #         -> List[np.array]:
    #     print('not supported currently!')
    #     return pd.DataFrame(['not supported currently!'])

    def data_split(self, df, start, end, target_date_col="date"):
        """
        split the dataset into training or testing using date
        :param data: (df) pandas dataframe, start, end
        :return: (df) pandas dataframe
        """
        data = df[(df[target_date_col] >= start) & (df[target_date_col] < end)]
        data = data.sort_values([target_date_col, "tic"], ignore_index=True)
        data.index = data[target_date_col].factorize()[0]
        return data
    
    def data_refine(self,drop_list:List[str]):
        self.dataframe = self.dataframe.drop(columns = drop_list)
        


import tushare as ts
import pandas as pd
from matplotlib import pyplot as plt


class ReturnPlotter:
    """
    An easy-to-use plotting tool to plot cumulative returns over time.
    Baseline supports equal weighting(default) and any stocks you want to use for comparison.
    """

    def __init__(self, df_account_value, df_trade, start_date, end_date):
        self.start = start_date
        self.end = end_date
        self.trade = df_trade
        self.df_account_value = df_account_value

    def get_baseline(self, ticket):
        df = ts.get_hist_data(ticket, start=self.start, end=self.end)
        df.loc[:, 'dt'] = df.index
        df.index = range(len(df))
        df.sort_values(axis=0, by='dt', ascending=True, inplace=True)
        df["date"] = pd.to_datetime(df["dt"], format='%Y-%m-%d')
        return df

    def plot(self, baseline_ticket=None):
        """
        Plot cumulative returns over time.
        use baseline_ticket to specify stock you want to use for comparison
        (default: equal weighted returns)
        """
        baseline_label = "Equal-weight portfolio"
        tic2label = {"399300": "CSI 300 Index", "000016": "SSE 50 Index"}
        if baseline_ticket:
            # 使用指定ticket作为baseline
            baseline_df = self.get_baseline(baseline_ticket)
            baseline_df = baseline_df[baseline_df.dt != "2020-06-26"]  # ours don't have date=="2020-06-26"
            baseline = baseline_df.close.tolist()
            baseline_label = tic2label.get(baseline_ticket, baseline_ticket)
        else:
            # 均等权重
            all_date = self.trade.date.unique().tolist()
            baseline = []
            for day in all_date:
                day_close = self.trade[self.trade["date"] == day].close.tolist()
                avg_close = sum(day_close) / len(day_close)
                baseline.append(avg_close)

        ours = self.df_account_value.account_value.tolist()
        ours = self.pct(ours)
        baseline = self.pct(baseline)

        days_per_tick = 60  # you should scale this variable accroding to the total trading days
        time = list(range(len(ours)))
        datetimes = self.df_account_value.date.tolist()
        ticks = [tick for t, tick in zip(time, datetimes) if t % days_per_tick == 0]
        plt.title("Cumulative Returns")
        plt.plot(time, ours, label="DDPG Agent", color="green")
        plt.plot(time, baseline, label=baseline_label, color="grey")
        plt.xticks([i * days_per_tick for i in range(len(ticks))], ticks, fontsize=7)

        plt.xlabel("Date")
        plt.ylabel("Cumulative Return")

        plt.legend()
        plt.show()

    def plot_all(self):
        baseline_label = "Equal-weight portfolio"
        tic2label = {"399300": "CSI 300 Index", "000016": "SSE 50 Index"}

        # 399300
        baseline_ticket = "399300"
        baseline_df = self.get_baseline(baseline_ticket)
        baseline_df = baseline_df[baseline_df.dt != "2020-06-26"]  # ours don't have date=="2020-06-26"
        baseline_300 = baseline_df.close.tolist()
        baseline_label_300 = tic2label[baseline_ticket]

        # 000016
        baseline_ticket = "000016"
        baseline_df = self.get_baseline(baseline_ticket)
        baseline_df = baseline_df[baseline_df.dt != "2020-06-26"]  # ours don't have date=="2020-06-26"
        baseline_50 = baseline_df.close.tolist()
        baseline_label_50 = tic2label[baseline_ticket]

        # 均等权重
        all_date = self.trade.date.unique().tolist()
        baseline_equal_weight = []
        for day in all_date:
            day_close = self.trade[self.trade["date"] == day].close.tolist()
            avg_close = sum(day_close) / len(day_close)
            baseline_equal_weight.append(avg_close)

        ours = self.df_account_value.account_value.tolist()

        ours = self.pct(ours)
        baseline_300 = self.pct(baseline_300)
        baseline_50 = self.pct(baseline_50)
        baseline_equal_weight = self.pct(baseline_equal_weight)

        
        
        
        
        
        
        days_per_tick = 60  # you should scale this variable accroding to the total trading days
        time = list(range(len(ours)))
        datetimes = self.df_account_value.date.tolist()
        ticks = [tick for t, tick in zip(time, datetimes) if t % days_per_tick == 0]
        plt.title("Cumulative Returns")
        plt.plot(time, ours, label="DDPG Agent", color="darkorange")
        plt.plot(time, baseline_equal_weight, label=baseline_label, color="cornflowerblue")  # equal weight
        plt.plot(time, baseline_300, label=baseline_label_300, color="lightgreen")  # 399300
        plt.plot(time, baseline_50, label=baseline_label_50, color="silver")  # 000016
        plt.xlabel("Date")
        plt.ylabel("Cumulative Return")

        plt.xticks([i * days_per_tick for i in range(len(ticks))], ticks, fontsize=7)
        plt.legend()
        plt.show()

    def pct(self, l):
        """Get percentage"""
        base = l[0]
        return [x / base for x in l]

    def get_return(self, df, value_col_name="account_value"):
        df = deepcopy(df)
        df["daily_return"] = df[value_col_name].pct_change(1)
        df["date"] = pd.to_datetime(df["date"], format='%Y-%m-%d')
        df.set_index("date", inplace=True, drop=True)
        df.index = df.index.tz_localize("UTC")
        return pd.Series(df["daily_return"], index=df.index)

3. 最后就可以用啦,功能真的很丰富!

最后只需要调取tushare类就可以使用了,如果换了路径记得调整import的代码

这边他会分成两段数据,方便后续机器学习的需求,分别用来训练和回测

import TushareProProcessor

ticket_list=['600000.SH', '600009.SH', '600016.SH', '600028.SH', '600030.SH',
       '600031.SH', '600036.SH', '600050.SH', '600104.SH', '600196.SH',
       '600276.SH', '600309.SH', '600519.SH', '600547.SH', '600570.SH']

train_start_date='2015-01-01'
train_stop_date='2019-08-01'
val_start_date='2019-08-01'
val_stop_date='2021-01-03'

token='……………………'  #自己看看自己的是啥啦


# download and clean
ts_processor = TushareProProcessor("tusharepro",train_start_date, val_stop_date, "1D", token=token)
ts_processor.download_data(ticket_list)
ts_processor.clean_data()
ts_processor.dataframe

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值