pytorch训练cifar10数据集查看各个种类图片的准确率

以vgg19为例

import os
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn,optim
#from lenet5 import Lenet5
from vgg import VGG19
from vgg import VGG34
 
def main():
    batchsz = 32
 
    cifar_train = datasets.CIFAR10('dataset/', train=True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406],
        					 std=[0.229,0.224,0.225])
    ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)

    cifar_test = datasets.CIFAR10('dataset/', train=False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406],
        					 std=[0.229,0.224,0.225])
    ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
 
    x, label = iter(cifar_train).next()
    print('x:', x.shape, 'label:', label.shape)
 
    device = torch.device('cuda')
    model = VGG19().to(device)
 
    criton = nn.CrossEntropyLoss().to(device)  #包含了softmax
    optimizer = optim.Adam(model.parameters(),lr=1e-3)
    print(model)
    
    if os.path.exists('model.pkl'):
        model.load_state_dict(torch.load('model.pkl'))
        print('model loaded from model.pkl')
 
    
    classes = ('plane', 'car', 'bird', 'cat',
        'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    N_CLASSES = 10
    class_correct = list(0. for i in range(N_CLASSES))
    class_total = list(0. for i in range(N_CLASSES))

    model.eval()
    #test
    total_correct = 0
    total_num = 0
    for x, label in cifar_test:
        x,label = x.to(device) ,label.to(device)
        logits = model(x)
        pred = logits.argmax(dim=1)
        total_correct += torch.eq(pred,label).float().sum().item()
        total_num += x.size(0)  #即batch_size
      
        c = (pred == label).squeeze()
        
        for i in range(len(label)):
            _label = label[i]
            class_correct[_label] += c[i].item()
            class_total[_label] += 1
        
    acc = total_correct / total_num
    print('acc: ',acc)

    for i in range(N_CLASSES):
        print('Accuracy of %5s : %2d %%' % (
            classes[i], 100 * class_correct[i] / class_total[i]))

 
 
if __name__ == '__main__':
    main()

 

而训练了vgg34则会是这样

 

 

参考:

https://blog.csdn.net/Arctic_Beacon/article/details/85068188

  • 2
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值