【HiVT】HiVT轨迹预测代码环境配置及训练

本文介绍了如何从GitHub克隆HiVT项目,创建conda环境,下载并处理Argoverse1.1预测数据集,安装ArgoverseAPI,以及解决训练过程中的文件读取限制问题。详细步骤包括设置环境、数据准备和常见错误处理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

0.简介

github项目链接 论文链接
在这里插入图片描述
Argoverse 1.1验证集的预期性能是:
Models minADE minFDE MR
HiVT-64 0.69 1.03 0.10
HiVT-128 0.66 0.97 0.09

在这里插入图片描述

1. 拉取代码仓库

git clone https://github.com/ZikangZhou/HiVT.git
cd HiVT

2. 创建conda环境

conda create -n HiVT python=3.8
conda activate HiVT
conda install pytorch==1.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge
conda install pytorch-geometric==1.7.2 -c rusty1s -c conda-forge
conda install pytorch-lightning==1.5.2 -c conda-forge

3. 下载Argoverse预测数据集

Argoverse Motion Forecasting Dataset v1.1
下载后解压成下面的形式

Argoverse1.1
├── train/
|   └── data/
|       ├── 1.csv
|       ├── 2.csv
|       ├── ...
└── val/
    └── data/
        ├── 1.csv
        ├── 2.csv

4. 安装Argoverse API

如果需要安装在HiVT的conda环境里,忽略下面连接中创建conda环境的操作
Ubuntu Argoverse API安装

5. 训练

如果数据集在home路径(~/Argoverse1.1/)下,执行

python train.py --root ~/Argoverse1.1/ --embed_dim 64  //To train HiVT-64
python train.py --root ~/Argoverse1.1/ --embed_dim 128 //To train HiVT-128

[Image]

6. 常见错误

6.1 RuntimeError(‘received %d items of ancdata’ %

训练过程中报该错误,是因为文件读取太多,修改文件读取限制

ulimit -n         //查看读取文件限制数量
ulimit -n 65536   //修改读取文件限制数量为65536
### Argoverse 数据集测试使用教程 #### 下载数据集 为了开始使用 Argoverse 数据集进行测试,首先需要下载相应的样本数据集。可以从官方网站获取 `Sample Datasets v1.1` 版本的数据集文件,但暂时不需要解压缩这些文件[^1]。 #### 环境配置 在准备运行基于 Argoverse 的实验之前,确保按照官方文档中的说明来设置开发环境。这通常涉及到创建虚拟环境并安装必要的依赖库。具体步骤可以参照 HiVT 项目提供的指南完成环境搭建工作[^4]。 #### 加载API接口 Argoverse 提供了一个 Python API 来帮助研究人员轻松访问和处理其发布的各种类型的数据集。通过研究 argoverse-api 项目的目录结构及其功能描述,能够更好地理解如何利用该工具包读取、解析以及可视化轨迹预测任务所需的信息[^3]。 #### 执行测试案例 一旦完成了上述准备工作,则可以根据个人需求编写脚本来加载已下载的数据集,并调用适当的方法来进行分析或训练模型。对于想要尝试的目标检测与跟踪等功能模块,可参考 YOLO系列算法的应用实例作为入门指导[^2]。 ```python from argoverse.data_loading.argoverse_tracking_loader import ArgoverseTrackingLoader root_dir = "/path/to/downloaded/dataset" argo_track_loader = ArgoverseTrackingLoader(root_dir) for seq_id,seq_df in argo_track_loader: print(f"Processing sequence {seq_id}") # Add your code here to process each sequence dataframe (seq_df) ```
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

HIT_Vanni

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值