前言
pytorch cifar-10识别
文章是学习记录性质的。本文使用ResNet实现cifar-10识别。共训练100次。训练过程,和测试过程,是函数,读者可以更换不同的model来进行训练和测试。用tensorboard来保存损失。最后,使用torch.save()保存网络,计算结果。
1.cifar-10数据集介绍
CIFAR-10(Canadian Institute for Advanced Research - 10)是一个常用的计算机视觉数据集,由60000张32x32像素的彩色图像组成,分为10个类别,每个类别有6000张图像。这个数据集被广泛用于图像分类、目标检测、图像生成等计算机视觉任务的模型训练和评估。
CIFAR-10包含以下10个类别:
飞机 (airplane)
汽车 (automobile)
鸟类 (bird)
猫 (cat)
鹿 (deer)
狗 (dog)
青蛙 (frog)
马 (horse)
船 (ship)
卡车 (truck)
每个类别都包含6000张图像,其中5000张用于训练,1000张用于测试。图像的大小为32x32像素,彩色图像,包括红、绿、蓝三个通道。
CIFAR-10是一个相对较小的数据集,可以用于快速原型设计和实验。由于图像尺寸较小,训练模型相对较快,因此研究人员经常在这个数据集上进行实验,以验证算法的有效性和性能。
可以使用一下代码查看部分数据
from matplotlib import pyplot as plt
import numpy as np
import torch
cifar10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)
class_names = ['airplane','automobile','bird','cat','deer',
'dog','frog','horse','ship','truck']
fig = plt.figure(figsize=(8,3))
num_classes = 10
for i in range(num_classes):
ax = fig.add_subplot(2, 5, 1 + i, xticks=[], yticks=[])
ax.set_title(class_names[i])
img = next(img for img, label in cifar10 if label == i)
plt.imshow(img)
plt.show()
结果
2.训练过程
引入库,加载数据
import datetime
from matplotlib import pyplot as plt
import numpy as np
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('logs')
#类名 这是CIFAR-10数据集的类别名称,对应着数据集中的10个不同类别。
class_names = ['airplane','automobile','bird','cat','deer',
'dog','frog','horse','ship','truck']
#这段代码定义了一个数据预处理管道,将图像数据转换为张量,并进行标准化。
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 下载并加载训练集,这里使用torchvision库中的CIFAR10类加载CIFAR-10训练集,将其进行预处理并使用DataLoader封装成一个可以迭代的数据加载器,以便在训练时批量获取数据。
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=64,
shuffle=True)
# 下载并加载测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=64,
shuffle=False)
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
训练函数
def training_loop(n_epochs,optimizer,loss_fn,model,train_loader):
for epoch in range(1,n_epochs+1):
train_loss = 0.0
for images,labels in train_loader:
images = images.to(device = device)
labels = labels.to(device= device)
optimizer.zero_grad()
outputs = model(images)
loss = loss_fn(outputs,labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
if epoch == 1 or epoch%10 == 0:
avg_train_loss = train_loss / len(train_loader)
print('{} Epoch {}, Training loss: {:.4f}'.format(
datetime.datetime.now(), epoch, avg_train_loss))
# 使用SummaryWriter记录训练损失
writer.add_scalar('Training Loss', avg_train_loss, epoch)
writer.close()
for epoch in range(1, n_epochs+1)::循环每个epoch。
for images, labels in train_loader::遍历训练数据集的每个批次。
images = images.to(device=device)和labels = labels.to(device=device):将数据移到GPU(如果可用)。
optimizer.zero_grad(): 清零梯度,以防止梯度累积。
outputs = model(images): 模型的前向传播。
loss = loss_fn(outputs, labels): 计算模型输出与真实标签之间的损失。
loss.backward(): 反向传播,计算梯度。
optimizer.step(): 使用梯度更新模型参数。
train_loss += loss.item(): 累计每个批次的训练损失。
最后,打印平均训练损失,并使用SummaryWriter记录训练损失。也可以记录一下精确度之类的。
测试函数
def validate(model, train_loader, val_loader):
accdict = {}
for name, loader in [("train", train_loader), ("val", val_loader)]:
correct = 0
total = 0
with torch.no_grad():
for imgs, labels in loader:
imgs = imgs.to(device=device)
labels = labels.to(device=device)
outputs = model(imgs)
_, predicted = torch.max(outputs, dim=1)
total += labels.shape[0]
correct += int((predicted == labels).sum())
print("Accuracy {}: {:.2f}".format(name , correct / total))
accdict[name] = correct / total
return accdict
for name, loader in [(“train”, train_loader), (“val”, val_loader)]::循环遍历训练集和验证集。
correct = 0 和 total = 0:初始化正确预测的样本数和总样本数。
with torch.no_grad()::在这个上下文中,不计算梯度,以提高验证过程的效率。
for imgs, labels in loader::遍历数据加载器的每个批次。
outputs = model(imgs): 模型的前向传播。
_, predicted = torch.max(outputs, dim=1): 获取预测结果中的最大值及其索引。
total += labels.shape[0] 和 correct += int((predicted == labels).sum()): 统计总样本数和正确预测的样本数。
accuracy = correct / total: 计算准确率。
print(“Accuracy {}: {:.2f}”.format(name, accuracy)): 打印准确率。
accdict[name] = accuracy: 将准确率保存到字典中。
建立模型
模型可以换成其他的,这里使用的是ResNet,简单的ResNet风格的卷积神经网络(Convolutional Neural Network,CNN)。
class ResNet(nn.Module):
def __init__(self, n_chans1=32):
super(ResNet, self).__init__()
self.n_chans1 = n_chans1
# 第一个卷积层,输入通道数为3(RGB图像),输出通道数为n_chans1,卷积核大小为3x3,padding为1
self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
# 第二个卷积层,输入通道数为n_chans1,输出通道数为n_chans1//2,卷积核大小为3x3,padding为1
self.conv2 = nn.Conv2d(n_chans1, n_chans1//2, kernel_size=3, padding=1)
# 第三个卷积层,输入和输出通道数都为n_chans1//2,卷积核大小为3x3,padding为1
self.conv3 = nn.Conv2d(n_chans1//2, n_chans1//2, kernel_size=3, padding=1)
# 全连接层1,输入大小为4*4*n_chans1//2,输出大小为32
self.fc1 = nn.Linear(4 * 4 * n_chans1//2, 32)
# 全连接层2,输入大小为32,输出大小为10(对应10个类别)
self.fc2 = nn.Linear(32, 10)
def forward(self, x):
# 第一个卷积层后接ReLU激活函数和最大池化层
out = F.max_pool2d(torch.relu(self.conv1(x)), 2)
# 第二个卷积层后接ReLU激活函数和最大池化层
out = F.max_pool2d(torch.relu(self.conv2(out)), 2)
out1 = out
# 第三个卷积层后接ReLU激活函数和最大池化层,同时加上第二个卷积层的输出(残差连接)
out = F.max_pool2d(torch.relu(self.conv3(out)) + out1, 2)
# 将输出展平为一维向量
out = out.view(-1, 4 * 4 * self.n_chans1//2)
# 全连接层1后接ReLU激活函数
out = torch.relu(self.fc1(out))
# 全连接层2,最终输出
out = self.fc2(out)
return out
这个模型使用了ResNet的一些基本思想,包括残差连接(residual connection)来帮助网络更好地学习。
加载模型进行训练
model = ResNet(n_chans1=32).to(device=device)
optimizer = optim.SGD(model.parameters(),lr=1e-2)
loss_fn = nn.CrossEntropyLoss()
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('logs')
training_loop(
n_epochs=100,
optimizer = optimizer,
model = model,
loss_fn = loss_fn,
train_loader=train_loader,
)
torch.save(model,'cifar-10-1')
validate(model, train_loader, test_loader)
可以使用torch.save()保存训练的结果。不过网络比较简单,效果不太好。