1-baselines/run.py解读

1-baselines/run.py解读

前言

就个人来说,比较喜欢从main函数入手,根据其运行的流程,一步一步找到每一步中涉及的参数以及函数的用意,最终摸清整个一个项目的流程和框架。
书接上文,上文书说到我们根据GitHub上的提示,可以使用下面的命令成功的运行了一个例子。

 python -m baselines.run --alg=deepq --env=PongNoFrameskip-v4 --num_timesteps=1e6

于是我们顺藤摸瓜,找到了baselines文件夹下的ren.py文件,这个文件就是这个项目的入口。

run.py

参数设置

我们根据上面的这个命令,先简单的分析一下他的参数构成,这里主要的参数有两个:

  • 参数alg=deepq表示的是采用的深度强化学习的算法,这里是指DQN。
  • 参数env=PongNoFrameskip-v4表示的是采用的测试的环境,这里是用的是PongNoFrameskip-v4,这是一个乒乓球小游戏,通过控制球拍上下移动接球,没接到球的一方就会丢失一分,先打到21分的一方就获胜了。

main函数

接下来,我们看一下run.py具体是怎么实现的(只讲解一些与程序相关的较大的语句
首先我们先找到main函数

def main(args):
    # configure logger, disable logging in child MPI processes (with rank > 0)

    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args(args)
    extra_args = parse_cmdline_kwargs(unknown_args)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        rank = 0
        configure_logger(args.log_path)
    else:
        rank = MPI.COMM_WORLD.Get_rank()
        configure_logger(args.log_path, format_strs=[])

    model, env = train(args, extra_args)

    if args.save_path is not None and rank == 0:
        save_path = osp.expanduser(args.save_path)
        model.save(save_path)

    if args.play:
        logger.log("Running trained model")
        obs = env.reset()

        state = model.initial_state if hasattr(model, 'initial_state') else None
        dones = np.zeros((1,))

        episode_rew = np.zeros(env.num_envs) if isinstance(env, VecEnv
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值