CIFAR-10图像分类

本文介绍了如何利用PyTorch解决CIFAR-10图像分类问题,包括加载数据集、构建ResNet18模型、训练模型和可选的可视化过程。通过数据增强、模型构建和训练,展示了深度学习在图像分类任务中的应用。
摘要由CSDN通过智能技术生成

我们将运用在前面几节中学到的知识来参加Kaggle竞赛,该竞赛解决了CIFAR-10图像分类问题。比赛网址是https://www.kaggle.com/c/cifar-10

基本思路

  1. 加载数据集
  2. 构建ResNet18模型
  3. 训练模型
  4. 可视化效果(可选)

基于pytorch的代码

使用的是CIFAR-10数据集

日常导入需要用到的python库

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

据上几节所学知识都是torchvision.dataset.CIFAR10加载数据, 这次虽然也可以,但我们可以学一些新知识(使用torchvision.dataset.ImageFolder)

加载数据集

这里使用到数据增强(将图像扩成40 * 40 再随机裁剪32 * 32, 水平翻转图片, 对图像进行均值归一化操作)
ImageFolder都是自己已经下载好的数据集
然后和往常一样加载就可以了

transform_train = transforms.Compose([
    # 随机裁剪成32 * 32, 四周填充边长为4
    transforms.RandomCrop(32, padding=4),
    # 随机水平翻转 p=.5
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # 均值,归一化
    transforms.Normalize((0.4731, 0.4822, 0.4465), (0.2212, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4731, 0.4822, 0.4465), (0.2212, 0.1994, 0.2010))
])

train_data = datasets.ImageFolder("/home/kesci/input/CIFAR102891/cifar-10/train",
            transform=transform_train)
valid_data = datasets.ImageFolder("/home/kesci/input/CIFAR102891/cifar-10/valid",
            transform=transform_test)
test_data = datasets.ImageFolder("/home/kesci/input/CIFAR102891/cifar-10/test",
            transform=transform_test)

train_iter = torch.utils.data.DataLoader(train_data, batch_size=128,
            shuffle=True, 
            num_workers=4)
valid_iter = torch.utils.data.DataLoader(valid_data, batch
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值