Pytorch 多分类结果测试

在模型训练过程中需要对当前的效果进行验证,或者训练结束后需要在测试集上对模型进行测试。比如多分类问题,网络的前向传播的结果是一个概率值Tensor,如果是一个10分类问题,并且batch=4,结果是一个4*10的Tensor,Tensor的每一行表示某张图片分别在10分类下的预测概率值。

Pytorch中的argmax()函数可以返回Tensor中每一行最大值的索引,torch.eq()函数可以比较两个Tensor对应位置处的值是否相等,返回一个Tensor的结果,0表示不相等,1表示相等。

我们可以使用argmax()与torch.eq()进行多分类问题准确率的计算:

下面的代码是在训练完成后,在测试集上进行测试准确率的代码片段:

相信代码可参考https://blog.csdn.net/weicao1990/article/details/98754647

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
你可以按照以下步骤保存测试集结果: 1. 加载测试集数据 首先,你需要加载测试集数据。你可以使用 PyTorch 的 DataLoader 类和自定义的数据集类来完成这个任务。你需要确保你的测试集数据和训练集数据的预处理方式相同。 2. 加载模型 然后,你需要加载你的训练好的模型。你可以使用 PyTorch 的 torch.load() 函数来加载模型。 3. 开始测试 接下来,你需要使用测试集数据对模型进行测试。你可以使用 PyTorch 的 torch.no_grad() 上下文管理器来关闭梯度计算。对于每个测试数据,你需要使用模型进行预测,并将预测结果保存到一个列表中。 4. 保存测试结果 最后,你可以将测试结果保存到一个文件中。你可以使用 Python 的 csv 模块来将结果保存到 CSV 文件中。你可以将每个测试数据的预测结果和真实标签写入文件中。 下面是一个示例代码,展示了如何保存测试集结果: ``` import csv import torch from torchvision import transforms from dataset import CustomDataset # 自定义的数据集类 # 加载测试集数据 test_data = CustomDataset('test', transform=transforms.ToTensor()) test_loader = torch.utils.data.DataLoader(test_data, batch_size=32) # 加载模型 model = torch.load('model.pth') # 开始测试 predictions = [] with torch.no_grad(): for inputs, _ in test_loader: outputs = model(inputs) _, preds = torch.max(outputs, 1) predictions.extend(preds.tolist()) # 保存测试结果 with open('test_results.csv', 'w', newline='') as csvfile: writer = csv.writer(csvfile) writer.writerow(['ImageName', 'Label']) for i, (_, label) in enumerate(test_data): writer.writerow([f'image_{i}.jpg', predictions[i]]) ``` 这个示例代码假设你已经实现了一个名为 CustomDataset 的自定义数据集类。你需要将 'test' 参数传递给 CustomDataset 类的构造函数来加载测试集数据。你还需要将一个名为 model.pth 的模型文件存放在当前目录中。最后,测试结果将被保存到一个名为 test_results.csv 的文件中。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

洪流之源

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值