T
o
p
1
A
C
C
Top1ACC
Top1ACC
准确率(accuracy): (TP + TN )/( TP + FP + TN + FN)
Acc:所有预测正确的/所有
import torch
import major_config
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from major_dataset import LoadDataset
def evaluteTop1(model, loader):
model.eval()
correct = 0
total = len(loader.dataset)
for x, y in loader:
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct / total
def evaluteTop5(model, loader):
model.eval()
correct = 0
total = len(loader.dataset)
for x, y in loader:
with torch.no_grad():
logits = model(x)
maxk = max((1, 5))
y_resize = y.view(-1, 1)
_, pred = logits.topk(maxk, 1, True, True)
correct += torch.eq(pred, y_resize).sum().float().item()
return correct / total
if __name__ == "__main__":
test_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(major_config.norm_mean, major_config.norm_std),
])
test_data = LoadDataset(data_dir=major_config.test_image, transform=test_transform)
test_loader = DataLoader(dataset=test_data, batch_size=10, shuffle=True)
net = major_config.model
path_model_state_dict = major_config.path_test_model
net.load_state_dict(torch.load(path_model_state_dict))
res_top1 = evaluteTop1(net,test_loader)
print(res_top1)
res_top5 = evaluteTop5(net,test_loader)
print(res_top5)