polylane 复现

源代码:https://github.com/lucastabelini/PolyLaneNet

1.数据集准备

参考Ultra-Fast-Lane-Detection 复现

2.环境搭建

conda create -n lane python=3.7
conda activate polylane  # 激活环境
pip install -r requirements.txt

3.修改配置文件

cfgs 下的配置文件,tusimple.yaml,修改数据集和测试集的路径(即下面中的root)

# Training settings
exps_dir: 'experiments'
iter_log_interval: 1
iter_time_window: 100
model_save_interval: 1
seed: 1
backup:
model:
  name: PolyRegression
  parameters:
    num_outputs: 35 # (5 lanes) * (1 conf + 2 (upper & lower) + 4 poly coeffs)
    pretrained: true
    backbone: 'efficientnet-b0'
    pred_category: false
    curriculum_steps: [0, 0, 0, 0]
loss_parameters:
  conf_weight: 1
  lower_weight: 1
  upper_weight: 1
  cls_weight: 0
  poly_weight: 300
batch_size: 16
epochs: 2695
optimizer:
  name: Adam
  parameters:
    lr: 3.0e-4
lr_scheduler:
  name: CosineAnnealingLR
  parameters:
    T_max: 385

# Testing settings
test_parameters:
  conf_threshold: 0.5

# Dataset settings
datasets:
  train:
    type: LaneDataset
    parameters:
      dataset: tusimple
      split: train
      img_size: [360, 640]
      normalize: true
      aug_chance: 0.9090909090909091 # 10/11
      augmentations:
       - name: Affine
         parameters:
           rotate: !!python/tuple [-10, 10]
       - name: HorizontalFlip
         parameters:
           p: 0.5
       - name: CropToFixedSize
         parameters:
           width: 1152
           height: 648
      root:  "/home/***/code/PolyLaneNet-master/tusimple" 

  test: &test
    type: LaneDataset
    parameters:
      dataset: tusimple
      split: test          #val
      max_lanes: 5
      img_size: [360, 640]
      root: "/home/***/code/PolyLaneNet-master/tusimple" 
      normalize: true
      augmentations: []

  # val = test
  val:
    <<: *test

要将test中的split由val改为test

4.训练与测试

#训练
python3 train.py --exp_name tusimple --cfg cfgs/tusimple.yaml
#测试
python3 test.py --exp_name tusimple --cfg cfgs/tusimple.yaml --epoch 2695  # 可以直接github里下载提供的模型进行测试

5.查看模型的计算量与参数量

在test.py中的第25行(即model.eval())添加以下代码

    from thop import profile, clever_format
    x = torch.zeros((1, 3, 360, 640)).to(device) + 1
    macs, params = profile(model, inputs=(x,))
    macs, params = clever_format([macs, params], "%.3f")
    print('MACs: {}'.format(macs))
    print('Params: {}'.format(params))
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值