前言:按照深度学习项目的流程,最初的步骤就是组织数据集,pytorch中提供了常用的深度学习图像数据集,cifar10,coco,imagenet等等,也提供了处理输入数据的工具DataLoader, transforms等工具,非常之方便。本篇将详细介绍使用pytorch加载、处理数据集,并使用nn.Module搭建简单cifar10图像分类模型。
之所以选择cifar10数据集,是因为它比较小,好操作,不要求大量资源。
1、数据集的加载
import torch.utils.data
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
cifar_data = torchvision.datasets.CIFAR10('./data', train=False, transform=transforms.ToTensor(), download=True)
print(len(cifar_data), type(cifar_data))
target_classes = cifar_data.classes
使用torchvision中datasets加载对应数据集,需要指定数据集存放文件夹,下载训练集还是验证集,下载的图像是PIL类型的文件,可以在这一步进行类型转换为Tensor,并进行下载。对于数据加载这种I/O密集形任务,可设置num_workers