mmsegmentation基础使用学习笔记(三)——使用现有模型在单GPU上进行训练和测试&代码解析

文件路径:mmsegmentation\docs\zh_cn\user_guides\2_dataset_prepare.md 

官方教程:mmsegmentation/docs/zh_cn/user_guides/4_train_test.md at dev-1.x · open-mmlab/mmsegmentation · GitHub

一、训练

1、tools/train.py 文件提供了在单GPU上部署训练任务的方法。

基础用法如下:

python tools/train.py  ${配置文件} [可选参数]

  • --work-dir ${工作路径}: 重新指定工作路径
  • --amp: 使用自动混合精度计算
  • --resume: 从工作路径中保存的最新检查点文件(checkpoint)恢复训练
  • --cfg-options ${需更覆盖的配置}: 覆盖已载入的配置中的部分设置,并且 以 xxx=yyy 格式的键值对 将被合并到配置文件中。 比如: '--cfg-option model.encoder.in_channels=6'

注意: 命令行参数 --resume 和在配置文件中的参数 load_from 的不同之处:

--resume 只决定是否继续使用工作路径中最新的检查点,它常常用于恢复被意外打断的训练。

load_from 会明确指定被载入的检查点文件,且训练迭代器将从0开始,通常用于微调模型。

如果要从指定的检查点上恢复训练可以用以下命令:

python tools/train.py ${配置文件} --resume --cfg-options load_from=${检查点}

2、train.py代码解析

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
import os
import os.path as osp

from mmengine.config import Config, DictAction
from mmengine.logging import print_log
from mmengine.runner import Runner
from mmseg.registry import RUNNERS

# 命令行参数解析
def parse_args():
    parser = argparse.ArgumentParser(description='Train a segmentor')
    parser.add_argument('config', help='train config file path') 
    # 必须提供的训练配置文件路径
    parser.add_argument('--work-dir', help='the dir to save logs and models') 
    # 可选参数,指定保存日志和模型的目录
    parser.add_argument(
        '--resume',
        action='store_true',
        default=False,
        help='resume from the latest checkpoint in the work_dir automatically')
    # 是否从最新的检查点恢复训练(默认为False)
    parser.add_argument(
        '--amp',
        action='store_true',
        default=False,
        help='enable automatic-mixed-precision training')
    # 是否启用自动混合精度训练(默认为False)
    # amp允许在浮点运算中混合使用不同精度的数值以提高训练速度并减少显存占用。
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    # 允许在命令行中覆盖配置文件中的某些设置,以键值对的形式传递。
    # 如: --cfg-options keys=${values}
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    # 选择任务启动器,支持 none、pytorch、slurm 和 mpi 等分布式训练方式。
    # 当使用PyTorch版本>=2.0.0时,“torch.distributed.launch”将把“--local rank”参数传递给“tools/train.py”,而不是“--local_rank”。
    parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
    # 在分布式训练中使用,默认值为 0
    args = parser.parse_args()
    # 使用 argparse 解析命令行参数,并将结果存储在 args 中
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    return args
# 主函数
def main():
    args = parse_args() # args变量将包含所有用户通过命令行传递的参数

    # load config
    cfg = Config.fromfile(args.config) # 读取训练所需的配置文件
    cfg.launcher = args.launcher
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    # 设置工作目录优先级,命令行参数work_dir>配置文件work_dir>默认使用配置文件名创建目录
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])

    # 实现自动混合精度训练
    if args.amp is True:
        optim_wrapper = cfg.optim_wrapper.type
        if optim_wrapper == 'AmpOptimWrapper':
            print_log(
                'AMP training is already enabled in your config.',
                logger='current',
                level=logging.WARNING)
        else:
            assert optim_wrapper == 'OptimWrapper', (
                '`--amp` is only supported when the optimizer wrapper type is '
                f'`OptimWrapper` but got {optim_wrapper}.')
            cfg.optim_wrapper.type = 'AmpOptimWrapper'
            cfg.optim_wrapper.loss_scale = 'dynamic'

    # 恢复训练,根据 --resume 参数来决定是否从最新检查点恢复训练
    cfg.resume = args.resume

    # 根据配置文件构建runner,继承OpenMMLab框架中的训练逻辑
    if 'runner_type' not in cfg: # 默认runner
        # build the default runner
        runner = Runner.from_cfg(cfg)
    else:
        # build customized runner from the registry
        # if 'runner_type' is set in the cfg
        runner = RUNNERS.build(cfg)

    # 开始训练
    runner.train()
    # 训练过程包含:数据加载、模型初始化、优化器和调度器的设置、前向传播(损失计算)反向传播、日志记录与检查点保存、分布式训练支持


if __name__ == '__main__':
    main()

这里需要注意的是,原始的训练流程是默认的,如果需要按照自己的流程对模型进行训练,就需要使用build_runner()函数自定义构建训练器。

build_runner()的参数是从命令行参数解析函数parse_args()中读取得到的,所以目前我认为应该是在配置文件中完善训练流程,具体我会在后续笔记中进行补充。

除了runner可以自己构建,修改mmseg文件夹下的基本模型能够修改网络本身,修改configs文件夹下的配置文件能够修改参数(数据集设置与加载、训练策略、网络参数)。

二、测试

1、tools/test.py 文件提供了在单 GPU 上启动测试任务的方法。

基础用法如下:

python tools/test.py ${配置文件} ${模型权重文件} [可选参数]

  • --work-dir: 如果指定了路径,结果会保存在该路径下。如果没有指定则会保存在 work_dirs/{配置文件名} 路径下.
  • --show: 当 --show-dir 没有指定时,可以使用该参数,在程序运行过程中显示预测结果。
  • --show-dir: 绘制了分割掩膜图片的存储文件夹。如果指定了该参数,则可视化的分割掩膜将被保存到 work_dir/timestamp/{指定路径}.
  • --wait-time: 多次可视化结果的时间间隔。当 --show 为激活状态时发挥作用。默认为2。
  • --cfg-options: 如果被具体指定,以 xxx=yyy 形式的键值对将被合并入配置文件中。

2、test.py代码解析

 

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp

from mmengine.config import Config, DictAction
from mmengine.runner import Runner


# TODO: support fuse_conv_bn, visualization, and format_only
# 解析命令行参数
def parse_args():
    parser = argparse.ArgumentParser(
        description='MMSeg test (and eval) a model')
    parser.add_argument('config', help='train config file path') # 必需,指定路径
    parser.add_argument('checkpoint', help='checkpoint file') 
    # 必需,指定模型的检查点文件(训练中保存的模型权重)
    parser.add_argument(
        '--work-dir',
        help=('if specified, the evaluation metric results will be dumped'
              'into the directory as json'))
    # 可选,指定保存评估结果(如指标)的目录
    parser.add_argument(
        '--out',
        type=str,
        help='The directory to save output prediction for offline evaluation')
    # 可选,指定用于保存输出预测结果的目录,便于离线评估
    parser.add_argument(
        '--show', action='store_true', help='show prediction results')
    # 可选,是否显示预测结果
    parser.add_argument(
        '--show-dir',
        help='directory where painted images will be saved. '
        'If specified, it will be automatically saved '
        'to the work_dir/timestamp/show_dir')
    # 可选,指定保存绘制图像的目录
    parser.add_argument(
        '--wait-time', type=float, default=2, help='the interval of show (s)')
    # 可选,指定显示图像的时间间隔
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    # 可选,允许在命令行中覆盖配置文件中的部分设置
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument(
        '--tta', action='store_true', help='Test time augmentation')
    # 当使用PyTorch版本>=2.0.0时,“torch.distributed.launch”将把“--local rank”参数传递给“tools/train.py”,而不是“--local_rank”。
    parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    return args
# 启用可视化hook
# 检查配置文件中的visualization hook,并根据命令行参数show和show-dir控制可视化行为
def trigger_visualization_hook(cfg, args):
    default_hooks = cfg.default_hooks
    if 'visualization' in default_hooks:
        visualization_hook = default_hooks['visualization']
        # Turn on visualization
        visualization_hook['draw'] = True
        if args.show: # show显示预测结果
            visualization_hook['show'] = True
            visualization_hook['wait_time'] = args.wait_time
        if args.show_dir: # show-dir将可视化结果保存到指定目录
            visualizer = cfg.visualizer
            visualizer['save_dir'] = args.show_dir
    else:
        raise RuntimeError(
            'VisualizationHook must be included in default_hooks.'
            'refer to usage '
            '"visualization=dict(type=\'VisualizationHook\')"')

    return cfg
def main():
    args = parse_args() # 解析命令行参数

    # load config
    cfg = Config.fromfile(args.config)
    cfg.launcher = args.launcher
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    # work_dir is determined in this priority: CLI > segment in file > filename
    if args.work_dir is not None:
        # update configs according to CLI args if args.work_dir is not None
        cfg.work_dir = args.work_dir
    elif cfg.get('work_dir', None) is None:
        # use config filename as default work_dir if cfg.work_dir is None
        cfg.work_dir = osp.join('./work_dirs',
                                osp.splitext(osp.basename(args.config))[0])

    cfg.load_from = args.checkpoint

    # 处理可视化,调用trigger_visualization_hook()
    if args.show or args.show_dir:
        cfg = trigger_visualization_hook(cfg, args)

    # 测试时增强(tta)如果指定了--tta,配置文件中的数据处理管道会被替换为tta_pipeline,并将模型配置为TTA模式
    # 是一种在推理时用多种数据增强方法,以提高在测试集上的泛化能力
    if args.tta:
        cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
        cfg.tta_model.module = cfg.model
        cfg.model = cfg.tta_model

    # 如果指定了--out,将预测结果保存到指定目录
    if args.out is not None:
        cfg.test_evaluator['output_dir'] = args.out
        cfg.test_evaluator['keep_results'] = True

    # build the runner from config
    runner = Runner.from_cfg(cfg)

    # start testing
    runner.test()
    # runner.test() 是启动测试的关键,它会加载模型、数据集并执行推理,然后根据指定的评估方式输出结果


if __name__ == '__main__':
    main()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

还会长的桔子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值