分类网络中的top1和top5

import torchvision
import torch
import torch.utils.data.dataloader as dataloader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
#  加载数据和处理数据
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)

train_loader = dataloader.DataLoader(dataset=train_data, shuffle=False, batch_size=4, num_workers=2)
test_loader = dataloader.DataLoader(dataset=test_data, shuffle=True, batch_size=4, num_workers=2)


#  定义网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__(),
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.layer3 = nn.Sequential(
            nn.Linear(5 * 5 * 16, 120),
            nn.Linear(120, 84),
            nn.Linear(84, 10)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = x.view(-1, 5 * 5 * 16)
        x = self.layer3(x)
        return x


#  实例化
cnn = CNN()

#  损失函数和优化
loss_f = nn.CrossEntropyLoss()
optim = optim.SGD(params=cnn.parameters(), lr=0.001, momentum=0.9)


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


#  准确率
def accuracy(output, target, topk=(1,)):
    maxk = max(topk) #topk=(1,)取top1准确率,topk=(1,5)取top1和top5准确率
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True) # topk参数中,maxk取得是top1准确率,dim=1是按行取值, largest=1是取最大值
    pred = pred.t()  # 转置
    correct = pred.eq(target.view(1, -1).expand_as(pred))  # 比较是否相等

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


#  train
for i in range(2):
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    for j, data in enumerate(train_loader, 0):
        inputs, labels = data

        optim.zero_grad()

        #  forward + backward + optimize
        outputs = cnn(inputs)
        loss = loss_f(outputs, labels)
        loss.backward()
        optim.step()
        # measure accuracy and record loss
        prec1 = accuracy(output=outputs, target=labels, topk=(1, 5))
        losses.update(loss.item(), labels.size(0))
        top1.update(prec1[0], labels.size(0))
        top5.update(prec1[1], labels.size(0))
        if j % 2000 == 1999:
            """
            top1.val是一个batch中的准确率, avg为准确率
            """
            print('Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(loss=losses, top1=top1, top5=top5))

#  测试
top1 = AverageMeter()
top5 = AverageMeter()
for data in test_loader:
    inputs, labels = data
    outputs = cnn(inputs)
    prec1, prec5 = accuracy(outputs, labels, topk=(1, 5))
    top1.update(prec1, labels.size(0))
    top5.update(prec5, labels.size(0))

print('Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(top1=top1, top5=top5))





评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值