yolov8逐步分解(2)_DetectionTrainer类初始化过程
yolov8逐步分解(3)_trainer训练之模型加载_yolov8 加载模型-CSDN博客
在上述文章逐步分解(3)和(4)中主要讲解了模型训练初始设置中self.setup_model()函数模型的加载及构建过程,本章将讲解混合精度训练AMP的相关代码。
下面是_setup_train()函数的详细代码。
def _setup_train(self, world_size):
""" Builds dataloaders and optimizer on correct rank process. """
# Model
self.run_callbacks('on_pretrain_routine_start')
ckpt = self.setup_model()#加载模型
self.model = self.model.to(self.device)
self.set_model_attributes()
# Check AMP
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
if self.amp and RANK in (-1, 0):
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
self.amp = torch.tensor(check_amp(self.model), device=self.device) #使用 check_amp 函数检查模型是否支持混合精度
callbacks.default_callbacks = callbacks_backup # restore callbacks, 恢复回之前备份的回调函数
if RANK > -1 and world_size > 1: # DDP 使用 dist.broadcast 将 self.amp 张量从rank 0广播到其他所有rank(返回None)。
dist.broadcast(self.amp, src=0) # broadcast the tensor