放几个大佬分享:
Argoverse数据集API安装和HiVT代码调试踩坑及排坑记录
HiVT部署
https://github.com/ZikangZhou/HiVT
1. torch版本
python=3.8/3.9时,如果装了2.0.1版本的torch+cu118,引用torch.nn里面的tranformer模块时,在前向转播的过程中可能会报错
File "/home/xx/anaconda3/envs/argo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) TypeError: forward() got an unexpected keyword argument 'is_causal'
建议安装低一点版本的torch
conda create -n HiVT python=3.8
conda activate HiVT
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
pip install torch-geometric==1.7.2
pip install pytorch-lightning==1.5.2
pip install torchmetrics==0.8.2
这里不是pytorch-geometric,要去安装torch-gemetric对应的四个包:具体方法见链接:
图神经网络GNN所需PyTorch Geometric库的安装
Argoverse API
- 进入 setup.py , 将 sklearn 修改为 scikit-learn 再 pip install -e .进行安装
- 将 numpy ==1.19.0 修改为其他版本,如 numpy ==1.24.3 再 pip install -e .进行安装。