系列文章目录
目录
前言
我们介绍的 MuJoCo Playground 是一个完全开源的机器人学习框架,由 MJX 构建,其明确目标是简化机器人的仿真、训练和从仿真到现实的转移。只需一个简单的安装过程(pip install playground),研究人员就能在单个 GPU 上几分钟内完成策略训练。Playground 支持各种机器人平台,包括四足机器人、人形机器人、灵巧的手和机械臂,并能从状态和像素输入实现零点模拟到真实的传输。这是通过由物理引擎、批量渲染器和训练环境组成的集成堆栈实现的。MuJoCo Playground 是一项由多个团体参与的社区工作,我们希望它能为研究人员和开发人员提供宝贵的帮助。
一、学习 RL 智能体
在本目录中,我们演示了使用 Brax 和 RSL-RL 从 MuJoCo Playground 环境中学习 RL 智能体。我们提供了两个命令行入口:python train_jax_ppo.py 和 python train_rsl_rl.py。
有关使用 MuJoCo Playground 进行 RL 的更多详细教程,请参阅:
- 使用 DM 控制套件的游乐场介绍
- 运动环境
- 操纵环境
- 来自 Vision 的训练车摆
- 来自 Vision 的机器人操作
1.1 使用 brax PPO 进行培训
要使用 brax PPO 进行训练,可以使用 train_jax_ppo.py 脚本。该脚本使用 brax PPO 算法在给定环境中训练智能体。
python train_jax_ppo.py --env_name=CartpoleBalance
利用像素观测数据训练基于视觉的策略:
python train_jax_ppo.py --env_name=CartpoleBalance --vision
使用 python train_jax_ppo.py --help 查看可能的选项和用法。日志和检查点保存在 logs 目录中。
1.2 使用 RSL-RL 进行训练
要使用 RSL-RL 进行训练,可以使用 train_rsl_rl.py 脚本。该脚本使用 RSL-RL 算法在给定环境中训练智能体。
python train_rsl_rl.py --env_name=LeapCubeReorient
根据所产生的政策来呈现行为:
python learning/train_rsl_rl.py --env_name LeapCubeReorient --play_only --load_run_name <run_name>
其中 run_name 是要加载的运行名称(将在启动训练运行时打印在控制台中)。
日志和检查点保存在日志目录下。
二、欢迎来到 MuJoCo Playground!
欢迎来到 MuJoCo Playground,我们很高兴见到您!MuJoCo Playground 包含一整套用于强化学习和机器人研究的环境。在本笔记本中,我们将介绍通过 MJX 移植到 GPU 上运行的 DM 控制套件环境。
需要使用具有 GPU 加速功能的 Colab 运行时。如果您使用的是纯 CPU 运行时,可以通过菜单 “运行时 > 更改运行时类型 ”进行切换。
#@title Install pre-requisites
!pip install mujoco
!pip install mujoco_mjx
!pip install brax
# @title Check if MuJoCo installation was successful
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
raise RuntimeError(
'Cannot communicate with GPU. '
'Make sure you are using a GPU Colab runtime. '
'Go to the Runtime menu and select Choose runtime type.'
)
# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
f.write("""{
"file_format_version" : "1.0.0",
"ICD" : {
"library_path" : "libEGL_nvidia.so.0"
}
}
""")
# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl
try:
print('Checking that the installation succeeded:')
import mujoco
mujoco.MjModel.from_xml_string('')
except Exception as e:
raise e from RuntimeError(
'Something went wrong during installation. Check the shell output above '
'for more information.\n'
'If using a hosted Colab runtime, make sure you enable GPU acceleration '
'by going to the Runtime menu and selecting "Choose runtime type".'
)
print('Installation successful.')
# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
# @title Import packages for plotting and creating graphics
import itertools
import time
from typing import Callable, List, NamedTuple, Optional, Union
import numpy as np
# Graphics and plotting.
print("Installing mediapy:")
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt
# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)
# @title Import MuJoCo, MJX, and Brax
from datetime import datetime
import functools
import os
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.io import html, mjcf, model
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import networks as sac_networks
from brax.training.agents.sac import train as sac
from etils import epath
from flax import struct
from flax.training import orbax_utils
from IPython.display import HTML, clear_output
import jax
from jax import numpy as jp
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
import numpy as np
from orbax import checkpoint as ocp
#@title Install MuJoCo Playground
!pip install playground
2.1 介绍
MuJoCo Playground 包含 DM 控制套件、机器人运动和机器人操作环境。您可以通过环境注册表加载其中任何一个环境:
from mujoco_playground import registry
env = registry.load('CartpoleBalance')
env
每个环境还与一个环境配置相关联,如果需要,可以覆盖该配置。我们也来加载配置:
env_cfg = registry.get_default_config('CartpoleBalance')
env_cfg
请注意,环境配置包含 sim_dt 和 ctrl_dt。ctrl_dt 决定每个 env.step 的时间间隔。因此,每个 env.step 的模拟步长为 ctrl_dt / sim_dt 倍。
其他值得注意的参数还有 vision_config,我们将在基于视觉的笔记本中详细讨论!现在,我们将专注于特权观察。
2.2 滚动
现在我们来做一个简单的启动。这与 MJX 教程笔记本密切相关。如果你对 MJX 比较熟悉,这应该不难理解;否则,你可以先看看 MJX 教程!
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state]
f = 0.5
for i in range(env_cfg.episode_length):
action = []
for j in range(env.action_size):
action.append(
jp.sin(
state.data.time * 2 * jp.pi * f + j * 2 * jp.pi / env.action_size
)
)
action = jp.array(action)
state = jit_step(state, action)
rollout.append(state)
frames = env.render(rollout)
media.show_video(frames, fps=1.0 / env.dt)
如果您使用 GPU 实例运行笔记本(应该如此),请注意环境会在设备上运行和存活!这当然意味着我们可以在 GPU 上进行一些大批量的 RL 操作。
state.obs.device
2.3 RL
我们将使用 brax 来训练一些 RL 策略,但我们会在主软件仓库中通过 python train_rsl_rl.py 展示 RSL-RL 的示例。我们鼓励 pytorch 用户去看一看!
现在,我们先以 brax PPO 为例。
MuJoCo Playground 通过 mujoco_playground.config 为所有可用环境提供了一组超参数。部分环境包含 RSL-RL 和基于视觉的 brax PPO 的配置。让我们为 CartpoleBalance 载入 brax PPO 配置!
from mujoco_playground.config import dm_control_suite_params
ppo_params = dm_control_suite_params.brax_ppo_config('CartpoleBalance')
ppo_params
当然,我们还要培训一项政策!
2.4 PPO
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]
def progress(num_steps, metrics):
clear_output(wait=True)
times.append(datetime.now())
x_data.append(num_steps)
y_data.append(metrics["eval/episode_reward"])
y_dataerr.append(metrics["eval/episode_reward_std"])
plt.xlim([0, ppo_params["num_timesteps"] * 1.25])
plt.ylim([0, 1100])
plt.xlabel("# environment steps")
plt.ylabel("reward per episode")
plt.title(f"y={y_data[-1]:.3f}")
plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")
display(plt.gcf())
ppo_training_params = dict(ppo_params)
network_factory = ppo_networks.make_ppo_networks
if "network_factory" in ppo_params:
del ppo_training_params["network_factory"]
network_factory = functools.partial(
ppo_networks.make_ppo_networks,
**ppo_params.network_factory
)
train_fn = functools.partial(
ppo.train, **dict(ppo_training_params),
network_factory=network_factory,
progress_fn=progress
)
from mujoco_playground import wrapper
make_inference_fn, params, metrics = train_fn(
environment=env,
wrap_env_fn=wrapper.wrap_for_brax_training,
)
print(f"time to jit: {times[1] - times[0]}")
print(f"time to train: {times[-1] - times[1]}")
如果你熟悉 brax,就会注意到我们为 train_fn 提供了一个自定义包装器。这是因为 MuJoCo Playground 环境与 vanilla brax 封装器不兼容。对于 RSL-RL,我们在 wrapper_torch.py 中提供了封装器。
2.5 实现可视化推出
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))
rng = jax.random.PRNGKey(42)
rollout = []
n_episodes = 1
for _ in range(n_episodes):
state = jit_reset(rng)
rollout.append(state)
for i in range(env_cfg.episode_length):
act_rng, rng = jax.random.split(rng)
ctrl, _ = jit_inference_fn(state.obs, act_rng)
state = jit_step(state, ctrl)
rollout.append(state)
render_every = 1
frames = env.render(rollout[::render_every])
rewards = [s.reward for s in rollout]
media.show_video(frames, fps=1.0 / env.dt / render_every)
2.6 DM 控制套件 - 体验一下!
现在,您可以在任何 DM 控制套件环境中尽情体验!世界尽在你掌握。
env_name = "FishSwim" # @param ["AcrobotSwingup", "AcrobotSwingupSparse", "BallInCup", "CartpoleBalance", "CartpoleBalanceSparse", "CartpoleSwingup", "CartpoleSwingupSparse", "CheetahRun", "FingerSpin", "FingerTurnEasy", "FingerTurnHard", "FishSwim", "HopperHop", "HopperStand", "HumanoidStand", "HumanoidWalk", "HumanoidRun", "PendulumSwingup", "PointMass", "ReacherEasy", "ReacherHard", "SwimmerSwimmer6", "WalkerRun", "WalkerStand", "WalkerWalk"]
CAMERAS = {
"AcrobotSwingup": "fixed",
"AcrobotSwingupSparse": "fixed",
"BallInCup": "cam0",
"CartpoleBalance": "fixed",
"CartpoleBalanceSparse": "fixed",
"CartpoleSwingup": "fixed",
"CartpoleSwingupSparse": "fixed",
"CheetahRun": "side",
"FingerSpin": "cam0",
"FingerTurnEasy": "cam0",
"FingerTurnHard": "cam0",
"FishSwim": "fixed_top",
"HopperHop": "cam0",
"HopperStand": "cam0",
"HumanoidStand": "side",
"HumanoidWalk": "side",
"HumanoidRun": "side",
"PendulumSwingup": "fixed",
"PointMass": "cam0",
"ReacherEasy": "fixed",
"ReacherHard": "fixed",
"SwimmerSwimmer6": "tracking1",
"WalkerRun": "side",
"WalkerWalk": "side",
"WalkerStand": "side",
}
camera_name = CAMERAS[env_name]
env_cfg = registry.get_default_config(env_name)
env = registry.load(env_name, config=env_cfg)
2.7 环境可视化
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
state = jit_reset(jax.random.PRNGKey(0))
rollout = [state]
f = 0.5
for i in range(env_cfg.episode_length):
action = []
for j in range(env.action_size):
action.append(
jp.sin(
state.data.time * 2 * jp.pi * f + j * 2 * jp.pi / env.action_size
)
)
action = jp.array(action)
state = jit_step(state, action)
rollout.append(state)
frames = env.render(rollout)
media.show_video(frames, fps=1.0 / env.dt)
2.8 训练
from mujoco_playground.config import dm_control_suite_params
ppo_params = dm_control_suite_params.brax_ppo_config(env_name)
sac_params = dm_control_suite_params.brax_sac_config(env_name)
2.9 PPO
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]
def progress(num_steps, metrics):
clear_output(wait=True)
times.append(datetime.now())
x_data.append(num_steps)
y_data.append(metrics["eval/episode_reward"])
y_dataerr.append(metrics["eval/episode_reward_std"])
plt.xlim([0, ppo_params["num_timesteps"] * 1.25])
plt.ylim([0, 1100])
plt.xlabel("# environment steps")
plt.ylabel("reward per episode")
plt.title(f"y={y_data[-1]:.3f}")
plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")
display(plt.gcf())
ppo_training_params = dict(ppo_params)
network_factory = ppo_networks.make_ppo_networks
if "network_factory" in ppo_params:
del ppo_training_params["network_factory"]
network_factory = functools.partial(
ppo_networks.make_ppo_networks,
**ppo_params.network_factory
)
train_fn = functools.partial(
ppo.train, **dict(ppo_training_params),
network_factory=network_factory,
progress_fn=progress
)
make_inference_fn, params, metrics = train_fn(
environment=env,
wrap_env_fn=wrapper.wrap_for_brax_training,
)
print(f"time to jit: {times[1] - times[0]}")
print(f"time to train: {times[-1] - times[1]}")
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))
rng = jax.random.PRNGKey(42)
rollout = []
n_episodes = 1
for _ in range(n_episodes):
state = jit_reset(rng)
rollout.append(state)
for i in range(env_cfg.episode_length):
act_rng, rng = jax.random.split(rng)
ctrl, _ = jit_inference_fn(state.obs, act_rng)
state = jit_step(state, ctrl)
rollout.append(state)
render_every = 1
frames = env.render(rollout[::render_every])
rewards = [s.reward for s in rollout]
media.show_video(frames, fps=1.0 / env.dt / render_every)
2.10 SAC
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]
def progress(num_steps, metrics):
clear_output(wait=True)
times.append(datetime.now())
x_data.append(num_steps)
y_data.append(metrics["eval/episode_reward"])
y_dataerr.append(metrics["eval/episode_reward_std"])
plt.xlim([0, sac_params["num_timesteps"] * 1.25])
plt.ylim([0, 1100])
plt.xlabel("# environment steps")
plt.ylabel("reward per episode")
plt.title(f"y={y_data[-1]:.3f}")
plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")
display(plt.gcf())
sac_training_params = dict(sac_params)
network_factory = sac_networks.make_sac_networks
if "network_factory" in sac_params:
del sac_training_params["network_factory"]
network_factory = functools.partial(
sac_networks.make_sac_networks,
**sac_params.network_factory
)
train_fn = functools.partial(
sac.train, **dict(sac_training_params),
network_factory=network_factory,
progress_fn=progress
)
make_inference_fn, params, metrics = train_fn(
environment=env,
wrap_env_fn=wrapper.wrap_for_brax_training,
)
print(f"time to jit: {times[1] - times[0]}")
print(f"time to train: {times[-1] - times[1]}")
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))
rng = jax.random.PRNGKey(0)
rollout = []
n_episodes = 1
for _ in range(n_episodes):
state = jit_reset(rng)
rollout.append(state)
for i in range(env_cfg.episode_length):
act_rng, rng = jax.random.split(rng)
ctrl, _ = jit_inference_fn(state.obs, act_rng)
state = jit_step(state, ctrl)
rollout.append(state)
render_every = 1
frames = env.render(rollout[::render_every])
rewards = [s.reward for s in rollout]
media.show_video(frames, fps=1.0 / env.dt / render_every)