Top1ACC

这段代码展示了如何在PyTorch中计算模型在测试集上的Top1和Top5准确率。首先,定义了评估Top1和Top5准确率的函数,然后加载预处理的测试数据集,接着加载模型并进行评估。最后,打印出模型的Top1和Top5准确率。
摘要由CSDN通过智能技术生成

T o p 1 A C C Top1ACC Top1ACC

准确率(accuracy): (TP + TN )/( TP + FP + TN + FN)

Acc:所有预测正确的/所有

#  for major_test
import torch
import major_config
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from major_dataset import LoadDataset

def evaluteTop1(model, loader):
    model.eval()

    correct = 0
    total = len(loader.dataset)

    for x, y in loader:
        #x, y = x.to(major_config.device), y.to(major_config.device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
        # correct += torch.eq(pred, y).sum().item()
    return correct / total


def evaluteTop5(model, loader):
    model.eval()
    correct = 0
    total = len(loader.dataset)
    for x, y in loader:
        #x, y = x.to(major_config.device), y.to(major_config.device)
        with torch.no_grad():
            logits = model(x)
            maxk = max((1, 5))
            y_resize = y.view(-1, 1)
            _, pred = logits.topk(maxk, 1, True, True)
            correct += torch.eq(pred, y_resize).sum().float().item()
    return correct / total

if __name__ == "__main__":
    # 1.加载测试数据
    # 1.1 预处理
    test_transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(major_config.norm_mean, major_config.norm_std),
    ])
    # 1.2 数据加载
    test_data = LoadDataset(data_dir=major_config.test_image, transform=test_transform)
    test_loader = DataLoader(dataset=test_data, batch_size=10, shuffle=True)  # shuffle训练时打乱样本

    # 2.加载模型
    net = major_config.model  # 对应修改模型 net = se_resnet50(num_classes=5,pretrained=True)
    path_model_state_dict = major_config.path_test_model
    net.load_state_dict(torch.load(path_model_state_dict))

    # 3.评测
    res_top1 = evaluteTop1(net,test_loader)
    print(res_top1)
    res_top5 = evaluteTop5(net,test_loader)
    print(res_top5)
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值