介绍
欢迎阅读这篇博客!在这里,我们将使用 PyTorch 框架来构建、训练和测试一个动物图像10分类模型,并通过 CIFAR-10 数据集来验证模型的性能。我们还将演示如何使用预训练模型对新的图像进行分类。(适合新手小白的教程)pytorch-cifar10.zip资源-CSDN文库
1. 数据准备
首先,我们需要下载并准备 CIFAR-10 数据集。你可以通过以下步骤进行:
# 安装 torchvision
pip install torchvision
# 下载 CIFAR-10 数据集
原下载地址:http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz,http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz。(如果下载不下来,可以评论区发邮箱号找我领取数据集)
# 加载数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)
2. 模型选择
在这个示例中,我们选择了 ResNet18 作为我们的图像分类模型。你可以根据需求选择其他模型。
# 模型选择代码
from models import *
# 创建 ResNet18 模型(还可以使用vgg19、GoogLeNet等网络训练)
net = ResNet18()
# net = GoogLeNet()
# net = VGG('VGG19')
net = net.to(device)
if device == 'cuda':
net = torch.nn.DataParallel(net)
cudnn.benchmark = True
3. 训练
我们将使用 SGD 优化器和交叉熵损失函数进行模型训练。你可以通过调整学习率和其他参数来自定义训练过程。
# 恢复训练
if args.resume:
checkpoint = torch.load('./checkpoint/model.pth')
net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=end_epoch)
# 训练和测试
for epoch in range(start_epoch, start_epoch + end_epoch):
print(f'\nEpoch: {epoch}')
# 训练
with tqdm(total=len(trainloader), unit='batch', leave=False) as pbar_train:
net.train()
train_loss, correct, total = 0, 0, 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
pbar_train.update(1)
pbar_train.set_postfix({'Train_Loss': train_loss / (batch_idx + 1), 'Train_Acc': 100. * correct / total})
# 保存 checkpoint
scheduler.step()
# 保存最佳模型
acc = 100. * correct / total
if acc > best_acc:
print('Saved_model')
state = {'net': net.state_dict(), 'acc': acc, 'epoch': epoch}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/model.pth')
best_acc = acc
4. 测试
在测试阶段,我们将加载预训练的模型,并使用它对新的图像进行分类。
import cv2
import torchvision.transforms as transforms
from PIL import Image
from models import * # 根据你的模型导入相应的模型
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 载入训练好的模型
# 加载 checkpoint 时,添加 map_location 参数
checkpoint = torch.load('./checkpoint/model.pth', map_location='cpu')
# 创建模型时,使用 DataParallel
model = torch.nn.DataParallel(ResNet18())
model.load_state_dict(checkpoint['net'])
model.eval()
# 图像预处理
transform = transforms.Compose([
transforms.Resize((32, 32)), # 调整图像大小
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# 读取测试图像
image_path = 'images/1.jpg' # 替换为你的图像路径
image = Image.open(image_path)
input_image = transform(image)
img = cv2.imread(image_path)
img = cv2.resize(img, (300, 300))
cv2.imshow('Image', img) # 调整通道顺序
input_image = input_image.unsqueeze(0) # 添加 batch 维度
# 使用模型进行预测
with torch.no_grad():
output = model(input_image)
# 获取预测结果
_, predicted = output.max(1)
class_index = predicted.item()
# 打印预测结果
print(f'The predicted class is: {classes[class_index]}')
cv2.waitKey(0) # 保持窗口打开,直到按下任意键
cv2.destroyAllWindows()
5. 结果分析
我们将在博客中分析训练和测试结果,包括准确率、损失等指标的变化。
6. 图像分类应用
最后,我们将展示如何使用模型对新的图像进行分类。你只需提供图像路径,模型将返回其预测结果。