4.1 数据处理工具箱概述
torch.utils.data工具包,包括以下4个类:
1)Dataset: 是一个抽象类,其他数据集需要继承这个类,并且覆盖其中的两个方法(getitem, len)
2) DataLoader: 定义一个新的迭代器,实现批量(batch)读取,打乱数据(shuffle)并提供并行加速等功能
3)random_split: 把数据集随机拆分为给定长度的非重叠的新数据集
4)*sampler:多种采样函数
4.2 utils.data简介
utils.data包括Dataset和DataLoader。torch.utils.data.Dataset为抽象类。自定义数据集需要继承这个类,并实现两个函数,一个是__len__, 另一个是__getitem__, 前者听数据的大小(size),后者通过给定索引获取数据和标签。__getitem__一次只能获取一个数据,所以需要通过torch.utils.data.DataLoader来定义一个新的迭代器,实现batch读取。
4.3 torchvision简介
torchvision有4个功能模块:model, datasets, transforms和utils。
4.3.1 transforms
transforms提供了对PIL Image对象和Tensor对象的常用操作
1)对PIL Image的常见操作如下
- Scale/Resize: 调整尺寸,长宽比保持不变
- CenterCrop, RandomCrop, RandomSizedCrop: 裁剪图片,CenterCrop和RandomCrop在crop时是- 固定size,RandomResizedCrop则是random size的crop
- Pad: 填充
- ToTensor:把一个取值范围是[0, 255]的PIL.Image转换成Tensor, 形状为(H, W, C)的Numpy.ndarray转换成形状为[C, H, W], 取值范围是[0, 1.0]的torch.FloatTensor
- RandomHorizontalFlip: 图像随机水平翻转,翻转概率为0.5
- RandomVerticalFlip: 图像随机垂直翻转
- ColorJitter: 修改亮度,对比度和饱和度
2)对Tensor的常见操作如下: - Normalize:标准化,即,减均值,除以标准差
- ToPILImage: 将Tensor转为PIL Image