import os import json import torch import torch.nn as nn from torchvision import transforms, datasets from torch.optim import Adam from tqdm import tqdm from model import resnet34 # 假设你的ResNet模型定义在model.py中 from model import resnet50 def validate(model, dataloader, device): # 将模型设置为评估模式 model.eval() # 定义总体准确率的累积变量 total_correct = 0 total_samples = 0 # 定义类别准确率字典 class_correct = {i: 0 for i in range(len(dataloader.dataset.classes))} class_total = {i: 0 for i in range(len(dataloader.dataset.classes))} with torch.no_grad(): for inputs, labels in dataloader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) _, predicted = torch.max(outputs, 1) c = (predicted == labels).squeeze() total_correct += c.sum().item() total_samples += labels.size(0) for i in range(len(labels)): label = labels[i].item() # 将标签转换为整数类型 class_correct[label] += c[i].item() class_total[label] += 1 # 计算总体准确率 overall_accuracy = total_correct / total_samples print("Overall accuracy: {:.2f}%".format(overall_accuracy * 100)) # 输出每个类别的准确率 for i in range(len(dataloader.dataset.classes)): if class_total[i] > 0: print('Accuracy of %5s : %2d %%' % ( dataloader.dataset.classes[i], 100 * class_correct[i] / class_total[i])) else: print('Accuracy of %5s : N/A (no validation samples)' % (dataloader.dataset.classes[i])) def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("Using {} device.".format(device)) data_transform = { "val": transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])} image_path = "E:\\wafer_data\\wafer_27" # 修改为你的数据集路径 assert os.path.exists(image_path), "{} path does not exist.".format(image_path) validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=data_transform["val"]) validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=4, shuffle=False) # 加载ResNet模型 num_classes = len(validate_dataset.classes) # 获取数据集中的类别数 net = resnet50(num_classes=num_classes) # 使用你的ResNet模型 # 加载预训练权重 net.load_state_dict(torch.load('resNet50.pth')) # 修改为你的模型路径 net.to(device) # 在验证集上验证每个类别的分类准确率 print("Validation accuracy for each class:") validate(net, validate_loader, device) if __name__ == '__main__': main()
ResNet模型计算每个类别的准确率与总准确率
最新推荐文章于 2024-07-12 19:06:42 发布