【mmdetection实践】(二)训练自己的网络

Train

mmdetection支持多卡训练,有两种模式,分别是distributed模式和非distributed模式。官方推荐使用distributed模式。

distributed模式

那我们先讲一下distributed模式,mmdetection是使用tools/dist_train.sh来实现。其使用方法是如下:

./tools/dist_train.sh ${
   CONFIG_FILE} ${
   GPU_NUM} [optional arguments]
# CONFIG_FILE是指模型的参数文件,例如: ./configs/faster_rcnn_r50_fpn_1x.py
# GPU_NUM是指使用GPU个数
# optional arguments其中可以使用的有:“--validate”,这个表示在trian的过程中使用val数据集进行验证

其中--validate的默认值是1个epoch进行一次validation,如果需要修改,在模型参数文件中加入如下

# 例如在./configs/faster_rcnn_r50_fpn_1x.py加入如下
evaluation = dict(interval=1)

打开dist_train.sh文件,可以看到其实还是调用tools/train.py

但由于我在电脑上跑dist_train.sh总是卡住,也不知道原因,所以我就尝试了非distributed的模式。

非distributed模式

非distributed模式就直接调用tools/train.py就可以,调用格式如下:

python tools/train.py ${
   CONFIG_FILE}
# CONFIG_FILE是指模型的参数文件,例如: ./configs/faster_rcnn_r50_fpn_1x.py

需要注意的有如下:

  • 在官方推荐中,这个训练方式是单卡训练
  • 在官方文档中,这个训练方式没有像dist_train.sh的optional arguments

train过程的追寻

这个解释过程是要从tools/train.py → \rightarrow mmdet/apis/train.py → \rightarrow mmcv/runner/runner.py

tools/train.py

想要了解训练过程,就需要仔细地看一下train.py中的内容

# tools/train.py
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    # update configs according to CLI args
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    cfg.gpus = args.gpus
    if args.autoscale_lr:
        # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
        cfg.optimizer['lr'] = cfg.optimizer['lr'] * cfg.gpus / 8

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # init logger before other steps
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed training: {}'.format(distributed))

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)
    
    #构建模型,其中cfg.model包含着模型的各种参数,加载自CONFIG_FILE
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
    
    #构建训练数据集
    datasets = [build_dataset(cfg.data.train
  • 12
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值