1. 数据集简介
CIFAR10数据集共有6W张彩色图像,图像大小是32*32*3的,共计10个类,每类6K张图片。
其中训练集5W张,构成了5个训练批,每一批1W张,但一个训练批中的各类图像并不一定数量相同,总的来看训练集,每一类都有5K张;测试集1W张单独构成一批,其来自10个分类,每类随机取1K张。
2. 数据加载
2.1 数据集下载
只下载一次, 批量迭代读取
def load_data(batch_size):
# 1. 构建数据转换器,进行数据增强
transform = get_transform()
# 2. 下载数据集
train_set = datasets.CIFAR10(root='../data/', train=True, download=True, transform=transform)
test_set = datasets.CIFAR10(root='../data/', train=False, download=True, transform=transform)
# 3. 生成数据迭代器
if torch.cuda.is_available():
# 使用GPU时,需要设置num_workers、pin_memory
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
else: # 使用cpu
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, pin_memory=True)
return train_loader, test_loader
2.2 数据增强
定义数据转换器
def get_transform():
transform1 = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5)) # 归一化[-1,1]
])
transform2 = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
transforms.RandomHorizontalFlip(0.3),
transforms.RandomVerticalFlip(0.3),
transforms.RandomRotation(10),
transforms.ColorJitter(0.25, 0.25, 0.25, 0.25)
])
return transform1
2.3 分割验证集
数据加载,从训练集中分割出验证集(占比30%), 根据索引进行采样