Mixformer train代码解读

def run(settings):
    # Build dataloaders
    loader_train, loader_val = build_dataloaders(cfg, settings)

    # Create network
    net = build_mixformer(cfg)

    # wrap networks to distributed one
    settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False)
    settings.distill = getattr(cfg.TRAIN, "DISTILL", False)
    settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "KL")
    # settings.save_every_epoch = True
    # Loss functions and Actors
    objective = {
   'giou': giou_loss, 'l1': l1_loss}
    loss_weight = {
   'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT}
    actor = MixFormerActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings)
		super().__init__(net, objective)
        self.loss_weight = loss_weight
        self.settings = settings
        self.bs = self.settings.batchsize  # batch size   8
        self.run_score_head = run_score_head   #false
        
    # Optimizer, parameters, and learning rates
    optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg)
			#AdamW   
    use_amp = getattr(cfg.TRAIN, "AMP", False)
    trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler, use_amp=use_amp)
    		# Initialize statistics variables
        	# Initialize tensorboard

    # train process
    trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True)

		 for epoch in range(self.epoch+1, max_epochs+1):
			self.train_epoch()
			self.cycle_dataset(loader)
			"""Do a cycle of training or validation."""
				for i, data in enumerate(loader, 1):  150/7500,这里的i是150,指的是当前运行的次数  7500len(loader)
					# forward pass
					loss, stats = self.actor(data)
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
						class MixFormerActor(BaseActor):
							def __call__(self, data):
								out_dict = self.forward_pass(data, run_score_head=self.run_score_head)
									search_bboxes = box_xywh_to_xyxy(data['search_anno'][0].clone())    8  4
        							out_dict, _ = self.net(data['template_images'][0], data['template_images'][1], data['search_images'],run_score_head=run_score_head, gt_bboxes=search_bboxes) #  8 3 128 128    8 3 128 128   1 8 3 320 320
        					# out_dict: (B, N, C), outputs_coord: (1, B, N, C), target_query: (1, B, N, C)
-----------------------------------------------------------------------------------------------------------------------
        							    def forward(self, template, online_template, search, run_score_head=False, gt_bboxes=None):
        									template, search = self.backbone(template, online_template, search)
        									---->8 3 128 128          8 3 128 128        8 3 320 320
        									# Forward the corner head
        											def forward(self, template, online_template, search):   i:0 1 2
        												 template, online_template, search = getattr(self, f'stage{
     i}')(template, online_template, search)
**
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
YOLOv8的训练代码解读如下: train.py是YOLOv8项目中的一个主要文件,它包含了训练模型所需的核心功能和逻辑。在train.py中,主要涉及的函数是train()函数。 train()函数是整个训练过程的入口函数,它接收一些参数,如模型、数据集、优化器等,并执行训练循环。在训练循环中,首先加载训练数据集,并通过数据增强方法对图像进行增强,以提高模型的泛化能力和鲁棒性。然后,将增强后的图像输入到模型中进行前向传播,得到预测结果。接着,计算预测结果与真实标签之间的损失,使用损失函数计算损失值,并根据损失值来更新模型的参数。这个过程不断迭代,直到达到预设的训练轮数或达到停止训练的条件。 train.py还涉及了其他一些模块和文件,如checkpoints、data、dataset、loss、utils等。checkpoints目录存储了训练过程中保存的模型权重文件,这些文件可以用于恢复训练或进行推理。data目录包含了存储类别信息和训练数据列表的文件,classes.txt文件存储了物体类别的名称,train.txt文件包含了训练数据集的文件路径列表。dataset模块提供了数据加载器,用于加载训练数据并进行预处理。loss模块包含了损失函数的实现,用于计算模型预测结果与真实标签之间的差异。utils模块包含了一些辅助函数和工具类,用于在训练过程中进行日志记录、模型保存等操作。 需要注意的是,由于train.py函数涉及的篇幅较大,本博客只提供了部分核心内容的讲解。如果你想详细了解train.py的完整代码,你可以查看网盘地址中的代码文件(提取码:wbqu)。 请注意,我在回答中使用了引用和引用中的相关内容来支持我的回答。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值