PyTorch系列 | correct += (predicted == labels).sum().item()的理解

各位小伙伴肯定看到过下面这段代码:

correct += (predicted == labels).sum().item()

这里面(predicted == labels)是布尔型,为什么可以接sum()呢?

我做了个测试,如果这里的predicted和labels是列表形式就会报错,如果是numpy的数组格式,会返回一个值,如果是tensor形式,就会返回一个张量。

举个例子:

import torch

a = torch.tensor([1,2,3])
b = torch.tensor([1,3,2])

print((a == b).sum())

上述代码的输出结果:

tensor(1)

如果将a和b改成numpy下的数组格式:

import numpy as np

a = np.array([1,2,3])
b = np.array([1,3,2])

print((a == b).sum())

上述代码的输出结果:

1

如果将a和b改成列表:

a = [1,2,3]
b = [1,3,2]

print((a == b).sum())

上述代码的输出结果:

Traceback (most recent call last):
  File "路径", line 4, in <module>
    print((a == b).sum())
AttributeError: 'bool' object has no attribute 'sum'

Process finished with exit code 1

Added:

.item()用于取出tensor中的值。

  • 48
    点赞
  • 94
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
这段代码的作用是在使用 PyTorch 定义的模型进行测试时,使用 `torch.no_grad()` 上下文管理器,以禁用梯度计算,从而提高代码效率。在测试过程中,我们不需要计算梯度,而且计算梯度会浪费计算资源,因此使用 `torch.no_grad()` 可以有效地提高测试速度。 `test_loader` 是一个 PyTorch 数据加载器,用于加载测试数据集。在每次迭代中,我们从加载器中获取一个批次的数据,这个批次包括 `images` 和 `labels` 两个变量。`images` 是一个张量,包含一个批次的测试图像数据,`labels` 是一个张量,包含相应的测试标签。 在使用模型进行测试时,我们将测试图像数据 `images` 作为模型的输入,并使用 `model` 对其进行前向传播。`outputs` 是模型的输出,它是一个张量,包含每个测试样本对应的预测结果。我们使用 `torch.max()` 函数找到每个测试样本的预测类别,并将其存储在 `predicted` 变量中。 我们通过比较每个测试样本的预测类别和真实类别来计算测试集的准确率。`total` 变量存储测试集的总样本数,`correct` 变量存储被正确预测的样本数。在每次迭代中,我们将当前批次的样本数量 `labels.size(0)` 添加到 `total` 中,将被正确预测的样本数量 `(predicted == labels).sum().item()` 添加到 `correct` 中。 最终,我们可以使用 `total` 和 `correct` 计算测试集的准确率。
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值