OpenAI baseline GAIL代码讲解及其可视化

最近在研究关于强化学习的部分工作,首先从OpenAI的Baseline中的小型GAIL算法出发。

首先参考了大神的文章从《西部世界》到GAIL(Generative Adversarial Imitation Learning)算法。

原文链接:https://blog.csdn.net/jinzhuojun/article/details/85220327#commentBox

对大神写的文章做一些补充和细节解释。

在baseline 的文件夹中运行即可以进行模型的训练

python3 -m baselines.gail.run_mujoco

在run_mujoco.py代码中写到

   parser.add_argument('--task', type=str, choices=['train', 'evaluate', 'sample'], default='train')

 可以在命令行后面添加 --task 改变任务为train 和evaluate。evaluate后面要加上存储的模型的地址

# 假设训练模型放在/home/jzj/source/baselines/checkpoint/trpo_gail.transition_limitation_-1.Hopper.g_step_3.d_step_1.policy_entcoeff_0.adversary_entcoeff_0.001.seed_0/
python3 -m baselines.gail.run_mujoco --task=evaluate  --load_model_path=/home/jzj/source/baselines/checkpoint/trpo_gail.transition_limitation_-1.Hopper.g_step_3.d_step_1.policy_entcoeff_0.adversary_entcoeff_0.001.seed_0/trpo_gail.transition_limitation_-1.Hopper.g_step_3.d_step_1.policy_entcoeff_0.adversary_entcoeff_0.001.seed_0

在baseline 中使用tensorflow方式存储模型:  在trpo_mpi.py  232行。

        # Save model
        if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
            fname = os.path.join(ckpt_dir, task_name)
            #U.save_variables(fname)
            #print("the save path is ",fname)
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver = tf.train.Saver()
            saver.save(tf.get_default_session(), fname)

所以在checkpoint中存储了可以用tensorflow方式读取模型的三个文件,而在运行评估模型时读取模型的方式采用的是baseline 中common自己定义的    U.load_variables(load_model_path)来读取文件,读取文件的类型是上面由tensorflow生成的文件的集合体。

    U.load_variables(load_model_path)

因此在存储模型的时候也应该采用common中的定义的save_variables来存储模型生成集成文件:

        if rank == 0 and iters_so_far % save_per_iter == 0 and ckpt_dir is not None:
            fname = os.path.join(ckpt_dir, task_name)
            U.save_variables(fname)
            print("the save path is ",fname)
            os.makedirs(os.path.dirname(fname), exist_ok=True)
            saver = tf.train.Saver()
            saver.save(tf.get_default_session(), fname)

 然后运行train的命令行,在训练100次迭代之后就可以在保存模型的文件夹中发现一个无.data/.index/.meta后缀的集成文件。

此时再运行evaluate命令行就可以出现对模型的评估返回数据

在run_mujoco.py中的traj_1_generator函数中的while函数中插入env.render()就可以渲染出模型可视化结果。

    while True:
        ac, vpred = pi.act(stochastic, ob)
        obs.append(ob)
        news.append(new)
        acs.append(ac)

        ob, rew, new, _ = env.step(ac)
        rews.append(rew)
        env.render()
        
        cur_ep_ret += rew
        cur_ep_len += 1
        if new or t >= horizon:
            break
        t += 1

感谢大佬的分享,同时在遇到困难的时候还是要敢于挑战权威呀。

  • 4
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值