FinRL源码解析之:main.py

from __future__ import annotations

import os
from argparse import ArgumentParser
from typing import List

from finrl.config import ALPACA_API_BASE_URL
from finrl.config import DATA_SAVE_DIR
from finrl.config import ERL_PARAMS
from finrl.config import INDICATORS
from finrl.config import RESULTS_DIR
from finrl.config import TENSORBOARD_LOG_DIR
from finrl.config import TEST_END_DATE
from finrl.config import TEST_START_DATE
from finrl.config import TRADE_END_DATE
from finrl.config import TRADE_START_DATE
from finrl.config import TRAIN_END_DATE
from finrl.config import TRAIN_START_DATE
from finrl.config import TRAINED_MODEL_DIR
from finrl.config_tickers import DOW_30_TICKER
from finrl.meta.env_stock_trading.env_stocktrading_np import StockTradingEnv

# construct environment

# try:
#     from finrl.config_private import ALPACA_API_KEY, ALPACA_API_SECRET
# except ImportError:
#     raise FileNotFoundError(
#         "Please set your own ALPACA_API_KEY and ALPACA_API_SECRET in config_private.py"
#     )


def build_parser():
    parser = ArgumentParser()
    parser.add_argument(
        "--mode",
        dest="mode",
        help="start mode, train, download_data" " backtest",
        metavar="MODE",
        default="train",
    )
    return parser


# "./" will be added in front of each directory
def check_and_make_directories(directories: list[str]):
    for directory in directories:
        if not os.path.exists("./" + directory):
            os.makedirs("./" + directory)


def main() -> int:
    parser = build_parser()
    options = parser.parse_args()
    check_and_make_directories(
        [DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR]
    )

    if options.mode == "train":
        from finrl import train

        env = StockTradingEnv

        # demo for elegantrl
        kwargs = (
            {}
        )  # in current meta, with respect yahoofinance, kwargs is {}. For other data sources, such as joinquant, kwargs is not empty
        train(
            start_date=TRAIN_START_DATE,
            end_date=TRAIN_END_DATE,
            ticker_list=DOW_30_TICKER,
            data_source="yahoofinance",
            time_interval="1D",
            technical_indicator_list=INDICATORS,
            drl_lib="elegantrl",
            env=env,
            model_name="ppo",
            cwd="./test_ppo",
            erl_params=ERL_PARAMS,
            break_step=1e5,
            kwargs=kwargs,
        )
    elif options.mode == "test":
        from finrl import test

        env = StockTradingEnv

        # demo for elegantrl
        # in current meta, with respect yahoofinance, kwargs is {}. For other data sources, such as joinquant, kwargs is not empty
        kwargs = {}

        account_value_erl = test(  # noqa
            start_date=TEST_START_DATE,
            end_date=TEST_END_DATE,
            ticker_list=DOW_30_TICKER,
            data_source="yahoofinance",
            time_interval="1D",
            technical_indicator_list=INDICATORS,
            drl_lib="elegantrl",
            env=env,
            model_name="ppo",
            cwd="./test_ppo",
            net_dimension=512,
            kwargs=kwargs,
        )
    elif options.mode == "trade":
        from finrl import trade

        try:
            from finrl.config_private import ALPACA_API_KEY, ALPACA_API_SECRET
        except ImportError:
            raise FileNotFoundError(
                "Please set your own ALPACA_API_KEY and ALPACA_API_SECRET in config_private.py"
            )
        env = StockTradingEnv
        kwargs = {}
        trade(
            start_date=TRADE_START_DATE,
            end_date=TRADE_END_DATE,
            ticker_list=DOW_30_TICKER,
            data_source="yahoofinance",
            time_interval="1D",
            technical_indicator_list=INDICATORS,
            drl_lib="elegantrl",
            env=env,
            model_name="ppo",
            API_KEY=ALPACA_API_KEY,
            API_SECRET=ALPACA_API_SECRET,
            API_BASE_URL=ALPACA_API_BASE_URL,
            trade_mode="paper_trading",
            if_vix=True,
            kwargs=kwargs,
            state_dim=len(DOW_30_TICKER) * (len(INDICATORS) + 3)
            + 3,  # bug fix: for ppo add dimension of state/observations space =  len(stocks)* len(INDICATORS) + 3+ 3*len(stocks)
            action_dim=len(
                DOW_30_TICKER
            ),  # bug fix: for ppo add dimension of action space = len(stocks)
        )
    else:
        raise ValueError("Wrong mode.")
    return 0


# Users can input the following command in terminal
# python main.py --mode=train
# python main.py --mode=test
# python main.py --mode=trade
if __name__ == "__main__":
    raise SystemExit(main())

这段代码主要包含了三个函数的定义,它们分别是 build_parser(), check_and_make_directories(directories: list[str]), 和 main() -> int。我会逐行为你解释它们的内容。

  1. def build_parser():: 定义了一个名为 build_parser 的函数,这个函数用来创建一个命令行参数解析器

  2. parser = ArgumentParser(): 创建了一个 ArgumentParser 对象,该对象用于解析命令行参数。

  3. parser.add_argument(...: 这部分代码向解析器添加了一个命令行参数 --mode,它有以下属性:

    • dest="mode": 参数的名称是 "mode"
    • help=...: 提供的是参数的描述。
    • metavar="MODE": 在帮助文档中,参数值的名称是 "MODE"。
    • default="train": 如果用户没有提供这个参数,它的默认值是 "train"。
  4. return parser: 返回创建的解析器对象

  5. def check_and_make_directories(directories: list[str]):: 定义了一个函数,它接受一个目录名列表,并检查这些目录是否存在。如果目录不存在,则创建它。

  6. for directory in directories:: 遍历提供的目录列表

  7. if not os.path.exists("./" + directory):: 使用 os.path.exists 检查目录是否存在。如果不存在...

  8. os.makedirs("./" + directory): ...使用 os.makedirs 创建该目录。

  9. def main() -> int:: 定义了主函数。

  10. parser = build_parser(): 通过前面定义的 build_parser 函数创建命令行参数解析器。

  11. options = parser.parse_args(): 解析命令行参数,并将结果保存在 options 中。

  12. check_and_make_directories([...: 使用前面定义的函数检查并创建四个目录。

  13. if options.mode == "train":: 检查 mode 参数是否为 "train"。

  14. from finrl import train: 如果是 "train" 模式,从 finrl 模块中导入 train 函数。

  15. env = StockTradingEnv: 创建一个名为 env 的环境对象。

  16. kwargs = ...: 创建一个关键字参数字典。根据注释,当前的数据源为 "yahoofinance" 时,该字典为空。

  17. train(...): 调用前面导入的 train 函数并传入一系列参数来进行训练。

  18. elif options.mode == "test":: 检查 mode 参数是否为 "test"。

  19. from finrl import test: 如果是 "test" 模式,从 finrl 模块中导入 test 函数。

  20. env = StockTradingEnv: 创建一个名为 env 的环境对象。

  21. kwargs = {}: 创建一个关键字参数字典。

  22. account_value_erl = test(...): 调用前面导入的 test 函数,并传入一系列参数来进行测试。

  23. elif options.mode == "trade":: 检查 mode 参数是否为 "trade"。

  24. from finrl import trade: 如果是 "trade" 模式,从 finrl 模块中导入 trade 函数

该代码段主要负责创建命令行参数解析器,确保必要的目录存在,并基于提供的命令行参数(train、test或trade)执行相应的操作。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值