from __future__ import annotations
import os
from argparse import ArgumentParser
from typing import List
from finrl.config import ALPACA_API_BASE_URL
from finrl.config import DATA_SAVE_DIR
from finrl.config import ERL_PARAMS
from finrl.config import INDICATORS
from finrl.config import RESULTS_DIR
from finrl.config import TENSORBOARD_LOG_DIR
from finrl.config import TEST_END_DATE
from finrl.config import TEST_START_DATE
from finrl.config import TRADE_END_DATE
from finrl.config import TRADE_START_DATE
from finrl.config import TRAIN_END_DATE
from finrl.config import TRAIN_START_DATE
from finrl.config import TRAINED_MODEL_DIR
from finrl.config_tickers import DOW_30_TICKER
from finrl.meta.env_stock_trading.env_stocktrading_np import StockTradingEnv
# construct environment
# try:
# from finrl.config_private import ALPACA_API_KEY, ALPACA_API_SECRET
# except ImportError:
# raise FileNotFoundError(
# "Please set your own ALPACA_API_KEY and ALPACA_API_SECRET in config_private.py"
# )
def build_parser():
parser = ArgumentParser()
parser.add_argument(
"--mode",
dest="mode",
help="start mode, train, download_data" " backtest",
metavar="MODE",
default="train",
)
return parser
# "./" will be added in front of each directory
def check_and_make_directories(directories: list[str]):
for directory in directories:
if not os.path.exists("./" + directory):
os.makedirs("./" + directory)
def main() -> int:
parser = build_parser()
options = parser.parse_args()
check_and_make_directories(
[DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR]
)
if options.mode == "train":
from finrl import train
env = StockTradingEnv
# demo for elegantrl
kwargs = (
{}
) # in current meta, with respect yahoofinance, kwargs is {}. For other data sources, such as joinquant, kwargs is not empty
train(
start_date=TRAIN_START_DATE,
end_date=TRAIN_END_DATE,
ticker_list=DOW_30_TICKER,
data_source="yahoofinance",
time_interval="1D",
technical_indicator_list=INDICATORS,
drl_lib="elegantrl",
env=env,
model_name="ppo",
cwd="./test_ppo",
erl_params=ERL_PARAMS,
break_step=1e5,
kwargs=kwargs,
)
elif options.mode == "test":
from finrl import test
env = StockTradingEnv
# demo for elegantrl
# in current meta, with respect yahoofinance, kwargs is {}. For other data sources, such as joinquant, kwargs is not empty
kwargs = {}
account_value_erl = test( # noqa
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
ticker_list=DOW_30_TICKER,
data_source="yahoofinance",
time_interval="1D",
technical_indicator_list=INDICATORS,
drl_lib="elegantrl",
env=env,
model_name="ppo",
cwd="./test_ppo",
net_dimension=512,
kwargs=kwargs,
)
elif options.mode == "trade":
from finrl import trade
try:
from finrl.config_private import ALPACA_API_KEY, ALPACA_API_SECRET
except ImportError:
raise FileNotFoundError(
"Please set your own ALPACA_API_KEY and ALPACA_API_SECRET in config_private.py"
)
env = StockTradingEnv
kwargs = {}
trade(
start_date=TRADE_START_DATE,
end_date=TRADE_END_DATE,
ticker_list=DOW_30_TICKER,
data_source="yahoofinance",
time_interval="1D",
technical_indicator_list=INDICATORS,
drl_lib="elegantrl",
env=env,
model_name="ppo",
API_KEY=ALPACA_API_KEY,
API_SECRET=ALPACA_API_SECRET,
API_BASE_URL=ALPACA_API_BASE_URL,
trade_mode="paper_trading",
if_vix=True,
kwargs=kwargs,
state_dim=len(DOW_30_TICKER) * (len(INDICATORS) + 3)
+ 3, # bug fix: for ppo add dimension of state/observations space = len(stocks)* len(INDICATORS) + 3+ 3*len(stocks)
action_dim=len(
DOW_30_TICKER
), # bug fix: for ppo add dimension of action space = len(stocks)
)
else:
raise ValueError("Wrong mode.")
return 0
# Users can input the following command in terminal
# python main.py --mode=train
# python main.py --mode=test
# python main.py --mode=trade
if __name__ == "__main__":
raise SystemExit(main())
这段代码主要包含了三个函数的定义,它们分别是 build_parser()
, check_and_make_directories(directories: list[str])
, 和 main() -> int
。我会逐行为你解释它们的内容。
-
def build_parser():
: 定义了一个名为build_parser
的函数,这个函数用来创建一个命令行参数解析器。 -
parser = ArgumentParser()
: 创建了一个ArgumentParser
对象,该对象用于解析命令行参数。 -
parser.add_argument(...
: 这部分代码向解析器添加了一个命令行参数--mode
,它有以下属性:dest="mode"
: 参数的名称是 "mode"。help=...
: 提供的是参数的描述。metavar="MODE"
: 在帮助文档中,参数值的名称是 "MODE"。default="train"
: 如果用户没有提供这个参数,它的默认值是 "train"。
-
return parser
: 返回创建的解析器对象。 -
def check_and_make_directories(directories: list[str]):
: 定义了一个函数,它接受一个目录名列表,并检查这些目录是否存在。如果目录不存在,则创建它。 -
for directory in directories:
: 遍历提供的目录列表。 -
if not os.path.exists("./" + directory):
: 使用os.path.exists
检查目录是否存在。如果不存在... -
os.makedirs("./" + directory)
: ...使用os.makedirs
创建该目录。 -
def main() -> int:
: 定义了主函数。 -
parser = build_parser()
: 通过前面定义的build_parser
函数创建命令行参数解析器。 -
options = parser.parse_args()
: 解析命令行参数,并将结果保存在options
中。 -
check_and_make_directories([...
: 使用前面定义的函数检查并创建四个目录。 -
if options.mode == "train":
: 检查mode
参数是否为 "train"。 -
from finrl import train
: 如果是 "train" 模式,从finrl
模块中导入train
函数。 -
env = StockTradingEnv
: 创建一个名为env
的环境对象。 -
kwargs = ...
: 创建一个关键字参数字典。根据注释,当前的数据源为 "yahoofinance" 时,该字典为空。 -
train(...)
: 调用前面导入的train
函数并传入一系列参数来进行训练。 -
elif options.mode == "test":
: 检查mode
参数是否为 "test"。 -
from finrl import test
: 如果是 "test" 模式,从finrl
模块中导入test
函数。 -
env = StockTradingEnv
: 创建一个名为env
的环境对象。 -
kwargs = {}
: 创建一个关键字参数字典。 -
account_value_erl = test(...)
: 调用前面导入的test
函数,并传入一系列参数来进行测试。 -
elif options.mode == "trade":
: 检查mode
参数是否为 "trade"。 -
from finrl import trade
: 如果是 "trade" 模式,从finrl
模块中导入trade
函数。
该代码段主要负责创建命令行参数解析器,确保必要的目录存在,并基于提供的命令行参数(train、test或trade)执行相应的操作。