文件路径:mmsegmentation\docs\zh_cn\user_guides\2_dataset_prepare.md
一、训练
1、tools/train.py 文件提供了在单GPU上部署训练任务的方法。
基础用法如下:
python tools/train.py ${配置文件} [可选参数]
--work-dir ${工作路径}
: 重新指定工作路径--amp
: 使用自动混合精度计算--resume
: 从工作路径中保存的最新检查点文件(checkpoint)恢复训练--cfg-options ${需更覆盖的配置}
: 覆盖已载入的配置中的部分设置,并且 以 xxx=yyy 格式的键值对 将被合并到配置文件中。 比如: '--cfg-option model.encoder.in_channels=6'
注意: 命令行参数 --resume 和在配置文件中的参数 load_from 的不同之处:
--resume 只决定是否继续使用工作路径中最新的检查点,它常常用于恢复被意外打断的训练。
load_from 会明确指定被载入的检查点文件,且训练迭代器将从0开始,通常用于微调模型。
如果要从指定的检查点上恢复训练可以用以下命令:
python tools/train.py ${配置文件} --resume --cfg-options load_from=${检查点}
2、train.py代码解析
# Copyright (c) OpenMMLab. All rights reserved. import argparse import logging import os import os.path as osp from mmengine.config import Config, DictAction from mmengine.logging import print_log from mmengine.runner import Runner from mmseg.registry import RUNNERS # 命令行参数解析 def parse_args(): parser = argparse.ArgumentParser(description='Train a segmentor') parser.add_argument('config', help='train config file path') # 必须提供的训练配置文件路径 parser.add_argument('--work-dir', help='the dir to save logs and models') # 可选参数,指定保存日志和模型的目录 parser.add_argument( '--resume', action='store_true', default=False, help='resume from the latest checkpoint in the work_dir automatically') # 是否从最新的检查点恢复训练(默认为False) parser.add_argument( '--amp', action='store_true', default=False, help='enable automatic-mixed-precision training') # 是否启用自动混合精度训练(默认为False) # amp允许在浮点运算中混合使用不同精度的数值以提高训练速度并减少显存占用。 parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') # 允许在命令行中覆盖配置文件中的某些设置,以键值对的形式传递。 # 如: --cfg-options keys=${values} parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') # 选择任务启动器,支持 none、pytorch、slurm 和 mpi 等分布式训练方式。 # 当使用PyTorch版本>=2.0.0时,“torch.distributed.launch”将把“--local rank”参数传递给“tools/train.py”,而不是“--local_rank”。 parser.add_argument('--local_rank', '--local-rank', type=int, default=0) # 在分布式训练中使用,默认值为 0 args = parser.parse_args() # 使用 argparse 解析命令行参数,并将结果存储在 args 中 if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) return args# 主函数 def main(): args = parse_args() # args变量将包含所有用户通过命令行传递的参数 # load config cfg = Config.fromfile(args.config) # 读取训练所需的配置文件 cfg.launcher = args.launcher if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # 设置工作目录优先级,命令行参数work_dir>配置文件work_dir>默认使用配置文件名创建目录 if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None cfg.work_dir = args.work_dir elif cfg.get('work_dir', None) is None: # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) # 实现自动混合精度训练 if args.amp is True: optim_wrapper = cfg.optim_wrapper.type if optim_wrapper == 'AmpOptimWrapper': print_log( 'AMP training is already enabled in your config.', logger='current', level=logging.WARNING) else: assert optim_wrapper == 'OptimWrapper', ( '`--amp` is only supported when the optimizer wrapper type is ' f'`OptimWrapper` but got {optim_wrapper}.') cfg.optim_wrapper.type = 'AmpOptimWrapper' cfg.optim_wrapper.loss_scale = 'dynamic' # 恢复训练,根据 --resume 参数来决定是否从最新检查点恢复训练 cfg.resume = args.resume # 根据配置文件构建runner,继承OpenMMLab框架中的训练逻辑 if 'runner_type' not in cfg: # 默认runner # build the default runner runner = Runner.from_cfg(cfg) else: # build customized runner from the registry # if 'runner_type' is set in the cfg runner = RUNNERS.build(cfg) # 开始训练 runner.train() # 训练过程包含:数据加载、模型初始化、优化器和调度器的设置、前向传播(损失计算)反向传播、日志记录与检查点保存、分布式训练支持 if __name__ == '__main__': main()
这里需要注意的是,原始的训练流程是默认的,如果需要按照自己的流程对模型进行训练,就需要使用build_runner()函数自定义构建训练器。
build_runner()的参数是从命令行参数解析函数parse_args()中读取得到的,所以目前我认为应该是在配置文件中完善训练流程,具体我会在后续笔记中进行补充。
除了runner可以自己构建,修改mmseg文件夹下的基本模型能够修改网络本身,修改configs文件夹下的配置文件能够修改参数(数据集设置与加载、训练策略、网络参数)。
二、测试
1、tools/test.py 文件提供了在单 GPU 上启动测试任务的方法。
基础用法如下:
python tools/test.py ${配置文件} ${模型权重文件} [可选参数]
--work-dir
: 如果指定了路径,结果会保存在该路径下。如果没有指定则会保存在work_dirs/{配置文件名}
路径下.--show
: 当--show-dir
没有指定时,可以使用该参数,在程序运行过程中显示预测结果。--show-dir
: 绘制了分割掩膜图片的存储文件夹。如果指定了该参数,则可视化的分割掩膜将被保存到work_dir/timestamp/{指定路径}
.--wait-time
: 多次可视化结果的时间间隔。当--show
为激活状态时发挥作用。默认为2。--cfg-options
: 如果被具体指定,以 xxx=yyy 形式的键值对将被合并入配置文件中。
2、test.py代码解析
# Copyright (c) OpenMMLab. All rights reserved. import argparse import os import os.path as osp from mmengine.config import Config, DictAction from mmengine.runner import Runner # TODO: support fuse_conv_bn, visualization, and format_only # 解析命令行参数 def parse_args(): parser = argparse.ArgumentParser( description='MMSeg test (and eval) a model') parser.add_argument('config', help='train config file path') # 必需,指定路径 parser.add_argument('checkpoint', help='checkpoint file') # 必需,指定模型的检查点文件(训练中保存的模型权重) parser.add_argument( '--work-dir', help=('if specified, the evaluation metric results will be dumped' 'into the directory as json')) # 可选,指定保存评估结果(如指标)的目录 parser.add_argument( '--out', type=str, help='The directory to save output prediction for offline evaluation') # 可选,指定用于保存输出预测结果的目录,便于离线评估 parser.add_argument( '--show', action='store_true', help='show prediction results') # 可选,是否显示预测结果 parser.add_argument( '--show-dir', help='directory where painted images will be saved. ' 'If specified, it will be automatically saved ' 'to the work_dir/timestamp/show_dir') # 可选,指定保存绘制图像的目录 parser.add_argument( '--wait-time', type=float, default=2, help='the interval of show (s)') # 可选,指定显示图像的时间间隔 parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') # 可选,允许在命令行中覆盖配置文件中的部分设置 parser.add_argument( '--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') parser.add_argument( '--tta', action='store_true', help='Test time augmentation') # 当使用PyTorch版本>=2.0.0时,“torch.distributed.launch”将把“--local rank”参数传递给“tools/train.py”,而不是“--local_rank”。 parser.add_argument('--local_rank', '--local-rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) return args# 启用可视化hook # 检查配置文件中的visualization hook,并根据命令行参数show和show-dir控制可视化行为 def trigger_visualization_hook(cfg, args): default_hooks = cfg.default_hooks if 'visualization' in default_hooks: visualization_hook = default_hooks['visualization'] # Turn on visualization visualization_hook['draw'] = True if args.show: # show显示预测结果 visualization_hook['show'] = True visualization_hook['wait_time'] = args.wait_time if args.show_dir: # show-dir将可视化结果保存到指定目录 visualizer = cfg.visualizer visualizer['save_dir'] = args.show_dir else: raise RuntimeError( 'VisualizationHook must be included in default_hooks.' 'refer to usage ' '"visualization=dict(type=\'VisualizationHook\')"') return cfgdef main(): args = parse_args() # 解析命令行参数 # load config cfg = Config.fromfile(args.config) cfg.launcher = args.launcher if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None cfg.work_dir = args.work_dir elif cfg.get('work_dir', None) is None: # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) cfg.load_from = args.checkpoint # 处理可视化,调用trigger_visualization_hook() if args.show or args.show_dir: cfg = trigger_visualization_hook(cfg, args) # 测试时增强(tta)如果指定了--tta,配置文件中的数据处理管道会被替换为tta_pipeline,并将模型配置为TTA模式 # 是一种在推理时用多种数据增强方法,以提高在测试集上的泛化能力 if args.tta: cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline cfg.tta_model.module = cfg.model cfg.model = cfg.tta_model # 如果指定了--out,将预测结果保存到指定目录 if args.out is not None: cfg.test_evaluator['output_dir'] = args.out cfg.test_evaluator['keep_results'] = True # build the runner from config runner = Runner.from_cfg(cfg) # start testing runner.test() # runner.test() 是启动测试的关键,它会加载模型、数据集并执行推理,然后根据指定的评估方式输出结果 if __name__ == '__main__': main()