03 pytorch 验证指标工具 torchmetrics

安装torchmetircs

pip install torchmetrics
conda install -c conda-forge torchmetrics

官方源码

Welcome to TorchMetrics — PyTorch-Metrics 1.3.1 documentation

https://lightning.ai/docs/torchmetrics/stable/pages/quickstart.html

基本函数

基本流程

在训练时我们都是使用微批次训练,对于TorchMetrics也是一样的,在一个批次前向传递完成后将目标值Y和预测值Y_PRED传递给torchmetrics的度量对象,度量对象会计算批次指标并保存它(在其内部被称为state)。

当所有的批次完成时(也就是训练的一个Epoch完成),我们就可以从度量对象返回最终结果(这是对所有批计算的结果)。这里的每个度量对象都是从metric类继承,它包含了4个关键方法

metrics.forward(pred,target)

更新度量状态并返回当前批次上计算的度量结果

metric.update(pred,target)

与forward相同,但是不会返回计算结果,相当于是只将结果存入了state。 如果不需要在当前批处理上计算出的度量结果,则优先使用这个方法,因为他不计算最终结果速度会很快

metric.compute()

 返回在所有批次上计算的最终结果。

也就是说其实forward相当于是update+compute。

metric.reset()

重置状态,以便为下一个验证阶段做好准备。

note: 在训练的当前批次,获得了模型的输出后可以forward或update(建议使用update)。 在批次完成后,调用compute以获取最终结果。最后,在验证轮次(Epoch)或者启用新的轮次进行训练时您调用reset重置状态指标。

使用指南

单个指标

API接口调用

import torch
# import our library
import torchmetrics

# initialize metric
metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=5)

n_batches = 10
for i in range(n_batches):
    # simulate a classification problem
    preds = torch.randn(10, 5).softmax(dim=-1)
    target = torch.randint(5, (10,))
    # metric on current batch
    acc = metric(preds, target)
    print(f"Accuracy on batch {i}: {acc}")

# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")

# Resetting internal state such that metric ready for new data
metric.reset()

深度学习训练模型时,可参考下述流程

import torch 
import torchmetrics 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
model = YourModel().to(device) 
metric = torchmetrics.Accuracy() 

for batch_idx, (data, target) in enumerate(val_dataloader): 
    data, target = data.to(device), target.to(device) 
    output = model(data) 
    # metric on current batch 
    batch_acc = metric.update(preds, target) 
    print(f"Accuracy on batch {i}: {batch_acc}") 

# metric on all batches using custom accumulation 
val_acc = metric.compute() 
print(f"Accuracy on all data: {val_acc}") 

# Resetting internal state such that metric is ready for new data 
metric.reset()

多个指标

Torchmetrics提供了MetricCollection可以将多个指标包装成单个可调用类,其接口与上面的基本用法相同。

import torch 
from torchmetrics import MetricCollection, Accuracy, Precision, Recall 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
model = YourModel().to(device) 
# collection of all validation metrics 
metric_collection = MetricCollection({ 
    'acc': Accuracy(), 
    'prec': Precision(num_classes=10, average='macro'), 
    'rec': Recall(num_classes=10, average='macro') 
}) 

for batch_idx, (data, target) in enumerate(val_dataloader): 
    data, target = data.to(device), target.to(device) 
    output = model(data) 
    batch_metrics = metric_collection.forward(preds, target) 
    print(f"Metrics on batch {i}: {batch_metrics}") 

val_metrics = metric_collection.compute() 
print(f"Metrics on all data: {val_metrics}") 
metric.reset()

创建自定义矩阵评估指标

我们只需要继承 Metric 类并且实现 update 和 computing 方法就可以了,另外就是需要在类初始化的时候使用self.add_state(state_name, default)来初始化我们的对象。

  1. __init__ :self.add_state 用于度量计算所需的每个内部状态
  2. update :更新度量状态所需的所有逻辑
  3. compute:实现计算方法,最终的度量计算
import torch 
import torchmetrics 

class MyAccuracy(Metric): 
    def __init__(self, delta): 
        super().__init__() 
        # to count the correct predictions 
        self.add_state('corrects', default=torch.tensor(0), dist_reduce_fx="sum") 
        # to count the total predictions 
        self.add_state('total', default=torch.tensor(0), dist_reduce_fx="sum") 

    def update(self, preds, target): 
        # update correct predictions count 
        self.correct += torch.sum(preds == target) 
        # update total count, numel() returns the total number of elements  
        self.total += target.numel() 

    def compute(self): 
        # final computation 
        return self.correct / self.total

代码详解:官方解释更为清楚: https://lightning.ai/docs/torchmetrics/stable/pages/implement.html#implement

  • The dist_reduce_fx argument to add_state is used to specify how the metric states should be reduced between batches in distributed settings. In this case we use "sum" to sum the metric states across batches. A couple of build in options are available: "sum""mean""cat""min" or "max", but a custom reduction is also supported.

  • In update we do not return anything but instead update the metric states in-place.

  • In compute when running in distributed mode, the states would have been synced before the compute method is called. Thus self.correct and self.total will contain the sum of the metric states across all processes.

参考

使用Torchmetrics快速进行验证指标的计算 - 知乎

  • 41
    点赞
  • 47
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值