ResNet模型计算每个类别的准确率与总准确率

import os
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.optim import Adam
from tqdm import tqdm

from model import resnet34  # 假设你的ResNet模型定义在model.py中
from model import resnet50
def validate(model, dataloader, device):
    # 将模型设置为评估模式
    model.eval()

    # 定义总体准确率的累积变量
    total_correct = 0
    total_samples = 0

    # 定义类别准确率字典
    class_correct = {i: 0 for i in range(len(dataloader.dataset.classes))}
    class_total = {i: 0 for i in range(len(dataloader.dataset.classes))}

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            total_correct += c.sum().item()
            total_samples += labels.size(0)
            for i in range(len(labels)):
  
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值