简介
Diffusion Policy 可以理解为扩散模型在机器人控制中的应用,能够结合模仿学习,通过观察人类专家的演示来学习策略。本文初步带领部署以及使用Diffusion Policy进行模拟PushT任务。
环境配置
1.安装anaconda
2.在anaconda下赞装mamba
conda install mamba -n base -c conda-forge
3.接下来按照代码中readme文件一步一步配置就可以了
安装依赖
sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf
创建环境
这里需要创建一个叫做robodiff的conda环境,创建过程耗时非常长,需要耐心等待。可以通过htop等软件监视内存变化以判断程序是否正常进行。
mamba env create -f conda_environment.yaml
教程同样给了一条备选方案,通过conda进行安装。作者这里使用该指令爆内存,请使用mamba并耐心等待。
conda env create -f conda_environment.yaml
安装wandb
pip install wandb
PushT demo运行
在主文件夹下,创建data文件夹
mkdir data && cd data
下载PushT数据集
wget https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip
解压数据集
unzip pusht.zip && cd ..
下载PushT的模型配置文件
wget -O image_pusht_diffusion_policy_cnn.yaml https://diffusion-policy.cs.columbia.edu/data/experiments/image/pusht/diffusion_policy_cnn/config.yaml
运行评估指令
-
设置环境变量:
export HYDRA_FULL_ERROR=1
-
运行指令:
python eval.py --checkpoint data/0550-test_mean_score=0.969.ckpt --output_dir data/pusht_eval_output
-
产生报错:
ImportError: Error loading 'diffusion_policy.workspace.train_diffusion_unet_lowdim_workspace.TrainDiffusionUnetLowdimWorkspace':
ImportError("cannot import name 'cached_download' from 'huggingface_hub' (/home/kyle/anaconda3/envs/robodiff/lib/python3.9/site-packages/huggingface_hub/__init__.py)")
- 解决方案
pip install huggingface_hub==0.25.2
- 再次运行
python eval.py --checkpoint data/0550-test_mean_score=0.969.ckpt --output_dir data/pusht_eval_output --device cuda:0
Output path data/pusht_eval_output already exists! Overwrite? [y/N]: y
pygame 2.1.2 (SDL 2.0.16, Python 3.9.18)
Hello from the pygame community. https://www.pygame.org/contribute.html
Eval PushtKeypointsRunner 1/1: 16%|█████ | 48/300 [00:16<01:26, 2.93it/s]
运行结果