stable_baselines3与imitation的结合——关于模型的调用和再训练的问题

最近在尝试使用stable_baselines3和imitation库在python下做模仿学习,遇到了一些问题,查询了官方文档还有GitHub都没有合适的解决方法,因此在这里分享一下解决方案。

 

首先贴一下官方文档的网址:

stable_baselines3
imitation

本文章的使用的环境是pybullet,环境代码是在网上找到的一个比较好的,做了一点小修改,原作者是这位,我用了他的环境(env)。

下面简单说一下如何使用imitation库下的BC(行为克隆),官网代码如下:

import numpy as np
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms import bc
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper

rng = np.random.default_rng(0)
env = gym.make("CartPole-v1")
expert = PPO(policy=MlpPolicy, env=env)
expert.learn(1000)

rollouts = rollout.rollout(
    expert,
    DummyVecEnv([lambda: RolloutInfoWrapper(env)]),
    rollout.make_sample_until(min_timesteps=None, min_episodes=50),
    rng=rng,
)
transitions = rollout.flatten_trajectories(rollouts)

bc_trainer = bc.BC(
    observation_space=env.observation_space,
    action_space=env.action_space,
    demonstrations=transitions,
    rng=rng,
)
bc_trainer.train(n_epochs=1)
reward, _ = evaluate_policy(bc_trainer.policy, env, 10)
print("Reward:", reward)

对于代码不做过多介绍,详情请参考官方文档。

之前我有一些问题:

1. 对于模仿学习后保存的模型,如何进行再调用或者继续模仿学习训练?

2. 对于模仿学习后保存的模型,如何将其中的网络参数拿出来给PPO用,也就是完成模仿学习+强化学习的联合?

3. 简单谈谈专家数据增强问题

下面依次对问题进行解释:

1. 保存加载模型,使用如下代码

#保存模型
bc_trainer.save_policy("bc_policy")
#加载模型
bc_trainer2=bc.reconstruct_policy("bc_policy")

第二个问题则比较简单,只需执行这个语句:

expert.policy=bc_trainer2
#这是因为通过bc读取的格式,就是policy

第三个问题之前读过一篇论文,其中谈到如何对机械臂任务的专家数据做扩充。一个简单的思路就是将专家数据的关键帧提取出来,帧之间添加噪声,这样一个专家数据可以扩充很多条轨迹。关键帧的定义是机械臂末端速度为零的节点。如果我写出这个代码并验证成功,会分享一下。

  • 11
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值