文章目录
Train
mmdetection支持多卡训练,有两种模式,分别是distributed模式和非distributed模式。官方推荐使用distributed模式。
distributed模式
那我们先讲一下distributed模式,mmdetection是使用tools/dist_train.sh来实现。其使用方法是如下:
./tools/dist_train.sh ${
CONFIG_FILE} ${
GPU_NUM} [optional arguments]
# CONFIG_FILE是指模型的参数文件,例如: ./configs/faster_rcnn_r50_fpn_1x.py
# GPU_NUM是指使用GPU个数
# optional arguments其中可以使用的有:“--validate”,这个表示在trian的过程中使用val数据集进行验证
其中--validate的默认值是1个epoch进行一次validation,如果需要修改,在模型参数文件中加入如下
# 例如在./configs/faster_rcnn_r50_fpn_1x.py加入如下
evaluation = dict(interval=1)
打开dist_train.sh文件,可以看到其实还是调用tools/train.py
但由于我在电脑上跑dist_train.sh总是卡住,也不知道原因,所以我就尝试了非distributed的模式。
非distributed模式
非distributed模式就直接调用tools/train.py就可以,调用格式如下:
python tools/train.py ${
CONFIG_FILE}
# CONFIG_FILE是指模型的参数文件,例如: ./configs/faster_rcnn_r50_fpn_1x.py
需要注意的有如下:
- 在官方推荐中,这个训练方式是单卡训练
- 在官方文档中,这个训练方式没有像dist_train.sh的optional arguments
train过程的追寻
这个解释过程是要从tools/train.py → \rightarrow → mmdet/apis/train.py → \rightarrow → mmcv/runner/runner.py
tools/train.py
想要了解训练过程,就需要仔细地看一下train.py中的内容
# tools/train.py
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# update configs according to CLI args
if args.work_dir is not None:
cfg.work_dir = args.work_dir
if args.resume_from is not None:
cfg.resume_from = args.resume_from
cfg.gpus = args.gpus
if args.autoscale_lr:
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
cfg.optimizer['lr'] = cfg.optimizer['lr'] * cfg.gpus / 8
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
else:
distributed = True
init_dist(args.launcher, **cfg.dist_params)
# init logger before other steps
logger = get_root_logger(cfg.log_level)
logger.info('Distributed training: {}'.format(distributed))
# set random seeds
if args.seed is not None:
logger.info('Set random seed to {}'.format(args.seed))
set_random_seed(args.seed)
#构建模型,其中cfg.model包含着模型的各种参数,加载自CONFIG_FILE
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
#构建训练数据集
datasets = [build_dataset(cfg.data.train