1. 数据准备
下载urbanlanegraph数据集,并进行解析
python data_preparation/generate_raw_sample.py /path/to/dataset /path/to/raw/output miami train
# miami为城市名称,共6个城市,名称可根据需求更改,为paloalto,pittsburgh,austin,washington,detroit
# tain为训练集,可以为eval,官方数据中没有提供test
2. 训练过程,分3个模型
2.1 训练context模型
python methods/train_centerline_regression.py --dataset_root /path/to/raw/output/dataset --sdf_version centerlines_sdf_context
# 3 channels image 最后生成context_checkpoints.pth
2.2 训练ego_context模型
在2.1训练后训练该模型
python methods/train_centerline_regression.py --dataset_root /path/to/raw/output/ --sdf_version centerlines_sdf_ego_context --checkpoint_pth_context_regression /path/to/context_regression/context_checkpoints.pth
# 4 channels image 最后生成ego_context_checkpoint.pth
2.3 lanegnn训练
2.3.1 准备processed特征数据
python data_preparation/generate_pth_samples.py --config methods/lanegnn/config/config.yaml --raw_dataset /path/to/raw/output --processed_dataset /path/to/preocessed/output/ --context_regressor_ckpt /path/to/context_regression/context_checkpoints.pth --ego_regressor_ckpt /path/to/ego_context_regression/ego_context_checkpoint.pth
# 调用2.1和2.2生成的模型,以及1的raw数据,生成processed数据,供lanegnn使用。数据格式可直接生成./train/processed/ ./eval/processed/
2.3.2 lanegnn模型训练
python methods/train_lanegnn.py --config methods/lanegnn/config/config.yaml --dataroot /path/to/processed/dataset/
3 测试过程
3.1 测试数据格式
测试数据包括所测试图像以及所测试图像的context图像,拥有上下文信息的图像,可以观察train/eval数据集中以rgb和rgb_context.png格式
举例:
miami__00*00*-rgb.png 为24位的256*256图像
miami***_00_00-rgb-context.png 为24位512512图像
3.2 测试
可根据data_preparation/genertate_pth_samples.py更改,生成特征数据,格式为pth
根据train_lanegnn.py中相关代码部分,写test代码