该存储库提供了RIPGeo框架的原始PyTorch实现。
一、基础用法
1、需求 Requirements
代码使用python 3.8.13、PyTorch 1.12.1、cudatoolkit 11.6.0和cudnn 7.6.5进行了测试。通过Anaconda安装依赖:
# create virtual environment
conda create --name RIPGeo python=3.8.13# activate environment
conda activate RIPGeo# install pytorch & cudatoolkit
conda install pytorch torchvision torchaudio cudatoolkit=11.6 -c pytorch -c conda-forge
# install other requirements
conda install numpy pandas
pip install scikit-learn
2、运行代码
# Open the "RIPGeo" folder
cd RIPGeo# data preprocess (executing IP clustering).
python preprocess.py --dataset "New_York"
python preprocess.py --dataset "Los_Angeles"
python preprocess.py --dataset "Shanghai"# run the model RIPGeo
python main.py --dataset "New_York" --dim_in 30 --lr 2e-3 --saved_epoch 100
python main.py --dataset "Los_Angeles" --dim_in 30 --lr 2e-3 --saved_epoch 100
python main.py --dataset "Shanghai" --dim_in 51 --lr 1e-3 --saved_epoch 70# load the checkpoint and then test
python test.py --dataset "New_York" --dim_in 30 --lr 2e-3 --load_epoch 100
python test.py --dataset "Los_Angeles" --dim_in 30 --lr 2e-3 --load_epoch 100
python test.py --dataset "Shanghai" --dim_in 51 --lr 1e-3 --load_epoch 70
二、main.py中使用的超参数描述
超参数 |
---|