engine.py详解

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Train and eval functions used in main.py
"""
import math
import os
import sys
from typing import Iterable   #类型提示相关的模块,用于增强代码的可读性和可维护性。

import torch

import util.misc as utils    #自定义的辅助函数模块
from datasets.coco_eval import CocoEvaluator
from datasets.panoptic_eval import PanopticEvaluator


def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,#模型、损失函数
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, max_norm: float = 0):#可选参数,用于梯度裁剪的最大范数。
    model.train()
    criterion.train()
    metric_logger = utils.MetricLogger(delimiter="  ")#用于记录训练过程中的指标和信息。
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))#添加两个指标:学习率(lr)和分类错误率(class_error)
    metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Epoch: [{}]'.format(epoch)#用于打印训练过程中的提示信息
    print_freq = 10#表示每隔多少批次打印一次训练信息
    #首先,使用 metric_logger.log_every 方法来遍历数据加载器,并设置打印频率和提示信息。
    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
        samples = samples.to(device)#在每个批次的开始,将输入数据和目标值移动到指定设备上。
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        #,调用模型进行前向传播,得到预测输出 outputs。
        outputs = model(samples)
        #criterion.weight_dict 是一个字典,包含各个损失项的权重。
        loss_dict = criterion(outputs, targets)#计算损失值。criterion 是一个损失函数,它接受 outputs 和 targets 作为输入,返回一个字典 loss_dict,其中包含不同损失项的损失值。
        weight_dict = criterion.weight_dict#通过遍历 loss_dict 字典的键,并根据 weight_dict 权重相乘,计算出总的损失值 losses。
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
        #接下来,为了记录和打印目的,对损失值进行一些处理。
        # reduce losses over all GPUs for logging purposes
        #首先,使用 utils.reduce_dict 函数将损失字典 loss_dict 中的损失值在多个 GPU 上进行求和。
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        #loss_dict_reduced_unscaled 字典中的键是原始损失字典 loss_dict_reduced 的键加上 '_unscaled' 后缀,对应的值是 loss_dict_reduced 中对应键的值
        loss_dict_reduced_unscaled = {f'{k}_unscaled': v
                                      for k, v in loss_dict_reduced.items()}
        #loss_dict_reduced_scaled 字典中的键是 loss_dict_reduced 中的键,对应的值是 loss_dict_reduced 中对应键的值乘以 weight_dict 中对应键的值。
        loss_dict_reduced_scaled = {k: v * weight_dict[k]
                                    for k, v in loss_dict_reduced.items() if k in weight_dict}
        #最后,计算经过缩放后的总损失值 losses_reduced_scaled,并调用 item() 方法将其转换为标量值 loss_value。
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
        #在训练过程中,会使用 metric_logger 对象记录和打印损失值。
        loss_value = losses_reduced_scaled.item()
        #这段代码是在每个批次训练结束后检查损失值是否为有限数。如果损失值不是有限数,则会输出一个错误信息,打印当前损失值和损失字典 loss_dict_reduced,并退出程序。
        #这是一个异常情况,通常表示模型训练出现了问题或者数据质量存在问题。停止训练并打印错误信息可以帮助用户及时发现问题,并解决问题。
        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)
        '''首先,调用 optimizer.zero_grad() 方法将模型参数的梯度置零,以便进行新一轮的反向传播。

        然后,调用 losses.backward() 方法进行反向传播,计算参数的梯度。
        
        接下来,如果 max_norm 大于 0,则使用 torch.nn.utils.clip_grad_norm_ 方法对参数的梯度进行裁剪,以防止梯度爆炸的问题。该方法会将参数的梯度按照指定的最大范数进行缩放。
        
        最后,调用 optimizer.step() 方法来更新模型的参数,根据计算得到的梯度与优化算法进行参数更新。
        
        通过这些步骤,完成了一轮训练的反向传播和参数更新操作。接下来,会回到训练循环的开始,继续下一个批次的训练。'''
        optimizer.zero_grad()
        losses.backward()
        if max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()
        '''首先,调用 metric_logger.update 方法,将损失值 loss_value、经过缩放后的损失字典 loss_dict_reduced_scaled、未经缩放的损失字典 loss_dict_reduced_unscaled 传递给 metric_logger 对象,以便记录和打印相应的指标值。

        然后,调用 metric_logger.update 方法,将分类错误的损失值 loss_dict_reduced['class_error'] 传递给 metric_logger 对象,以便记录和打印分类错误的指标值。
        
        最后,调用 metric_logger.update 方法,将当前优化器的学习率 optimizer.param_groups[0]["lr"] 传递给 metric_logger 对象,以便记录和打印学习率的指标值。
        
        通过更新 metric_logger 对象中的指标值,可以在整个训练过程中记录并打印各种指标值,有助于了解模型的训练情况,并进行优化和调试。'''
        metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    '''这段代码是在多个进程之间同步统计信息,并输出平均统计结果。

        首先,调用 `metric_logger.synchronize_between_processes()` 方法,该方法会在多个进程之间同步统计信息,确保每个进程的统计结果一致。
        
        然后,通过打印 `metric_logger` 对象,输出平均统计结果。这里使用 `print("Averaged stats:", metric_logger)` 来将平均统计结果打印出来。
        
        最后,使用字典推导式遍历 `metric_logger.meters` 中的每个统计指标,将全局平均值 `meter.global_avg` 存入一个字典中,并将该字典作为函数的返回值。
        
        通过同步统计信息并输出平均结果,可以得到整个训练过程中各种指标的平均值,便于进行模型性能评估和比较。'''
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


@torch.no_grad()
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir):
    model.eval()
    criterion.eval()
    '''首先,它将模型和损失函数设置为评估模式:
    model.eval()
    criterion.eval()
    然后,它初始化一个MetricLogger对象来记录评估指标,并添加一个名为class_error的指标来跟踪分类错误率。 
    接下来,它定义了一个变量iou_types,该变量包含在后处理器中可以使用的评估指标类型('segm'和'bbox')。  
    然后,它创建了一个CocoEvaluator对象,用于计算COCO评估指标。base_ds是基本数据集,iou_types是评估指标类型(即分割和边界框),用于初始化CocoEvaluator对象。 
    最后,它返回了一个CocoEvaluator对象。'''
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Test:'

    iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
    coco_evaluator = CocoEvaluator(base_ds, iou_types)
    #注释部分可能是为了设置特定的IoU阈值来计算COCO评估指标
    # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]

    panoptic_evaluator = None
    if 'panoptic' in postprocessors.keys():
        panoptic_evaluator = PanopticEvaluator(
            data_loader.dataset.ann_file,
            data_loader.dataset.ann_folder,
            output_dir=os.path.join(output_dir, "panoptic_eval"),
        )
    '''这段代码是一个评估函数中的循环。它遍历数据加载器(`data_loader`)中的样本和目标,并对它们进行评估。

    在每次循环中,代码将输入数据(`samples`)和目标数据(`targets`)移动到指定的设备上。
    
    接下来,使用模型对输入数据进行前向传播,得到输出结果(`outputs`)。
    
    然后,使用损失函数(`criterion`)计算模型的损失值。`loss_dict`是一个包含各种损失组成部分的字典。
    
    通过`utils.reduce_dict`将损失字典的值在所有GPU上进行合并,以便记录和日志输出。`weight_dict`是一个权重字典,用于加权损失值。
    
    将合并后的损失字典乘以权重字典中的对应权重,并创建一个未加权的损失字典。
    
    通过`metric_logger.update`更新评估指标记录器(`metric_logger`)。其中包括总损失(`loss`)和减小比例后的各个损失部分。
    
    还更新了分类错误率(`class_error`)指标。
    
    最后,提取目标的原始尺寸,并使用后处理器(`postprocessors`)中的'bbox'评估指标类型对输出结果进行后处理,得到最终的评估结果(`results`)。'''
    for samples, targets in metric_logger.log_every(data_loader, 10, header):
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        outputs = model(samples)
        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_scaled = {k: v * weight_dict[k]
                                    for k, v in loss_dict_reduced.items() if k in weight_dict}
        loss_dict_reduced_unscaled = {f'{k}_unscaled': v
                                      for k, v in loss_dict_reduced.items()}
        metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
                             **loss_dict_reduced_scaled,
                             **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])

        orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
        results = postprocessors['bbox'](outputs, orig_target_sizes)
        if 'segm' in postprocessors.keys():
            target_sizes = torch.stack([t["size"] for t in targets], dim=0)
            results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
        '''这段代码将每个目标(`target`)和输出结果(`output`)进行配对,并将它们组合成一个字典`res`,其中键是目标的图像ID(通过`target['image_id'].item()`获取),值是对应的输出结果。
        如果存在`coco_evaluator`(COCO评估器),则使用`coco_evaluator.update(res)`方法将`res`传递给评估器,以更新评估结果。这样可以在评估过程中逐步计算COCO评估指标,如平均精度(mAP)等。'''
        res = {target['image_id'].item(): output for target, output in zip(targets, results)}
        if coco_evaluator is not None:
            coco_evaluator.update(res)

        if panoptic_evaluator is not None:
            res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes)
            for i, target in enumerate(targets):
                image_id = target["image_id"].item()
                file_name = f"{image_id:012d}.png"
                res_pano[i]["image_id"] = image_id
                res_pano[i]["file_name"] = file_name

            panoptic_evaluator.update(res_pano)
    '''这段代码用于在多个进程之间收集统计数据。
    首先,调用`metric_logger.synchronize_between_processes()`方法来确保所有进程之间的同步。这将使每个进程的指标记录器(`metric_logger`)中的统计信息保持一致。
    然后,通过访问`metric_logger`对象来打印平均统计信息,即输出平均统计信息的日志。
    如果存在`coco_evaluator`(COCO评估器),则调用`coco_evaluator.synchronize_between_processes()`方法来确保所有进程之间的同步。这是为了在多个进程中使用COCO评估器时,将评估结果进行合并和同步。'''
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    if coco_evaluator is not None:
        coco_evaluator.synchronize_between_processes()
    if panoptic_evaluator is not None:
        panoptic_evaluator.synchronize_between_processes()

    # accumulate predictions from all images
    if coco_evaluator is not None:
        #累积所有图像的预测结果。这将将每个图像的预测结果添加到评估器中
        coco_evaluator.accumulate()
        #来生成评估汇总信息。这将计算并打印出各种COCO评估指标的值,如平均精度(mAP)等。
        coco_evaluator.summarize()
    panoptic_res = None
    if panoptic_evaluator is not None:
        panoptic_res = panoptic_evaluator.summarize()
    '''这段代码通过遍历指标记录器(`metric_logger`)中的计量器(`meters`),创建一个字典`stats`,其中包含每个计量器的全局平均值。
    对于每个计量器,使用`meter.global_avg`获取其全局平均值,并将其存储在与计量器键匹配的`stats`字典中。最终得到的`stats`字典将包含每个计量器的全局平均值。'''
    stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
    '''这段代码用于将COCO评估器的bounding box(bbox)指标统计信息添加到`stats`字典中。
    首先,检查`postprocessors`字典的键是否包含'bbox',即检查后处理器中是否包含针对bounding box的评估指标。
    如果存在COCO评估器且后处理器中包含bounding box的评估指标,则通过访问`coco_evaluator.coco_eval['bbox'].stats.tolist()`获取bounding box的统计信息,
    并存储在`stats`字典中,以键'coco_eval_bbox'为标识。该统计信息可能包括精度、召回率、F1分数等指标。'''
    if coco_evaluator is not None:
        if 'bbox' in postprocessors.keys():
            stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
        if 'segm' in postprocessors.keys():
            stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
    if panoptic_res is not None:
        stats['PQ_all'] = panoptic_res["All"]
        stats['PQ_th'] = panoptic_res["Things"]
        stats['PQ_st'] = panoptic_res["Stuff"]
    #首先,返回包含所有指标的字典stats。然后,如果存在COCO评估器,则同时返回该评估器对象coco_evaluator,否则只返回字典stats
    return stats, coco_evaluator

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

毕竟是shy哥

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值