我们将运用在前面几节中学到的知识来参加Kaggle竞赛,该竞赛解决了CIFAR-10图像分类问题。比赛网址是https://www.kaggle.com/c/cifar-10
基本思路
- 加载数据集
- 构建ResNet18模型
- 训练模型
- 可视化效果(可选)
基于pytorch的代码
使用的是CIFAR-10数据集
日常导入需要用到的python库
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
据上几节所学知识都是torchvision.dataset.CIFAR10加载数据, 这次虽然也可以,但我们可以学一些新知识(使用torchvision.dataset.ImageFolder)
加载数据集
这里使用到数据增强(将图像扩成40 * 40 再随机裁剪32 * 32, 水平翻转图片, 对图像进行均值归一化操作)
ImageFolder都是自己已经下载好的数据集
然后和往常一样加载就可以了
transform_train = transforms.Compose([
# 随机裁剪成32 * 32, 四周填充边长为4
transforms.RandomCrop(32, padding=4),
# 随机水平翻转 p=.5
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# 均值,归一化
transforms.Normalize((0.4731, 0.4822, 0.4465), (0.2212, 0.1994, 0.2010))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4731, 0.4822, 0.4465), (0.2212, 0.1994, 0.2010))
])
train_data = datasets.ImageFolder("/home/kesci/input/CIFAR102891/cifar-10/train",
transform=transform_train)
valid_data = datasets.ImageFolder("/home/kesci/input/CIFAR102891/cifar-10/valid",
transform=transform_test)
test_data = datasets.ImageFolder("/home/kesci/input/CIFAR102891/cifar-10/test",
transform=transform_test)
train_iter = torch.utils.data.DataLoader(train_data, batch_size=128,
shuffle=True,
num_workers=4)
valid_iter = torch.utils.data.DataLoader(valid_data, batch