目录
训练笔记
def train_seg(self, true_img, true_seg, log_label):
true_img = true_img.to(self.rank, non_blocking=True)
true_seg = true_seg.to(self.rank, non_blocking=True)
true_img, true_seg = self.random_crop(true_img, true_seg)
with autocast(enabled=not self.args.disable_mixed_precision):
pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0]
loss = segmentation_loss(pred_seg, true_seg)
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()