MMdetection之train.py源码详解

本文详细介绍了如何使用MMDetection框架训练和优化一个目标检测模型,包括从命令行参数解析配置,构建模型(如Faster R-CNN)和数据集,以及核心的训练流程。通过train.py脚本,模型的训练过程涉及了构建detector、dataset和train_detector函数,这些函数在源码中被详细解释。此外,文章还讨论了设置随机种子、日志记录和分布式训练的相关细节。

目录

一、tools/train.py

二、源码详解

三、核心函数详解

(一)build_detector(mmdet/models/builder.py)

(二) build_dataset(mmdet/datasets/builder)

(三) train_detector(mmdet/apis/train.py)

(四)set_random_seed:

(五)get_root_logger:


一、tools/train.py

可选参数:

# =========== optional arguments ===========
# --work-dir        存储日志和模型的目录
# --resume-from     加载 checkpoint 的目录
# --no-validate     是否在训练的时候进行验证
# 互斥组:
#   --gpus          使用的 GPU 数量
#   --gpu_ids       使用指定 GPU 的 id
# --seed            随机数种子
# --deterministic   是否设置 cudnn 为确定性行为
# --options         其他参数
# --launcher        分布式训练使用的启动器,可以为:['none', 'pytorch', 'slurm', 'mpi']
#                   none:不启动分布式训练,dist_train.sh 中默认使用 pytorch 启动。
# --local_rank      本地进程编号,此参数 torch.distributed.launch 会自动传入。

对于 tools/train.py 其主要的流程如下

        对于 train.py 来说,首先从命令行和配置文件读取配置,然后分别用 build_detector、build_dataset 构建模型和数据集,最后将模型和数据集传入 train_detector 进行训练。

(一)从命令行和配置文件获取参数配置

cfg = Config.fromfile(args.config)  

(二)构建模型

# 构建模型: 需要传入 cfg.model,cfg.train_cfg,cfg.test_cfg
model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

(三)构建数据集

# 构建数据集: 需要传入 cfg.data.train,表明是训练集
datasets = [build_dataset(cfg.data.train)]

(四)训练模型

# 训练检测器:需要传入模型、数据集、配置参数等
train_detector(
    model,
    datasets,
    cfg,
    distributed=distributed,
    validate=(not args.no_validate),
    timestamp=timestamp,
    meta=meta)

二、源码详解

import argparse
import copy
import os
import os.path as osp
import time

import mmcv
import torch
# Config 用于读取配置文件, DictAction 将命令行字典类型参数转化为 key-value 形式
from mmcv import Config, DictAction
from mmcv.runner import init_dist

from mmdet import __version__
from mmdet.apis import set_random_seed, train_detector
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.utils import collect_env, get_root_logger


# python tools/train.py ${CONFIG_FILE} [optional arguments]

# =========== optional arguments ===========
# --work-dir        存储日志和模型的目录
# --resume-from     加载 checkpoint 的目录
# --no-validate     是否在训练的时候进行验证
# 互斥组:
#   --gpus          使用的 GPU 数量
#   --gpu_ids       使用指定 GPU 的 id
# --seed            随机数种子
# --deterministic   是否设置 cudnn 为确定性行为
# --options         其他参数
# --launcher        分布式训练使用的启动器,可以为:['none', 'pytorch', 'slurm', 'mpi']
#                   none:不启动分布式训练,dist_train.sh 中默认使用 pytorch 启动。
# --local_rank      本地进程编号,此参数 torch.distributed.launch 会自动传入。

def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
    parser.add_argument('--work-dir', help='the dir to save logs and models')
    parser.add_argument(
        '--resume-from', help='the checkpoint file to resume from')

    # action: store (默认, 表示保存参数)
    # action: store_true, store_false (如果指定参数, 则为 True, False)
    parser.add_argument(
        '--no-validate',
        action='store_true',
        help='whether not to evaluate the checkpoint during training')

    # --------- 创建一个互斥组. argparse 将会确保互斥组中的参数只能出现一个 ---------
    group_gpus = parser.add_mutually_exclusive_group()
    group_gpus.add_argument(
        '--gpus',
        type=int,
        help='number of gpus to use '
        '(only applicable to non-distributed training)')

    # 可以使用 python train.py --gpu-ids 0 1 2 3 指定使用的 GPU id
    # 参数结果:[0, 1, 2, 3]
    # nargs = '*':参数个数可以设置0个或n个
    # nargs = '+':参数个数可以设置1个或n个
    # nargs = '?':参数个数可以设置0个或1个
    group_gpus.add_argument(
        '--gpu-ids',
        type=int,
        nargs='+',
        help='ids of gpus to use '
        '(only applicable to non-distributed training)')
    # ------------------------------------------------------------------------

    parser.add_argument('--seed', type=int, default=None, help='random seed')
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')

    # 其他参数: 可以使用 --options a=1,2,3 指定其他参数
    # 参数结果: {'a': [1, 2, 3]}
    parser.add_argument(
        '--options', nargs='+', action=DictAction, help='arguments in dict')

    # 如果使用 dist_utils.sh 进行分布式训练, launcher 默认为 pytorch
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')

    # 本地进程编号,此参数 torch.distributed.launch 会自动传入。
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    # 如果环境中没有 LOCAL_RANK,就设置它为当前的 local_rank
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    return args


def main():
    args = parse_args()

    # 从文件读取配置
    cfg = Config.fromfile(args.config)
    # 从命令行读取额外的配置
    if args.options is not None:
        cfg.merge_from_dict(args.options)

    # 设置 cudnn_benchmark = True 可以加速输入大小固定的模型. 如:SSD300
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    # work_dir 的优先程度为: 命令行 > 配置文件
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    # 当 work_dir 为 None 的时候, 使用 ./work_dir/配置文件名 作为默认工作目录
    elif cfg.get('work_dir', None) is None:
        # os.path.basename(path)    返回文件名
        # os.path.splitext(path)    分割路径, 返回路径名和文件扩展名的元组
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])
    # 是否继续上次的训练
    if args.resume_from is not None:
        cfg.resume_from = args.resume_from
    # gpu id
    if args.gpu_ids is not None:
        cfg.gpu_ids = args.gpu_ids
    else:
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)

    # 如果 launcher 为 none,不启用分布式训练。不使用 dist_train.sh 默认参数为 none.
    if args.launcher == 'none':
        distributed = False
    # launcher 不为 none,启用分布式训练。使用 dist_train.sh,会传 ‘pytorch’
    else:
        distributed = True
        # 初始化 dist 里面会调用 ini
mmdetection是一个目标检测框架,train.py是用于训练模型的脚本。在运行mmdetectiontrain.py之前,需要先安装好mmdetection框架及其依赖环境。 train.py的运行需要通过命令行参数来指定模型配置文件和训练数据集文件。首先,需要确定好所使用的模型配置文件,该文件用于指定模型的结构、超参数等信息。其次,需要准备好训练数据集文件,包括训练图片、标注文件等。 在运行train.py之前,可以先配置一些训练参数,如学习率、训练轮数、批量大小等。这些参数可以在命令行中通过设置参数值来指定,也可以直接在train.py脚本中进行修改。可以根据实际需要调整这些参数的数值。 运行train.py的命令形式如下: ``` python train.py ${CONFIG_FILE} [--work-dir ${WORK_DIR}] ``` 其中,`${CONFIG_FILE}`是模型配置文件的路径,`--work-dir ${WORK_DIR}`是可选参数,指定训练结果的保存路径。如果未指定`--work-dir`参数,则默认保存在当前路径下的`work_dirs`目录中。 train.py的运行过程主要分为以下几个步骤:加载配置文件、构建模型、加载训练数据、定义优化器、定义学习率策略、开始训练。在训练过程中,会按照一定的周期迭代进行训练,每个周期结束时会进行验证,并根据验证结果保存最优模型。 训练完成后,可以使用训练好的模型进行目标检测任务。mmdetection还提供了其他功能,如测试模型、评估模型等,可以根据具体需求选择相应的方法进行。
评论 6
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值