检测数据库detectron2代码解析(一)训练文件

facebookresearch发表的检测数据库detectron2几乎涵盖了当下最新的各种检测代码。

训练代码

train_net.py

1. 参数加载

args = default_argument_parser().parse_args()

例如,在训练时传入参数:

python tools/train_net.py  	--config-file configs/COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml \
						 	--num-gpus 1 SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025

其中文件faster_rcnn_R_50_FPN_1x.yaml内容为

_BASE_: "../Base-RCNN-FPN.yaml"
MODEL:
  WEIGHTS: "/home/sharedir/industrial/pgchen/R-50.pkl" # PKL文件路径,如果不存在会自行下载
  MASK_ON: False
  RESNETS:
    DEPTH: 50

args为:

Namespace(	config_file='configs/COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml',  # 配置文件路径
			dist_url='tcp://127.0.0.1:50263', # 用于连接分布式作业的URL文件
			eval_only=False, 
			machine_rank=0, 
			num_gpus=1, 
			num_machines=1, 
			opts=['SOLVER.IMS_PER_BATCH', '2', 'SOLVER.BASE_LR', '0.0025'], 
			resume=False)

2. 多GPU分布式训练

launch( # 多GPU分布式训练
    main,
    args.num_gpus,
    num_machines=args.num_machines,
    machine_rank=args.machine_rank,
    dist_url=args.dist_url,
    args=(args,),
)

3. 创建配置

cfg = setup(args)

def setup(args):
    """
    创建配置并执行基本设置
    """
    cfg = get_cfg() 
    # 从给定的配置文件和list加载内容并将其合并到 self
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)
    '''default_setup:
    在作业开始时执行一些基本的常见设置,包括:
     1. 设置detectron2 logger
     2. 记录有关环境、cmdline 参数和配置的基本信息
     3. 将配置备份到输出目录
    '''
    return cfg

4. TrainerBase

在基类中建立基础训练器。

class TrainerBase:
    """
    带有hooks的迭代训练器的基类。

    我们在这里做出的唯一假设是:训练循环运行。
    子类可以实现循环是什么。
    我们没有对数据加载器、优化器、模型等的存在做任何假设。
    Attributes:
    iter(int): the current iteration.
    start_iter(int): The iteration to start with.
        			 By convention the minimum possible value is 0.
    max_iter(int): The iteration to end training.
    storage(EventStorage): An EventStorage that's opened during the course of training.
    """
    def __init__(self) -> None:
        self._hooks: List[HookBase] = []
        self.iter: int = 0
        self.start_iter: int = 0
        self.max_iter: int
        self.storage: EventStorage
        _log_api_usage("trainer." + self.__class__.__name__)
        
    def train(self, start_iter: int, max_iter: int):
        """Args:
            start_iter, max_iter (int): See docs above
        """
        logger = logging.getLogger(__name__)
        logger.info("Starting training from iteration {}".format(start_iter))

        self.iter = self.start_iter = start_iter
        self.max_iter = max_iter

        with EventStorage(start_iter) as self.storage:
            try:
                self.before_train()
                for self.iter in range(start_iter, max_iter):
                    self.before_step()
                    self.run_step()
                    self.after_step()
                # self.iter == max_iter 可以被 `after_train` 用来判断训练是成功完成还是由于异常而失败。
                self.iter += 1
            except Exception:
                logger.exception("Exception during training:")
                raise
            finally:
                self.after_train()

主要包含五个功能函数:

self.before_train()
self.before_step()
self.run_step()
self.after_step()
self.after_train()

这些函数中都只包含一个简单的循环函数,循环的内容在HookBase类中定义。

    def before_train(self):
        for h in self._hooks:
            h.before_train()

    def after_train(self):
        self.storage.iter = self.iter
        for h in self._hooks:
            h.after_train()

    def before_step(self):
        # 在每一步的整个执行过程中保持 storage.iter == trainer.iter 的不变性
        self.storage.iter = self.iter
        for h in self._hooks:
            h.before_step()

    def after_step(self):
        for h in self._hooks:
            h.after_step()

    def run_step(self):
        raise NotImplementedError

5. HookBase

HookBase是最基本的hook,只是定义了函数,并没有给函数内容,是其他hook用来继承的基类。

class HookBase:
    """
    Base class for hooks that can be registered with :class:`TrainerBase`.

    Each hook can implement 4 methods. The way they are called is demonstrated
    in the following snippet:
    ::
        hook.before_train()
        for iter in range(start_iter, max_iter):
            hook.before_step()
            trainer.run_step()
            hook.after_step()
        iter += 1
        hook.after_train()
    """
    trainer: "TrainerBase" = None
    """
    A weak reference to the trainer object. Set by the trainer when the hook is registered.
    """

    def before_train(self):
        """
        Called before the first iteration.
        """
        pass

    def after_train(self):
        """
        Called after the last iteration.
        """
        pass

    def before_step(self):
        """
        Called before each iteration.
        """
        pass

    def after_step(self):
        """
        Called after each iteration.
        """
        pass

    def state_dict(self):
        """
        Hooks are stateless by default, but can be made checkpointable by
        implementing `state_dict` and `load_state_dict`.
        """
        return {}

6. DefaultTrainer

DefaultTrainer类继承TrainerBase

class DefaultTrainer(TrainerBase):
    """
    具有默认训练逻辑的训练器。它执行以下操作:

    1. 使用由给定配置定义的模型、优化器、数据加载器创建一个 :class:`SimpleTrainer`。创建由配置定义的 LR 调度程序。
    2. 加载最后一个checkpoint或者`cfg.MODEL.WEIGHTS`,如果存在,当`resume_or_load`被调用。
    3. 注册一些由配置定义的常用 hooks。
    """
    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        super().__init__()
        logger = logging.getLogger("detectron2")
        if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called for d2
            setup_logger()
        cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())

        # Assume these objects must be constructed in this order.
        model = self.build_model(cfg)
        optimizer = self.build_optimizer(cfg, model)
        data_loader = self.build_train_loader(cfg)

        model = create_ddp_model(model, broadcast_buffers=False)
        self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
            model, data_loader, optimizer
        )

        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
        self.checkpointer = DetectionCheckpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            trainer=weakref.proxy(self),
        )
        self.start_iter = 0
        self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        self.register_hooks(self.build_hooks())

6.1 创建模型

def build_model(cfg):
    """
    构建整个模型架构,由 ``cfg.MODEL.META_ARCHITECTURE`` 定义。
   	根据配置函数里面的内容,找到对应的函数,然后调用创建模型
    """
    meta_arch = cfg.MODEL.META_ARCHITECTURE
    # 这里是用的是GeneralizedRCNN
    model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
    model.to(torch.device(cfg.MODEL.DEVICE))
    _log_api_usage("modeling.meta_arch." + meta_arch)
    return model

6.2 优化器

def get_default_optimizer_params(
    model: torch.nn.Module,
    base_lr: Optional[float] = None,
    weight_decay: Optional[float] = None,
    weight_decay_norm: Optional[float] = None,
    bias_lr_factor: Optional[float] = 1.0,
    weight_decay_bias: Optional[float] = None,
    overrides: Optional[Dict[str, Dict[str, float]]] = None,
):
    """
    获取优化器的默认参数列表,支持几种类型的覆盖。 如果不需要覆盖,这相当于`model.parameters()`。

    Args:
        base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
        weight_decay: weight decay for every group by default. Can be omitted to use the one
            in optimizer.
        weight_decay_norm: override weight decay for params in normalization layers
        bias_lr_factor: multiplier of lr for bias parameters.
        weight_decay_bias: override weight decay for bias parameters
        overrides: if not `None`, provides values for optimizer hyperparameters
            (LR, weight decay) for module parameters with a given name; e.g.
            ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
            weight decay values for all module parameters named `embedding`.

    For common detection models, ``weight_decay_norm`` is the only option
    needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
    from Detectron1 that are not found useful.

    Example:
    ::
        torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
                       lr=0.01, weight_decay=1e-4, momentum=0.9)
    """
    if overrides is None:
        overrides = {}
    defaults = {}
    if base_lr is not None:
        defaults["lr"] = base_lr
    if weight_decay is not None:
        defaults["weight_decay"] = weight_decay
    bias_overrides = {}
    if bias_lr_factor is not None and bias_lr_factor != 1.0:
        # NOTE: unlike Detectron v1, we now by default make bias hyperparameters
        # exactly the same as regular weights.
        if base_lr is None:
            raise ValueError("bias_lr_factor requires base_lr")
        bias_overrides["lr"] = base_lr * bias_lr_factor
    if weight_decay_bias is not None:
        bias_overrides["weight_decay"] = weight_decay_bias
    if len(bias_overrides):
        if "bias" in overrides:
            raise ValueError("Conflicting overrides for 'bias'")
        overrides["bias"] = bias_overrides

    norm_module_types = (
        torch.nn.BatchNorm1d,
        torch.nn.BatchNorm2d,
        torch.nn.BatchNorm3d,
        torch.nn.SyncBatchNorm,
        # NaiveSyncBatchNorm inherits from BatchNorm2d
        torch.nn.GroupNorm,
        torch.nn.InstanceNorm1d,
        torch.nn.InstanceNorm2d,
        torch.nn.InstanceNorm3d,
        torch.nn.LayerNorm,
        torch.nn.LocalResponseNorm,
    )
    params: List[Dict[str, Any]] = []
    memo: Set[torch.nn.parameter.Parameter] = set()
    for module in model.modules():
        for module_param_name, value in module.named_parameters(recurse=False):
            if not value.requires_grad:
                continue
            # Avoid duplicating parameters
            if value in memo:
                continue
            memo.add(value)

            hyperparams = copy.copy(defaults)
            if isinstance(module, norm_module_types) and weight_decay_norm is not None:
                hyperparams["weight_decay"] = weight_decay_norm
            hyperparams.update(overrides.get(module_param_name, {}))
            params.append({"params": [value], **hyperparams})
    return params

6.3 学习率更新

def build_lr_scheduler(
    cfg: CfgNode, optimizer: torch.optim.Optimizer
) -> torch.optim.lr_scheduler._LRScheduler:
    """
    Build a LR scheduler from config.
    """
    name = cfg.SOLVER.LR_SCHEDULER_NAME

    if name == "WarmupMultiStepLR":
        steps = [x for x in cfg.SOLVER.STEPS if x <= cfg.SOLVER.MAX_ITER]
        if len(steps) != len(cfg.SOLVER.STEPS):
            logger = logging.getLogger(__name__)
            logger.warning(
                "SOLVER.STEPS contains values larger than SOLVER.MAX_ITER. "
                "These values will be ignored."
            )
        sched = MultiStepParamScheduler(
            values=[cfg.SOLVER.GAMMA ** k for k in range(len(steps) + 1)],
            milestones=steps,
            num_updates=cfg.SOLVER.MAX_ITER,
        )
    elif name == "WarmupCosineLR":
        sched = CosineParamScheduler(1, 0)
    else:
        raise ValueError("Unknown LR scheduler: {}".format(name))

    sched = WarmupParamScheduler(
        sched,
        cfg.SOLVER.WARMUP_FACTOR,
        min(cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER, 1.0),
        cfg.SOLVER.WARMUP_METHOD,
    )
    return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER)

之后几篇文章将详细介绍各个经典网络的网络结构,以及其他部分的细节。

  • 8
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
好的,以下是一段解析选择题并导入数据库的Java代码: ``` import java.io.BufferedReader; import java.io.FileReader; import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; import java.sql.SQLException; public class ImportQuestions { public static void main(String[] args) { String csvFile = "questions.csv"; String jdbcURL = "jdbc:mysql://localhost:3306/mydatabase"; String username = "root"; String password = "password"; String line = ""; String cvsSplitBy = ","; int questionNumber = ; try (BufferedReader br = new BufferedReader(new FileReader(csvFile)); Connection conn = DriverManager.getConnection(jdbcURL, username, password)) { String sql = "INSERT INTO questions (question, optionA, optionB, optionC, optionD, answer) VALUES (?, ?, ?, ?, ?, ?)"; PreparedStatement pstmt = conn.prepareStatement(sql); while ((line = br.readLine()) != null) { String[] question = line.split(cvsSplitBy); questionNumber++; pstmt.setString(1, question[]); pstmt.setString(2, question[1]); pstmt.setString(3, question[2]); pstmt.setString(4, question[3]); pstmt.setString(5, question[4]); pstmt.setString(6, question[5]); pstmt.executeUpdate(); } System.out.println("Successfully imported " + questionNumber + " questions."); } catch (SQLException e) { System.out.println("SQL Exception: " + e.getMessage()); } catch (Exception e) { System.out.println("Exception: " + e.getMessage()); } } } ``` 这段代码可以读取一个CSV文件解析其中的选择题,并将其导入到MySQL数据库中。具体来说,它会创建一个名为"questions"的表,其中包含题目、四个选项和正确答案。然后,它会逐行读取CSV文件,将每个题目的信息插入到数据库中。最后,它会输出成功导入的题目数量。 希望这段代码对你有帮助!

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值