使用Torchmetrics快速进行验证指标的计算

本文介绍了TorchMetrics,一个用于处理验证指标的Python库,支持多种常见指标,如Accuracy、Dice和F1Score。文章详细讲解了如何安装、使用和自定义指标,以及MetricCollection的便利性。通过实例展示了如何在训练和验证过程中轻松计算和重置指标。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

TorchMetrics可以为我们提供一种简单、干净、高效的方式来处理验证指标。TorchMetrics提供了许多现成的指标实现,如Accuracy, Dice, F1 Score, Recall, MAE等等,几乎最常见的指标都可以在里面找到。torchmetrics目前已经包好了80+任务评价指标。

TorchMetrics安装也非常简单,只需要PyPI安装最新版本:

 pip install torchmetrics

基本流程介绍

在训练时我们都是使用微批次训练,对于TorchMetrics也是一样的,在一个批次前向传递完成后将目标值Y和预测值Y_PRED传递给torchmetrics的度量对象,度量对象会计算批次指标并保存它(在其内部被称为state)。

当所有的批次完成时(也就是训练的一个Epoch完成),我们就可以从度量对象返回最终结果(这是对所有批计算的结果)。这里的每个度量对象都是从metric类继承,它包含了4个关键方法:

  • metric.forward(pred,target) - 更新度量状态并返回当前批次上计算的度量结果。如果您愿意,也可以使用metric(pred, target),没有区别。
  • metric.update(pred,target) - 与forward相同,但是不会返回计算结果,相当于是只将结果存入了state。如果不需要在当前批处理上计算出的度量结果,则优先使用这个方法,因为他不计算最终结果速度会很快。
  • metric.compute() - 返回在所有批次上计算的最终结果。也就是说其实forward相当于是update+compute。
  • metric.reset() - 重置状态,以便为下一个验证阶段做好准备。

也就是说:在我们训练的当前批次,获得了模型的输出后可以forward或update(建议使用update)。在批次完成后,调用compute以获取最终结果。最后,在验证轮次(Epoch)或者启用新的轮次进行训练时您调用reset重置状态指标

例如下面的代码:

 import torch
 import torchmetrics
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 model = YourModel().to(device)
 metric = torchmetrics.Accuracy()
 
 for batch_idx, (data, target) in enumerate(val_dataloader):
     data, target = data.to(device), target.to(device)
     output = model(data)
     # metric on current batch
     batch_acc = metric.update(preds, target)
     print(f"Accuracy on batch {i}: {batch_acc}")
 
 # metric on all batches using custom accumulation
 val_acc = metric.compute()
 print(f"Accuracy on all data: {val_acc}")
 
 # Resetting internal state such that metric is ready for new data
 metric.reset()

MetricCollection

在上面的示例中,使用了单个指标进行计算,但一般情况下可能会包含多个指标。Torchmetrics提供了MetricCollection可以将多个指标包装成单个可调用类,其接口与上面的基本用法相同。这样我们就无需单独处理每个指标。

代码如下:

 import torch
 from torchmetrics import MetricCollection, Accuracy, Precision, Recall
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 model = YourModel().to(device)
 # collection of all validation metrics
 metric_collection = MetricCollection({
     'acc': Accuracy(),
     'prec': Precision(num_classes=10, average='macro'),
     'rec': Recall(num_classes=10, average='macro')
 })
 
 for batch_idx, (data, target) in enumerate(val_dataloader):
     data, target = data.to(device), target.to(device)
     output = model(data)
     batch_metrics = metric_collection.forward(preds, target)
     print(f"Metrics on batch {i}: {batch_metrics}")
 
 val_metrics = metric_collection.compute()
 print(f"Metrics on all data: {val_metrics}")
 metric.reset()

也可以使用列表而不是字典,但是使用字典会更加清晰。

自定义指标

虽然Torchmetrics包含了很多常见的指标,但是有时我们还需要自己定义一些不常用的特定指标。我们只需要继承 Metric 类并且实现 updatecomputing 方法就可以了,另外就是需要在类初始化的时候使用self.add_state(state_name, default)来初始化我们的对象。

代码也很简单:

 import torch
 import torchmetrics
 
 class MyAccuracy(Metric):
     def __init__(self, delta):
         super().__init__()
         # to count the correct predictions
         self.add_state('corrects', default=torch.tensor(0))
         # to count the total predictions
         self.add_state('total', default=torch.tensor(0))
 
     def update(self, preds, target):
         # update correct predictions count
         self.correct += torch.sum(preds == target)
         # update total count, numel() returns the total number of elements 
         self.total += target.numel()
 
     def compute(self):
         # final computation
         return self.correct / self.total

总结

就是这样,Torchmetrics为我们指标计算提供了非常简单快速的处理方式,如果你想更多的了解它的用法,请参考官方文档:

https://avoid.overfit.cn/post/bdedfe4229e04da49049c4e7d56152d1

作者:Mattia Gatti

### DeepLabV3+ 模型精度评估指标 DeepLabV3+模型在多个标准数据集上展示了卓越的性能,特别是在PASCAL VOC 2012和Cityscapes数据集上。对于这些数据集,常用的精度评估指标主要包括交并比(Intersection over Union, IoU)、平均交并比(Mean Intersection over Union, mIoU),以及像素准确率(Pixel Accuracy)。具体表现如下: #### 1. 平均交并比 (mIoU) 这是衡量语义分割任务中最常用的一个评价标准之一。它计算的是预测结果与真实标签之间的重叠区域占两者总和的比例,再取各类别的平均值。DeepLabV3+在这项指标上有出色的表现,在PASCAL VOC 2012测试集中达到了89.0%的mIoU得分[^2]。 #### 2. 像素准确率 (Pixel Accuracy) 此指标表示被正确分类的像素数占全部像素总数的比例。虽然这是一个直观易懂的测量方法,但在多类别不平衡的情况下可能不够全面。尽管如此,高像素准确率仍然是模型良好泛化能力的一种体现。 #### 3. 类别平衡下的加权交并比 (Weighted IoU) 考虑到不同类别的样本数量可能存在较大差异,因此引入了权重因子来调整各个类别的贡献度,使得最终的结果更加公平合理。这种方法能够更好地反映模型对稀有类别的识别效果。 为了验证上述提到的各项性能指标,可以通过开源库如`torchmetrics`中的相应函数来进行快速便捷地计算。下面给出一段简单的Python代码片段展示如何利用该库完成这一操作: ```python import torch from torchmetrics import JaccardIndex, Accuracy # 初始化度量工具 iou_metric = JaccardIndex(task="multiclass", num_classes=NUM_CLASSES) accuracy_metric = Accuracy(task="multiclass", num_classes=NUM_CLASSES) # 计算IoU和Accuracy iou_value = iou_metric(preds, target).item() pixel_accuracy = accuracy_metric(preds, target).item() print(f"IOU Score: {iou_value:.4f}") print(f"Pixel Accuracy: {pixel_accuracy:.4f}") ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值