Preference Transformer:基于Transformer的强化学习人类偏好建模
项目介绍
Preference Transformer 是一个基于Transformer架构的开源项目,旨在通过建模人类偏好来优化强化学习(RL)中的非马尔可夫奖励。该项目由来自KAIST、密歇根大学、LG AI Research、UC Berkeley和Google Research的研究团队共同开发,并在ICLR 2023会议上发表。
项目通过构建隐藏嵌入 ${\mathbf{x}_t}$ 来捕捉从初始时间步到时间步 $t$ 的上下文信息,利用偏好注意力层计算非马尔可夫奖励 ${\hat{r}t}$ 及其凸组合 ${z_t }$,最终聚合这些信息以建模非马尔可夫奖励的加权总和 $\sum{t}{w_t \hat{r}_t }$。
项目技术分析
核心技术
- Transformer架构:利用Transformer的强大序列建模能力,捕捉时间序列数据中的复杂依赖关系。
- 偏好注意力机制:通过双向自注意力计算非马尔可夫奖励,提升模型对人类偏好的理解。
- Jax/Flax实现:基于Jax和Flax框架,提供高效的计算性能和可扩展性。
技术优势
- 非马尔可夫奖励建模:有效处理复杂环境中奖励的时序依赖性。
- 高灵活性:支持多种环境和数据集,如D4RL和Robosuite。
- 开源友好:提供详细的安装指南和运行示例,便于研究和应用。
项目及技术应用场景
应用领域
- 机器人控制:在Robosuite等机器人模拟环境中,优化机器人行为以符合人类偏好。
- 游戏AI:在复杂游戏环境中,通过学习人类玩家的偏好,提升AI智能体的表现。
- 自动驾驶:在自动驾驶系统中,根据人类驾驶员的偏好,优化驾驶策略。
典型案例
- D4RL数据集:通过训练奖励模型,提升强化学习算法在多种任务中的性能。
- Robosuite环境:利用人类演示数据,训练机器人执行复杂任务,如抓取和放置。
项目特点
- 强大的建模能力:基于Transformer架构,能够有效捕捉和建模复杂的人类偏好。
- 高效实现:利用Jax/Flax框架,提供高效的计算和优化能力。
- 广泛适用性:支持多种环境和数据集,适用于多种强化学习任务。
- 开源便捷:提供详细的安装和运行指南,便于研究人员和开发者快速上手。
结语
Preference Transformer 项目为强化学习领域提供了一个全新的视角,通过建模人类偏好,优化非马尔可夫奖励,显著提升了智能体的学习效果。无论是学术界还是工业界,该项目都具有广泛的应用前景。欢迎各位研究人员和开发者尝试和使用这个强大的开源工具!
安装与运行
安装依赖
conda create -y -n offline python=3.8
conda activate offline
pip install --upgrade pip
conda install -y -c conda-forge cudatoolkit=11.1 cudnn=8.2.1
pip install -r requirements.txt
cd d4rl
pip install -e .
cd ..
pip install "jax[cuda11_cudnn805]>=0.2.27" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install protobuf==3.20.1 gym<0.24.0 distrax==0.1.2 wandb
pip install transformers
运行示例
D4RL环境
# Preference Transfomer (PT)
CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --transformer.embd_dim 256 --transformer.n_layer 1 --transformer.n_head 4 --env {D4RL env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len 100 --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type PrefTransformer
Robosuite环境
# Preference Transfomer (PT)
CUDA_VISIBLE_DEVICES=0 python -m JaxPref.new_preference_reward_main --use_human_label True --comment {experiment_name} --robosuite True --robosuite_dataset_type {dataset_type} --robosuite_dataset_path {path for robomimic demonstrations} --transformer.embd_dim 256 --transformer.n_layer 1 --transformer.n_head 4 --env {Robosuite env name} --logging.output_dir './logs/pref_reward' --batch_size 256 --num_query {number of query} --query_len {100|50} --n_epochs 10000 --skip_flag 0 --seed {seed} --model_type PrefTransformer
更多详细信息和运行示例,请参考项目官方文档。