目录
1、安装
官网地址:Welcome to TorchMetrics — PyTorch-Metrics 1.0.1 documentation
GitHub:Torchmetrics
pip install torchmetrics -i https://pypi.tuna.tsinghua.edu.cn/simple
2、基本流程介绍
在训练时我们都是使用 batch_size 批次训练,对于TorchMetrics也是一样的,在一个批次前向传递完成后将目标值Y和预测值Y_pre传递给torchmetrics的评价指标对象,评价指标对象会计算该批次评价指标并保存它(在其内部被称为state)。
当所有的批次完成时(也就是训练的一个Epoch完成),我们就可以从评价指标对象返回最终结果(这是对所有批计算的结果)。这里的每个度量对象都是从metric类继承,它包含了4个关键方法:
- metric.forward(pred, target):更新度量状态并返回当前批次上计算的度量结果。如果您愿意,也可以使用metric(pred, target),没有区别。
- metric.update(pred,target) :与forward相同,但是不会返回计算结果,相当于是只将结果存入了state。如果不需要在当前批处理上计算出的度量结果,则优先使用这个方法,因为他不计算最终结果速度会很快。
- metric.compute():返回在所有批次上计算的最终结果。也就是说其实forward相当于是update+compute。
- metric.reset():重置状态,以便为下一个验证阶段做好准备。
也就是说:在我们训练的当前批次,获得了模型的输出后可以forward或update(建议使用update)。在批次完成后,调用compute以获取最终结果。最后,在验证轮次(Epoch)或者启用新的轮次进行训练时您调用reset重置状态指标。
例如:
import torch
import torchmetrics
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model().to(device)
metric = torchmetrics.Accuracy().to(device)
for batch_idx, (data, target) in enumerate(val_dataloader):
data, target = data.to(device), target.to(device)
outp