PaddleDetection代码解析之训练部分解析

2021SC@SDUSC

以下是PaddleDetection训练部分的代码,这部分代码比较重要,相对来说也比较难,我在反复阅读后还是对他有所了解且在代码的关键部分和难理解的部分加上了备注:

train.py流程解析

从程序入口开始(if name == ‘main’:)

1.直接进入main函数

初始化训练参数:

  • ①.parser = ArgsParser() #读取命令行传递参数,加载yaml文件参数
  • ②.将参数整合在一起,检查参数配置是否正确
  • ③.是否使用GPU加速
  • ④.查看paddledet版本是否正确
  • ⑤.进入run()函数

配置阶段

  • a.系统变量配置、初始化、得到使用GPU数量等
  • b.创建数据读取类
  • c.创建网络结构类
  • d.创建学习率类
  • e.创建优化器类
  • f.初始化模型权重,加载预训练模型、模型与优化器整合,
  • g.是否是多卡,实例多模型并行训练
  • 开启训练
  • g.遍历数据,开始循环训练,根据时间戳计算一系列时间(剩余时间,平均训练时间)
  • h.模型前向推理,反向传播,(多卡模型并行,loss合并)
  • j.每个iter结束后输出日志
  • k.定期打印log,定期保存模型和优化器参数,(eval开启:最优 and 定时)
  • 直到迭代结束

 

#train.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os, sys
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
if parent_path not in sys.path:
    sys.path.append(parent_path)

# ignore numba warning
import warnings
warnings.filterwarnings('ignore')
import random
import datetime
import time
import numpy as np

import paddle
from paddle.distributed import ParallelEnv

from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight, save_model

import ppdet.utils.cli as cli
import ppdet.utils.check as check
import ppdet.utils.stats as stats
from ppdet.utils.logger import setup_logger
logger = setup_logger('train')

#运行配置参数解析函数
def parse_args():
    parser = cli.ArgsParser()
    parser.add_argument(
        "--eval",
        action='store_true',
        default=False,
        help="Whether to perform evaluation in train")
    parser.add_argument(
        "-r", "--resume", default=None, help="weights path for resume")
    parser.add_argument(
        "--slim_config",
        default=None,
        type=str,
        help="Configuration file of slim method.")
    parser.add_argument(
        "--enable_ce",
        type=bool,
        default=False,
        help="If set True, enable continuous evaluation job."
        "This flag is only used for internal test.")
    parser.add_argument(
        "--fp16",
        action='store_true',
        default=False,
        help="Enable mixed precision training.")
    parser.add_argument(
        "--fleet", action='store_true', default=False, help="Use fleet or not")
    parser.add_argument(
        "--use_vdl",
        type=bool,
        default=False,
        help="whether to record the data to VisualDL.")
    parser.add_argument(
        '--vdl_log_dir',
        type=str,
        default="vdl_log_dir/scalar",
        help='VisualDL logging directory for scalar.')
    parser.add_argument(
        '--save_prediction_only',
        action='store_true',
        default=False,
        help='Whether to save the evaluation results only')
    args = parser.parse_args()
    return args



#run函数,detection套件执行的核心部分
def run(FLAGS, cfg):
    # init fleet environment  #初始化环境
    if cfg.fleet:
        init_fleet_env()
    else:
        # init parallel environment if nranks > 1  #是否采用模型并行(多卡)
        init_parallel_env()

    if FLAGS.enable_ce:   #随机参数
        set_random_seed(0)

    # build trainer   #建立模型
    trainer = Trainer(cfg, mode='train')

    # load weights    #加载预训练模型参数
    if FLAGS.resume is not None:
        trainer.resume_weights(FLAGS.resume)
    elif 'pretrain_weights' in cfg and cfg.pretrain_weights:
        trainer.load_weights(cfg.pretrain_weights)

    # training   #执行训练
    trainer.train(FLAGS.eval)






#主函数定义
def main():
    
    FLAGS = parse_args()   #加载运行参数
    cfg = load_config(FLAGS.config)   #加载yaml配置
    cfg['fp16'] = FLAGS.fp16    #是否采用半精度
    cfg['fleet'] = FLAGS.fleet
    cfg['use_vdl'] = FLAGS.use_vdl   #是否采用训练可视化
    cfg['vdl_log_dir'] = FLAGS.vdl_log_dir  #可视化文件路径
    cfg['save_prediction_only'] = FLAGS.save_prediction_only  #只保存预测结果
    merge_config(FLAGS.opt) #合并配置
    # 选择执行环境
    place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')
    # 是否采用同步BN
    if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu:
        cfg['norm_type'] = 'bn'
    # slim配置
    if FLAGS.slim_config:
        cfg = build_slim_model(cfg, FLAGS.slim_config)
    #检测配置文件
    check.check_config(cfg)
    check.check_gpu(cfg.use_gpu)
    check.check_version()
    #执行run函数
    run(FLAGS, cfg)


#程序入口
if __name__ == "__main__":
    main()#主函数入口

  • 2
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值