ResNet模型计算每个类别的准确率与总准确率

import os
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.optim import Adam
from tqdm import tqdm

from model import resnet34  # 假设你的ResNet模型定义在model.py中
from model import resnet50
def validate(model, dataloader, device):
    # 将模型设置为评估模式
    model.eval()

    # 定义总体准确率的累积变量
    total_correct = 0
    total_samples = 0

    # 定义类别准确率字典
    class_correct = {i: 0 for i in range(len(dataloader.dataset.classes))}
    class_total = {i: 0 for i in range(len(dataloader.dataset.classes))}

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            total_correct += c.sum().item()
            total_samples += labels.size(0)
            for i in range(len(labels)):
                label = labels[i].item()  # 将标签转换为整数类型
                class_correct[label] += c[i].item()
                class_total[label] += 1

    # 计算总体准确率
    overall_accuracy = total_correct / total_samples
    print("Overall accuracy: {:.2f}%".format(overall_accuracy * 100))

    # 输出每个类别的准确率
    for i in range(len(dataloader.dataset.classes)):
        if class_total[i] > 0:
            print('Accuracy of %5s : %2d %%' % (
                dataloader.dataset.classes[i], 100 * class_correct[i] / class_total[i]))
        else:
            print('Accuracy of %5s : N/A (no validation samples)' % (dataloader.dataset.classes[i]))

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using {} device.".format(device))

    data_transform = {
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])}

    image_path = "E:\\wafer_data\\wafer_27"  # 修改为你的数据集路径
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=4, shuffle=False)

    # 加载ResNet模型
    num_classes = len(validate_dataset.classes)  # 获取数据集中的类别数
    net = resnet50(num_classes=num_classes)  # 使用你的ResNet模型
    # 加载预训练权重
    net.load_state_dict(torch.load('resNet50.pth'))  # 修改为你的模型路径
    net.to(device)

    # 在验证集上验证每个类别的分类准确率
    print("Validation accuracy for each class:")
    validate(net, validate_loader, device)


if __name__ == '__main__':
    main()
  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值