facebookresearch发表的检测数据库detectron2几乎涵盖了当下最新的各种检测代码。
训练代码
train_net.py
1. 参数加载
args = default_argument_parser().parse_args()
例如,在训练时传入参数:
python tools/train_net.py --config-file configs/COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml \
--num-gpus 1 SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025
其中文件faster_rcnn_R_50_FPN_1x.yaml
内容为
_BASE_: "../Base-RCNN-FPN.yaml"
MODEL:
WEIGHTS: "/home/sharedir/industrial/pgchen/R-50.pkl" # PKL文件路径,如果不存在会自行下载
MASK_ON: False
RESNETS:
DEPTH: 50
则args
为:
Namespace( config_file='configs/COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml', # 配置文件路径
dist_url='tcp://127.0.0.1:50263', # 用于连接分布式作业的URL文件
eval_only=False,
machine_rank=0,
num_gpus=1,
num_machines=1,
opts=['SOLVER.IMS_PER_BATCH', '2', 'SOLVER.BASE_LR', '0.0025'],
resume=False)
2. 多GPU分布式训练
launch( # 多GPU分布式训练
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)
3. 创建配置
cfg = setup(args)
def setup(args):
"""
创建配置并执行基本设置
"""
cfg = get_cfg()
# 从给定的配置文件和list加载内容并将其合并到 self
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
'''default_setup:
在作业开始时执行一些基本的常见设置,包括:
1. 设置detectron2 logger
2. 记录有关环境、cmdline 参数和配置的基本信息
3. 将配置备份到输出目录
'''
return cfg
4. TrainerBase
在基类中建立基础训练器。
class TrainerBase:
"""
带有hooks的迭代训练器的基类。
我们在这里做出的唯一假设是:训练循环运行。
子类可以实现循环是什么。
我们没有对数据加载器、优化器、模型等的存在做任何假设。
Attributes:
iter(int): the current iteration.
start_iter(int): The iteration to start with.
By convention the minimum possible value is 0.
max_iter(int): The iteration to end training.
storage(EventStorage): An EventStorage that's opened during the course of training.
"""
def __init__(self) -> None:
self._hooks: List[HookBase] = []
self.iter: int = 0
self.start_iter: int = 0
self.max_iter: int
self.storage: EventStorage
_log_api_usage("trainer." + self.__class__.__name__)
def train(self, start_iter: int, max_iter: int):
"""Args:
start_iter, max_iter (int): See docs above
"""
logger = logging.getLogger(__name__)
logger.info("Starting training from iteration {}".format(start_iter))
self.iter = self.start_iter = start_iter
self.max_iter = max_iter
with EventStorage(start_iter) as self.storage:
try:
self.before_train()
for self.iter in range(start_iter, max_iter):
self.before_step()
self.run_step()
self.after_step()
# self.iter == max_iter 可以被 `after_train` 用来判断训练是成功完成还是由于异常而失败。
self.iter += 1
except Exception:
logger.exception("Exception during training:")
raise
finally:
self.after_train()
主要包含五个功能函数:
self.before_train()
self.before_step()
self.run_step()
self.after_step()
self.after_train()
这些函数中都只包含一个简单的循环函数,循环的内容在HookBase
类中定义。
def before_train(self):
for h in self._hooks:
h.before_train()
def after_train(self):
self.storage.iter = self.iter
for h in self._hooks:
h.after_train()
def before_step(self):
# 在每一步的整个执行过程中保持 storage.iter == trainer.iter 的不变性
self.storage.iter = self.iter
for h in self._hooks:
h.before_step()
def after_step(self):
for h in self._hooks:
h.after_step()
def run_step(self):
raise NotImplementedError
5. HookBase
HookBase
是最基本的hook
,只是定义了函数,并没有给函数内容,是其他hook用来继承的基类。
class HookBase:
"""
Base class for hooks that can be registered with :class:`TrainerBase`.
Each hook can implement 4 methods. The way they are called is demonstrated
in the following snippet:
::
hook.before_train()
for iter in range(start_iter, max_iter):
hook.before_step()
trainer.run_step()
hook.after_step()
iter += 1
hook.after_train()
"""
trainer: "TrainerBase" = None
"""
A weak reference to the trainer object. Set by the trainer when the hook is registered.
"""
def before_train(self):
"""
Called before the first iteration.
"""
pass
def after_train(self):
"""
Called after the last iteration.
"""
pass
def before_step(self):
"""
Called before each iteration.
"""
pass
def after_step(self):
"""
Called after each iteration.
"""
pass
def state_dict(self):
"""
Hooks are stateless by default, but can be made checkpointable by
implementing `state_dict` and `load_state_dict`.
"""
return {}
6. DefaultTrainer
DefaultTrainer
类继承TrainerBase
类
class DefaultTrainer(TrainerBase):
"""
具有默认训练逻辑的训练器。它执行以下操作:
1. 使用由给定配置定义的模型、优化器、数据加载器创建一个 :class:`SimpleTrainer`。创建由配置定义的 LR 调度程序。
2. 加载最后一个checkpoint或者`cfg.MODEL.WEIGHTS`,如果存在,当`resume_or_load`被调用。
3. 注册一些由配置定义的常用 hooks。
"""
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
"""
super().__init__()
logger = logging.getLogger("detectron2")
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
setup_logger()
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
model = create_ddp_model(model, broadcast_buffers=False)
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
model, data_loader, optimizer
)
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
self.checkpointer = DetectionCheckpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg.OUTPUT_DIR,
trainer=weakref.proxy(self),
)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks())
6.1 创建模型
def build_model(cfg):
"""
构建整个模型架构,由 ``cfg.MODEL.META_ARCHITECTURE`` 定义。
根据配置函数里面的内容,找到对应的函数,然后调用创建模型
"""
meta_arch = cfg.MODEL.META_ARCHITECTURE
# 这里是用的是GeneralizedRCNN
model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
model.to(torch.device(cfg.MODEL.DEVICE))
_log_api_usage("modeling.meta_arch." + meta_arch)
return model
6.2 优化器
def get_default_optimizer_params(
model: torch.nn.Module,
base_lr: Optional[float] = None,
weight_decay: Optional[float] = None,
weight_decay_norm: Optional[float] = None,
bias_lr_factor: Optional[float] = 1.0,
weight_decay_bias: Optional[float] = None,
overrides: Optional[Dict[str, Dict[str, float]]] = None,
):
"""
获取优化器的默认参数列表,支持几种类型的覆盖。 如果不需要覆盖,这相当于`model.parameters()`。
Args:
base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
weight_decay: weight decay for every group by default. Can be omitted to use the one
in optimizer.
weight_decay_norm: override weight decay for params in normalization layers
bias_lr_factor: multiplier of lr for bias parameters.
weight_decay_bias: override weight decay for bias parameters
overrides: if not `None`, provides values for optimizer hyperparameters
(LR, weight decay) for module parameters with a given name; e.g.
``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
weight decay values for all module parameters named `embedding`.
For common detection models, ``weight_decay_norm`` is the only option
needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
from Detectron1 that are not found useful.
Example:
::
torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
lr=0.01, weight_decay=1e-4, momentum=0.9)
"""
if overrides is None:
overrides = {}
defaults = {}
if base_lr is not None:
defaults["lr"] = base_lr
if weight_decay is not None:
defaults["weight_decay"] = weight_decay
bias_overrides = {}
if bias_lr_factor is not None and bias_lr_factor != 1.0:
# NOTE: unlike Detectron v1, we now by default make bias hyperparameters
# exactly the same as regular weights.
if base_lr is None:
raise ValueError("bias_lr_factor requires base_lr")
bias_overrides["lr"] = base_lr * bias_lr_factor
if weight_decay_bias is not None:
bias_overrides["weight_decay"] = weight_decay_bias
if len(bias_overrides):
if "bias" in overrides:
raise ValueError("Conflicting overrides for 'bias'")
overrides["bias"] = bias_overrides
norm_module_types = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
)
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for module in model.modules():
for module_param_name, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
hyperparams = copy.copy(defaults)
if isinstance(module, norm_module_types) and weight_decay_norm is not None:
hyperparams["weight_decay"] = weight_decay_norm
hyperparams.update(overrides.get(module_param_name, {}))
params.append({"params": [value], **hyperparams})
return params
6.3 学习率更新
def build_lr_scheduler(
cfg: CfgNode, optimizer: torch.optim.Optimizer
) -> torch.optim.lr_scheduler._LRScheduler:
"""
Build a LR scheduler from config.
"""
name = cfg.SOLVER.LR_SCHEDULER_NAME
if name == "WarmupMultiStepLR":
steps = [x for x in cfg.SOLVER.STEPS if x <= cfg.SOLVER.MAX_ITER]
if len(steps) != len(cfg.SOLVER.STEPS):
logger = logging.getLogger(__name__)
logger.warning(
"SOLVER.STEPS contains values larger than SOLVER.MAX_ITER. "
"These values will be ignored."
)
sched = MultiStepParamScheduler(
values=[cfg.SOLVER.GAMMA ** k for k in range(len(steps) + 1)],
milestones=steps,
num_updates=cfg.SOLVER.MAX_ITER,
)
elif name == "WarmupCosineLR":
sched = CosineParamScheduler(1, 0)
else:
raise ValueError("Unknown LR scheduler: {}".format(name))
sched = WarmupParamScheduler(
sched,
cfg.SOLVER.WARMUP_FACTOR,
min(cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER, 1.0),
cfg.SOLVER.WARMUP_METHOD,
)
return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER)
之后几篇文章将详细介绍各个经典网络的网络结构,以及其他部分的细节。