目录
(一)build_detector(mmdet/models/builder.py)
(二) build_dataset(mmdet/datasets/builder)
(三) train_detector(mmdet/apis/train.py)
一、tools/train.py
可选参数:
# =========== optional arguments ===========
# --work-dir 存储日志和模型的目录
# --resume-from 加载 checkpoint 的目录
# --no-validate 是否在训练的时候进行验证
# 互斥组:
# --gpus 使用的 GPU 数量
# --gpu_ids 使用指定 GPU 的 id
# --seed 随机数种子
# --deterministic 是否设置 cudnn 为确定性行为
# --options 其他参数
# --launcher 分布式训练使用的启动器,可以为:['none', 'pytorch', 'slurm', 'mpi']
# none:不启动分布式训练,dist_train.sh 中默认使用 pytorch 启动。
# --local_rank 本地进程编号,此参数 torch.distributed.launch 会自动传入。
对于 tools/train.py 其主要的流程如下:
对于 train.py 来说,首先从命令行和配置文件读取配置,然后分别用 build_detector、build_dataset 构建模型和数据集,最后将模型和数据集传入 train_detector 进行训练。
(一)从命令行和配置文件获取参数配置
cfg = Config.fromfile(args.config)
(二)构建模型
# 构建模型: 需要传入 cfg.model,cfg.train_cfg,cfg.test_cfg
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
(三)构建数据集
# 构建数据集: 需要传入 cfg.data.train,表明是训练集
datasets = [build_dataset(cfg.data.train)]
(四)训练模型
# 训练检测器:需要传入模型、数据集、配置参数等
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=(not args.no_validate),
timestamp=timestamp,
meta=meta)
二、源码详解
import argparse
import copy
import os
import os.path as osp
import time
import mmcv
import torch
# Config 用于读取配置文件, DictAction 将命令行字典类型参数转化为 key-value 形式
from mmcv import Config, DictAction
from mmcv.runner import init_dist
from mmdet import __version__
from mmdet.apis import set_random_seed, train_detector
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.utils import collect_env, get_root_logger
# python tools/train.py ${CONFIG_FILE} [optional arguments]
# =========== optional arguments ===========
# --work-dir 存储日志和模型的目录
# --resume-from 加载 checkpoint 的目录
# --no-validate 是否在训练的时候进行验证
# 互斥组:
# --gpus 使用的 GPU 数量
# --gpu_ids 使用指定 GPU 的 id
# --seed 随机数种子
# --deterministic 是否设置 cudnn 为确定性行为
# --options 其他参数
# --launcher 分布式训练使用的启动器,可以为:['none', 'pytorch', 'slurm', 'mpi']
# none:不启动分布式训练,dist_train.sh 中默认使用 pytorch 启动。
# --local_rank 本地进程编号,此参数 torch.distributed.launch 会自动传入。
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume-from', help='the checkpoint file to resume from')
# action: store (默认, 表示保存参数)
# action: store_true, store_false (如果指定参数, 则为 True, False)
parser.add_argument(
'--no-validate',
action='store_true',
help='whether not to evaluate the checkpoint during training')
# --------- 创建一个互斥组. argparse 将会确保互斥组中的参数只能出现一个 ---------
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--gpus',
type=int,
help='number of gpus to use '
'(only applicable to non-distributed training)')
# 可以使用 python train.py --gpu-ids 0 1 2 3 指定使用的 GPU id
# 参数结果:[0, 1, 2, 3]
# nargs = '*':参数个数可以设置0个或n个
# nargs = '+':参数个数可以设置1个或n个
# nargs = '?':参数个数可以设置0个或1个
group_gpus.add_argument(
'--gpu-ids',
type=int,
nargs='+',
help='ids of gpus to use '
'(only applicable to non-distributed training)')
# ------------------------------------------------------------------------
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
# 其他参数: 可以使用 --options a=1,2,3 指定其他参数
# 参数结果: {'a': [1, 2, 3]}
parser.add_argument(
'--options', nargs='+', action=DictAction, help='arguments in dict')
# 如果使用 dist_utils.sh 进行分布式训练, launcher 默认为 pytorch
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
# 本地进程编号,此参数 torch.distributed.launch 会自动传入。
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
# 如果环境中没有 LOCAL_RANK,就设置它为当前的 local_rank
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
# 从文件读取配置
cfg = Config.fromfile(args.config)
# 从命令行读取额外的配置
if args.options is not None:
cfg.merge_from_dict(args.options)
# 设置 cudnn_benchmark = True 可以加速输入大小固定的模型. 如:SSD300
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
# work_dir 的优先程度为: 命令行 > 配置文件
if args.work_dir is not None:
cfg.work_dir = args.work_dir
# 当 work_dir 为 None 的时候, 使用 ./work_dir/配置文件名 作为默认工作目录
elif cfg.get('work_dir', None) is None:
# os.path.basename(path) 返回文件名
# os.path.splitext(path) 分割路径, 返回路径名和文件扩展名的元组
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
# 是否继续上次的训练
if args.resume_from is not None:
cfg.resume_from = args.resume_from
# gpu id
if args.gpu_ids is not None:
cfg.gpu_ids = args.gpu_ids
else:
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
# 如果 launcher 为 none,不启用分布式训练。不使用 dist_train.sh 默认参数为 none.
if args.launcher == 'none':
distributed = False
# launcher 不为 none,启用分布式训练。使用 dist_train.sh,会传 ‘pytorch’
else:
distributed = True
# 初始化 dist 里面会调用 ini

本文详细介绍了如何使用MMDetection框架训练和优化一个目标检测模型,包括从命令行参数解析配置,构建模型(如Faster R-CNN)和数据集,以及核心的训练流程。通过train.py脚本,模型的训练过程涉及了构建detector、dataset和train_detector函数,这些函数在源码中被详细解释。此外,文章还讨论了设置随机种子、日志记录和分布式训练的相关细节。
最低0.47元/天 解锁文章
3275





