import torchvision
import torch
import torch.utils.data.dataloader as dataloader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
# 加载数据和处理数据
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
train_loader = dataloader.DataLoader(dataset=train_data, shuffle=False, batch_size=4, num_workers=2)
test_loader = dataloader.DataLoader(dataset=test_data, shuffle=True, batch_size=4, num_workers=2)
# 定义网络
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__(),
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.layer3 = nn.Sequential(
nn.Linear(5 * 5 * 16, 120),
nn.Linear(120, 84),
nn.Linear(84, 10)
)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = x.view(-1, 5 * 5 * 16)
x = self.layer3(x)
return x
# 实例化
cnn = CNN()
# 损失函数和优化
loss_f = nn.CrossEntropyLoss()
optim = optim.SGD(params=cnn.parameters(), lr=0.001, momentum=0.9)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
# 准确率
def accuracy(output, target, topk=(1,)):
maxk = max(topk) #topk=(1,)取top1准确率,topk=(1,5)取top1和top5准确率
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True) # topk参数中,maxk取得是top1准确率,dim=1是按行取值, largest=1是取最大值
pred = pred.t() # 转置
correct = pred.eq(target.view(1, -1).expand_as(pred)) # 比较是否相等
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
# train
for i in range(2):
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
for j, data in enumerate(train_loader, 0):
inputs, labels = data
optim.zero_grad()
# forward + backward + optimize
outputs = cnn(inputs)
loss = loss_f(outputs, labels)
loss.backward()
optim.step()
# measure accuracy and record loss
prec1 = accuracy(output=outputs, target=labels, topk=(1, 5))
losses.update(loss.item(), labels.size(0))
top1.update(prec1[0], labels.size(0))
top5.update(prec1[1], labels.size(0))
if j % 2000 == 1999:
"""
top1.val是一个batch中的准确率, avg为准确率
"""
print('Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(loss=losses, top1=top1, top5=top5))
# 测试
top1 = AverageMeter()
top5 = AverageMeter()
for data in test_loader:
inputs, labels = data
outputs = cnn(inputs)
prec1, prec5 = accuracy(outputs, labels, topk=(1, 5))
top1.update(prec1, labels.size(0))
top5.update(prec5, labels.size(0))
print('Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(top1=top1, top5=top5))
分类网络中的top1和top5
最新推荐文章于 2025-03-24 14:16:54 发布