import os import json import torch import torch.nn as nn from torchvision import transforms, datasets from torch.optim import Adam from tqdm import tqdm from model import resnet34 # 假设你的ResNet模型定义在model.py中 from model import resnet50 def validate(model, dataloader, device): # 将模型设置为评估模式 model.eval() # 定义总体准确率的累积变量 total_correct = 0 total_samples = 0 # 定义类别准确率字典 class_correct = {i: 0 for i in range(len(dataloader.dataset.classes))} class_total = {i: 0 for i in range(len(dataloader.dataset.classes))} with torch.no_grad(): for inputs, labels in dataloader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) _, predicted = torch.max(outputs, 1) c = (predicted == labels).squeeze() total_correct += c.sum().item() total_samples += labels.size(0) for i in range(len(labels)):
05-06
1538

08-01
324

03-30
3327
