AlexNet计算每个类别的精确率、召回率和 F1-Score代码

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from model import AlexNet  # 假设你的模型定义在model.py中


def validate(model, dataloader, device):
    # 将模型设置为评估模式
    model.eval()

    # 定义总体准确率的累积变量
    total_correct = 0
    total_samples = 0

    # 定义类别准确率字典
    class_correct = {i: 0 for i in range(len(dataloader.dataset.classes))}
    class_total = {i: 0 for i in range(len(dataloader.dataset.classes))}
    class_precision = {i: 0 for i in range(len(dataloader.dataset.classes))}
    class_recall = {i: 0 for i in range(len(dataloader.dataset.classes))}
    class_f1 = {i: 0 for i in range(len(dataloader.dataset.classes))}

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            total_correct += c.sum().item()
            total_samples += labels.size(0)
            for i in range(len(labels)):
                label = labels[i].item()  # 将标签转换为整数类型
                class_correct[label] += c[i].item()
                class_total[label] += 1

        # 计算每个类别的精确率、召回率和 F1-Score
        for i in range(len(dataloader.dataset.classes)):
            if class_total[i] > 0:
                class_precision[i] = class_correct[i] / class_total[i]
                class_recall[i] = class_correct[i] / len(dataloader.dataset.targets)
                class_f1[i] = 2 * (class_precision[i] * class_recall[i]) / (class_precision[i] + class_recall[i])

    # 计算总体精确率、召回率和 F1-Score
    overall_precision = sum(class_precision.values()) / len(class_precision)
    overall_recall = sum(class_recall.values()) / len(class_recall)
    overall_f1 = sum(class_f1.values()) / len(class_f1)
    overall_accuracy = total_correct / total_samples

    print("Overall accuracy: {:.2f}%".format(overall_accuracy * 100))
    print("Overall precision: {:.2f}".format(overall_precision))
    print("Overall recall: {:.2f}".format(overall_recall))
    print("Overall F1-Score: {:.2f}".format(overall_f1))

    # 输出每个类别的精确率、召回率和 F1-Score
    for i in range(len(dataloader.dataset.classes)):
        print('Class: %5s Precision: %.2f%% Recall: %.2f%% F1-Score: %.2f' %
              (dataloader.dataset.classes[i], class_precision[i] * 100,
               class_recall[i] * 100, class_f1[i] * 100))





def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using {} device.".format(device))

    data_transform = {
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    image_path = "E:\\wafer_data\\wafer_27"  # 修改为你的数据集路径
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=4, shuffle=False)

    # 加载模型
    net = AlexNet(num_classes=8, init_weights=False)  # 注意:此处要设置为False,因为我们将加载预训练权重
    net.load_state_dict(torch.load('AlexNet.pth'))  # 修改为你的模型路径
    net.to(device)

    # 在验证集上验证每个类别的分类准确率
    print("Validation accuracy for each class:")
    validate(net, validate_loader, device)


if __name__ == '__main__':
    main()
  • 8
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值