前言
就个人来说,比较喜欢从main函数入手,根据其运行的流程,一步一步找到每一步中涉及的参数以及函数的用意,最终摸清整个一个项目的流程和框架。
书接上文,上文书说到我们根据GitHub上的提示,可以使用下面的命令成功的运行了一个例子。
python -m baselines.run --alg=deepq --env=PongNoFrameskip-v4 --num_timesteps=1e6
于是我们顺藤摸瓜,找到了baselines文件夹下的ren.py文件,这个文件就是这个项目的入口。
run.py
参数设置
我们根据上面的这个命令,先简单的分析一下他的参数构成,这里主要的参数有两个:
- 参数alg=deepq表示的是采用的深度强化学习的算法,这里是指DQN。
- 参数env=PongNoFrameskip-v4表示的是采用的测试的环境,这里是用的是PongNoFrameskip-v4,这是一个乒乓球小游戏,通过控制球拍上下移动接球,没接到球的一方就会丢失一分,先打到21分的一方就获胜了。
main函数
接下来,我们看一下run.py具体是怎么实现的(只讲解一些与程序相关的较大的语句)
首先我们先找到main函数
def main(args):
# configure logger, disable logging in child MPI processes (with rank > 0)
arg_parser = common_arg_parser()
args, unknown_args = arg_parser.parse_known_args(args)
extra_args = parse_cmdline_kwargs(unknown_args)
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
rank = 0
configure_logger(args.log_path)
else:
rank = MPI.COMM_WORLD.Get_rank()
configure_logger(args.log_path, format_strs=[])
model, env = train(args, extra_args)
if args.save_path is not None and rank == 0:
save_path = osp.expanduser(args.save_path)
model.save(save_path)
if args.play:
logger.log("Running trained model")
obs = env.reset()
state = model.initial_state if hasattr(model, 'initial_state') else None
dones = np.zeros((1,))
episode_rew = np.zeros(env.num_envs) if isinstance(env, VecEnv