图像分类:基于 CIFAR-10 数据集的深度学习模型构建与实现
1. 引言
在计算机视觉领域,图像分类是一项基础性任务,要求模型根据输入图像将其归类到预定义的类别中。它被广泛应用于自动驾驶、安防监控、医疗影像分析等领域。本篇博客将详细介绍如何基于 CIFAR-10 数据集,使用 PyTorch 框架构建深度学习模型进行图像分类。我们会从数据加载和预处理开始,逐步搭建模型,训练和验证模型,最终展示结果并进行混淆矩阵分析。
2. CIFAR-10 数据集简介
CIFAR-10 是一个用于图像分类的经典数据集,包含 60000 张 32x32 像素的彩色图像,分为 10 个类。每个类别有 6000 张图像,这些类分别是飞机(airplane)、汽车(automobile)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)和卡车(truck)。其中,50000 张图像用于训练,10000 张用于测试。
3. 数据预处理
在深度学习任务中,数据的预处理非常重要。良好的数据预处理可以帮助模型更快地收敛并提高准确率。我们将进行以下两种预处理:
- 标准化:将图像像素值从
[0, 255]
缩放到[0, 1]
,并使用 CIFAR-10 的均值和标准差进行归一化。 - 数据增强:为了提升模型的泛化能力,采用随机水平翻转和随机裁剪等增强技术。
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomCrop(32, padding=4), # 随机裁剪
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # 标准化
])
4. 快速验证
在实际任务中,训练完整模型往往需要耗费大量时间和资源。因此,在正式训练前,我们可以使用较小批量数据和较短的训练周期来快速验证代码是否正确。下面我们设置 batch_size=8
来加载数据,并只训练 1 个 epoch 来进行验证。
4.1. 数据加载与预处理
import torch
import torchvision
batch_size = 8 # 小批量数据
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
4.2. 模型构建
为了进行快速验证,我们使用 ResNet18
预训练模型,并修改最后一层为 10 类输出,适应 CIFAR-10 数据集。
import torch.nn as nn
import torchvision.models as models
# 使用 ResNet18 模型,并修改输出层
net = models.resnet18(pretrained=True)
net.fc = nn.Linear(net.fc.in_features, 10) # 修改最后一层为10类输出
4.3. 模型训练与验证
我们将使用交叉熵损失函数和 Adam 优化器进行训练,并只训练 1 个 epoch,方便快速验证模型的可行性。
import torch.optim as optim
criterion = nn.CrossEntropyLoss() # 定义损失函数
optimizer = optim.Adam(net.parameters(), lr=0.001) # Adam 优化器
# 训练模型(仅1个epoch)
for epoch in range(1): # 快速验证,训练1个epoch
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad() # 梯度清零
outputs = net(inputs) # 前向传播
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新权重
running_loss += loss.item()
if i % 100 == 99:
print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
running_loss = 0.0
print("训练完成")
4.4. 模型评估
我们使用测试集来验证模型的分类准确率,以确保模型在训练结束后能正确地对测试数据进行分类。
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the {total} test images: {100 * correct / total:.2f}%')
5. 完整的模型实现
在快速验证过后,我们可以运行完整的模型训练过程,增加训练轮数和批次大小,从而获得更高的分类准确率。以下是完整代码,展示数据加载、模型构建、训练、评估以及混淆矩阵可视化的全过程。
5.1. 数据加载与预处理
# 加载CIFAR-10数据集并进行预处理
batch_size = 64 # 增大batch size
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
5.2. 模型构建
# 使用 ResNet18 模型,并修改输出层
net = models.resnet18(pretrained=True)
net.fc = nn.Linear(net.fc.in_features, 10) # 修改最后一层为10类输出
5.3. 模型训练
我们将训练 10 个 epoch 来提高模型性能,并定期打印损失值来监控训练进度。
# 训练模型(训练10个epoch)
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad() # 梯度清零
outputs = net(inputs) # 前向传播
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新权重
running_loss += loss.item()
if i % 100 == 99:
print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
running_loss = 0.0
print("训练完成")
5.4. 模型评估与混淆矩阵可视化
在训练完成后,我们使用测试集评估模型的分类效果,并通过绘制混淆矩阵来分析模型的性能。
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
# 定义 CIFAR-10 类别名称
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 模型评估并生成混淆矩阵
correct = 0
total = 0
y_true = []
y_pred = []
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
y_true.extend(labels.cpu().numpy())
y_pred.extend(predicted.cpu().numpy())
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 打印测试集上的准确率
print(f'Accuracy of the network on the {total} test images: {100 * correct / total:.2f}%')
# 绘制混淆矩阵
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(
10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()
6. 结论
通过本篇博客的介绍,我们展示了如何基于 CIFAR-10 数据集使用深度学习进行图像分类。首先通过快速验证确保模型可以正常运行,然后进行完整的模型训练,进一步优化性能。最终,我们使用混淆矩阵分析模型在不同类别上的表现,并根据分类结果进行调整和优化。在实际应用中,可以通过增加训练轮数、数据增强和超参数调整来进一步提升模型效果。
如果你对更多关于算法、深度学习和人工智能的内容感兴趣,欢迎关注微信公众号 “算法最TOP”。我们会定期发布高质量的技术文章和最新的行业动态,助你在技术领域不断精进,保持在行业的最前沿!希望这篇博客对你有所帮助,期待你的关注!