def run(settings):
# Build dataloaders
loader_train, loader_val = build_dataloaders(cfg, settings)
# Create network
net = build_mixformer(cfg)
# wrap networks to distributed one
settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False)
settings.distill = getattr(cfg.TRAIN, "DISTILL", False)
settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "KL")
# settings.save_every_epoch = True
# Loss functions and Actors
objective = {
'giou': giou_loss, 'l1': l1_loss}
loss_weight = {
'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT}
actor = MixFormerActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings)
super().__init__(net, objective)
self.loss_weight = loss_weight
self.settings = settings
self.bs = self.settings.batchsize # batch size 8
self.run_score_head = run_score_head #false
# Optimizer, parameters, and learning rates
optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg)
#AdamW
use_amp = getattr(cfg.TRAIN, "AMP", False)
trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler, use_amp=use_amp)
# Initialize statistics variables
# Initialize tensorboard
# train process
trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True)
for epoch in range(self.epoch+1, max_epochs+1):
self.train_epoch()
self.cycle_dataset(loader)
"""Do a cycle of training or validation."""
for i, data in enumerate(loader, 1): 150/7500,这里的i是150,指的是当前运行的次数 7500是len(loader)
# forward pass
loss, stats = self.actor(data)
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class MixFormerActor(BaseActor):
def __call__(self, data):
out_dict = self.forward_pass(data, run_score_head=self.run_score_head)
search_bboxes = box_xywh_to_xyxy(data['search_anno'][0].clone()) 8 4
out_dict, _ = self.net(data['template_images'][0], data['template_images'][1], data['search_images'],run_score_head=run_score_head, gt_bboxes=search_bboxes) # 8 3 128 128 8 3 128 128 1 8 3 320 320
# out_dict: (B, N, C), outputs_coord: (1, B, N, C), target_query: (1, B, N, C)
-----------------------------------------------------------------------------------------------------------------------
def forward(self, template, online_template, search, run_score_head=False, gt_bboxes=None):
template, search = self.backbone(template, online_template, search)
---->8 3 128 128 8 3 128 128 8 3 320 320
# Forward the corner head
def forward(self, template, online_template, search): i:0 1 2
template, online_template, search = getattr(self, f'stage{
i}')(template, online_template, search)
**
Mixformer train代码解读
于 2022-05-03 23:49:39 首次发布