使用 PyTorch 进行图像分类:CIFAR-10 数据集和预训练模型

介绍

欢迎阅读这篇博客!在这里,我们将使用 PyTorch 框架来构建、训练和测试一个动物图像10分类模型,并通过 CIFAR-10 数据集来验证模型的性能。我们还将演示如何使用预训练模型对新的图像进行分类。(适合新手小白的教程)pytorch-cifar10.zip资源-CSDN文库

1. 数据准备

首先,我们需要下载并准备 CIFAR-10 数据集。你可以通过以下步骤进行:

# 安装 torchvision
pip install torchvision

# 下载 CIFAR-10 数据集

原下载地址:http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz,http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz。(如果下载不下来,可以评论区发邮箱号找我领取数据集)

# 加载数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)

2. 模型选择

在这个示例中,我们选择了 ResNet18 作为我们的图像分类模型。你可以根据需求选择其他模型。

# 模型选择代码
from models import *

# 创建 ResNet18 模型(还可以使用vgg19、GoogLeNet等网络训练)
net = ResNet18()
# net = GoogLeNet()
# net = VGG('VGG19') 
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

3. 训练

我们将使用 SGD 优化器和交叉熵损失函数进行模型训练。你可以通过调整学习率和其他参数来自定义训练过程。

# 恢复训练
if args.resume:
    checkpoint = torch.load('./checkpoint/model.pth')
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=end_epoch)

# 训练和测试
for epoch in range(start_epoch, start_epoch + end_epoch):
    print(f'\nEpoch: {epoch}')

    # 训练
    with tqdm(total=len(trainloader), unit='batch', leave=False) as pbar_train:
        net.train()
        train_loss, correct, total = 0, 0, 0

        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            pbar_train.update(1)
            pbar_train.set_postfix({'Train_Loss': train_loss / (batch_idx + 1), 'Train_Acc': 100. * correct / total})

    # 保存 checkpoint
    scheduler.step()

    # 保存最佳模型
    acc = 100. * correct / total
    if acc > best_acc:
        print('Saved_model')
        state = {'net': net.state_dict(), 'acc': acc, 'epoch': epoch}
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/model.pth')
        best_acc = acc

4. 测试

在测试阶段,我们将加载预训练的模型,并使用它对新的图像进行分类。

import cv2
import torchvision.transforms as transforms
from PIL import Image
from models import *  # 根据你的模型导入相应的模型

classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 载入训练好的模型
# 加载 checkpoint 时,添加 map_location 参数
checkpoint = torch.load('./checkpoint/model.pth', map_location='cpu')

# 创建模型时,使用 DataParallel
model = torch.nn.DataParallel(ResNet18())
model.load_state_dict(checkpoint['net'])
model.eval()

# 图像预处理
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # 调整图像大小
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 读取测试图像
image_path = 'images/1.jpg'  # 替换为你的图像路径
image = Image.open(image_path)
input_image = transform(image)

img = cv2.imread(image_path)
img = cv2.resize(img, (300, 300))
cv2.imshow('Image', img)  # 调整通道顺序

input_image = input_image.unsqueeze(0)  # 添加 batch 维度

# 使用模型进行预测
with torch.no_grad():
    output = model(input_image)

# 获取预测结果
_, predicted = output.max(1)
class_index = predicted.item()

# 打印预测结果
print(f'The predicted class is: {classes[class_index]}')
cv2.waitKey(0)  # 保持窗口打开,直到按下任意键
cv2.destroyAllWindows()

5. 结果分析

我们将在博客中分析训练和测试结果,包括准确率、损失等指标的变化。

6. 图像分类应用

最后,我们将展示如何使用模型对新的图像进行分类。你只需提供图像路径,模型将返回其预测结果。

  • 9
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
### 回答1: PyTorch的ResNet-18在CIFAR-10数据集预训练模型是指在经过大规模的图像数据集进行训练后的ResNet-18模型,以便在CIFAR-10数据集进行更好的图像分类任务。 ResNet-18是一个由18个卷积层和全连接层组成的深度神经网络。预训练模型是指在大规模数据上进行训练得到的模型参数,因此具有更好的泛化性能。CIFAR-10是一个包含10个类别的图像分类数据集,用于在小尺寸图像上进行模型训练和评估。 通过使用训练的ResNet-18模型,在CIFAR-10数据集进行图像分类任务时,我们可以利用预训练模型的权重参数来加快训练过程并提高准确率。预训练模型的好处是可以从大规模数据中学习到更多的特征表示,这些特征表示通常具有更高的鉴别性,因此可以更好地捕捉图像的关键特征。 对于CIFAR-10数据集预训练模型可以有效地缩短训练时间并提高模型的收敛速度,因为在预训练模型中已经包含了对图像的一些共享特征的学习。通过在CIFAR-10数据集进行微调,即在预训练模型的基础上进行进一步的训练,可以逐步调整模型参数以适应CIFAR-10数据集的特定要求,从而提高最终的图像分类性能。 总而言之,PyTorch的ResNet-18在CIFAR-10预训练模型是通过在大规模数据上进行训练,在CIFAR-10数据集进行图像分类任务时使用预训练模型。这个预训练模型可以帮助提高训练速度和分类准确率,并且在模型训练和微调时起到了重要作用。 ### 回答2: PyTorch的ResNet-18是一种在CIFAR-10数据集进行训练的深度神经网络模型。CIFAR-10是一个包含10个类别的图像分类数据集,包括飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。 ResNet-18是指由18个卷积层和全连接层组成的深度残差网络。该网络的设计思想是通过残差连接(即跳过连接)来解决深度网络中的梯度消失问题,使得网络具有更好的训练效果。这意味着在每个卷积层之后,输入信号可以通过两条路径传递:一条直接连接到后续层,另一条通过卷积操作后再进行连接。这种设计可以使网络更加容易学习输入和输出之间的映射关系。 在CIFAR-10上预训练的ResNet-18模型具有多个优点。首先,这个模型具有较小的参数量和计算复杂度,适合在资源有限的环境下使用。其次,该模型经过在CIFAR-10数据集上的预训练,可以直接用于图像分类任务。通过在CIFAR-10进行训练,模型可以学习到一般的图像特征和模式,使其能够更好地泛化到其他类似的图像分类任务中。 通过使用训练的ResNet-18模型,我们可以利用其已经学到的特征和知识,节省训练时间,并为我们的具体图像分类任务提供一个良好的起点。此外,该模型可以通过微调(fine-tuning)进一步优化,以适应特定任务的需求。 综上所述,PyTorch的ResNet-18在CIFAR-10预训练模型是一个有价值的工具,可以用于图像分类任务,具有较小的参数量和计算复杂度,预先学习了一般的图像特征和模式,并可以通过微调进一步适应特定任务的需求。 ### 回答3: PyTorch预训练模型ResNet-18在CIFAR-10数据集上表现出色。首先,CIFAR-10是一个包含10个不同类别的图像数据集,每个类别有6000个图像,共计60000个图像。ResNet-18是一个基于深度残差网络的模型,它具有18个卷积层和全连接层。该模型在ImageNet数据集进行了预训练,其中包含了1000个类别的图像。 当我们将预训练的ResNet-18模型应用于CIFAR-10数据集时,可以得到很好的结果。因为CIFAR-10数据集的图像尺寸较小(32x32),相对于ImageNet数据集中的图像(224x224),所以ResNet-18模型在CIFAR-10上的训练速度更快。此外,ResNet-18模型通过残差连接解决了深度网络中的梯度消失问题,这使得它在CIFAR-10数据集上的表现也非常稳定。 通过使用预训练模型,我们可以通过迁移学习的方式节省训练时间。我们可以先将ResNet-18加载到内存中,然后只需针对CIFAR-10数据集的最后一层或几层进行微调即可。这样可以有效地提高模型在CIFAR-10上的性能。 总之,PyTorch中的预训练模型ResNet-18在CIFAR-10数据集上表现优秀。它通过残差连接解决了深度网络中的梯度消失问题,具有较快的训练速度和较好的稳定性。使用预训练模型可以节省训练时间,并通过微调模型的方式进一步提高性能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Anger、破晓

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值