引言
之前在写mmdetection源码的解读过程时,觉得train_detector()这部分很重要,对于理解整个的训练过程应该时起着非常大的理解作用。
然后最近研究工作一直在看和修改mmdetection的其他模块的代码这一块。感觉train_detector()这块内容其实也不是特别重要来着,可能就是一个加强理解的过程。这次还是花了点时间,大致的看了一下,顺便加上自己的一些理解,解释了一下整个过程,如果有错的话,希望各路大佬指出,互相学习哈。
train_detector()
下面的代码出现在tools/train.py中,也是main函数的结尾,也就是说,我们训练的时候,到这就是真正的开始训练了。
train_detector(
model,
datasets,
cfg,
distributed=distributed,
validate=args.validate,
logger=logger)
那到底怎么训练的呢?
下面代码是train_detector()函数的定义,在mmdet/api/train.py文件中
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
logger=None):
if logger is None:
logger = get_root_logger(cfg.log_level)
# start training
if distributed:
_dist_train(model, dataset, cfg, validate=validate