#测试集 testset = torchvision.datasets.CIFAR10(root='./datacifar',train=False,download=True,transform=transform) #数据加载 testloader = torch.utils.data.DataLoader(testset,batch_size=10000,shuffle=True,num_workers=0) # 迭代测试集的数据 test_data_iter = iter(testloader) test_image,test_label = next(test_data_iter) accuracy = (test_label == predict_y).sum().item() / test_label.size(0)
这里的test_label.size(0)是什么意思呢?首先test_label是来自上面的测试集的标签,是一个Tensor数据类型,test_label.size()方法返回一个元组,表示张量的尺寸。对于标签张量,通常有两个维度:批处理大小(batch size)和类别数量(number of classes)
test_label.size(0)返回的是第一维的大小,即批处理大小,也就是当前批次中样本的数量
test_label.size(1)返回的是第二维的大小,即类别数量
具体参考原博主传送门:如何理解 labels.size(0) ? - 茴香豆的茴 - 博客园 (cnblogs.com)