PyTorch 数据处理工具箱
文章目录
1、数据处理工具箱概述
Pytorch 涉及数据处理(数据装载、数据预处理、数据增强等)主要工具包及相互关系如图:
它主要包含 4 个类:
- Dataset:是一个抽象类,其它数据集需要继承这个类,并且覆写其中的两个方法(_getitem_、_len_);
- DataLoader:定义一个新的迭代器,实现批量(batch)读取,打乱数据(shuffle)并提供并行加速等功能
- random_split:把数据集随机拆分为给定长度的非重叠新数据集
- *sampler:多种采样函数
中间是 Pytorch 可视化处理工具(torchvision),Pytorch 的一个视觉处理工具包,独立于 Pytorch,需要另外安装。它包括 4 个类,各类的主要功能如下:
- datasets:提供常用的数据集加载,设计上都是继承 torch.utils.data.Dataset,主要包括 MMIST、CIFAR10/100、ImageNet、COCO 等;
- models:提供深度学习中各种经典的网络结构以及训练好的模型(如果选择 pretrained=True),包括 AlexNet, VGG系列、ResNet 系列、Inception 系列等;
- transforms:常用的数据预处理操作,主要包括对 Tensor 及 PIL Image 对象的操作
- utils:含两个函数,一个是 make_grid,它能将多张图片拼接在一个网格中;另一个是 save_img,它能将 Tensor 保存成图片
2、utils.data 简介
-
utils.data 包括 Dataset 和 DataLoader:
-
torch.utils.data.Dataset 为抽象类。自定义数据集需要继承这个类,并实现两个函数。一个是__len__,另一个是__getitem__,前者提供数据的大小(size),后者通过给定索引获取数据和标签。
-
_getitem_ 一次只能获取一个数据,所以通过 torch.utils.data.DataLoader 来定义一个新的迭代器,实现 batch 读取。
DataLoader 的格式为:
data.DataLoader( dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, )
- dataset:加载的数据集;
- batch_size:批大小;
- shuffle:是否将数据打乱;
- sampler:样本抽样;
- num_workers:使用多进程加载的进程数,0 代表不使用多进程;
- collate_fn:如何将多个样本数据拼接成一个 batch,一般使用默认的拼接方式即可;
- pin_memory:是否将数据保存在 pin memory 区,pin memory 中的数据转到 GPU 会快一些;
- drop_last:dataset 中的数据个数可能不是 batch_size 的整数倍,drop_last 为 True 会将多出来不足一个batch 的数据丢弃。
2.1、自定义一个数据集
-
导入需要的模块
import torch from torch.utils import data import numpy as np
-
定义获取数据集的类
类继承基类 Dataset,自定义一个数据集及对应标签。
class TestDataset(data.Dataset):#继承Dataset def __init__(self): self.Data=np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])#一些由2维向量表示的数据集 self.Label=np.asarray([0,1,0,1,2])#这是数据集对应的标签 def __getitem__(self, index): #把numpy转换为Tensor txt=torch.from_numpy(self.Data[index]) label=torch.tensor(self.Label[index]) return txt,label def __len__(self): return len(self.Data)
-
获取数据集中数据
Test=TestDataset() print(Test[2]) #相当于调用__getitem__(2) print(Test.__len__()) #輸出: #(tensor([2, 1]), tensor(0)) #5
以上数据以 tuple 返回,每次只返回一个样本。实际上,Dateset 只负责数据的抽取,一次调用__getitem__只返回一个样本。如果希望批量处理(batch),同时还要进行 shuffle 和并行加速等操作,可选择 DataLoader。
test_loader = data.DataLoader(Test,batch_size=2,shuffle=False,num_workers=2) for i,traindata in enumerate(test_loader): print('i:',i) Data,Label=traindata print('data:',Data) print('Label:',Label)
从这个结果可以看出,这是批量读取。我们可以像使用迭代器一样使用它,如对它进行循环操作。不过它不是迭代器,我们可以通过 iter 命令转换为迭代器。
dataiter=iter(test_loader) imgs,labels=next(dataiter) #imgs.size()
一般用 data.Dataset 处理同一个目录下的数据。如果数据在不同目录下,不同目录代表不同类别(这种情况比较普遍),使用 data.Dataset 来处理就不很方便。不过,可以使用 Pytorch 另一种可视化数据处理工具(即 torchvision)就非常方便,不但可以自动获取标签,还提供很多数据预处理、数据增强等转换函数。
-
3、torchvision 简介
- torchvision 有 4 个功能模块,model、datasets、transforms 和 utils:
- 利用 datasets 下载一些经典数据集;
- 提供深度学习中各种经典的网络结构以及训练好的模型(如果选择 pretrained=True);
- datasets 的 ImageFolder处理自定义数据集;
- transforms 对源数据进行预处理、增强。
3.1、transforms
transforms 提供了对 PIL Image 对象和 Tensor 对象的常用操作:
-
对 PIL Image 的常见操作如下:
- Scale/Resize:调整尺寸,长宽比保持不变;
- CenterCrop、RandomCrop、RandomSizedCrop:裁剪图片,CenterCrop 和 RandomCrop 在 crop 时是固定size,RandomResizedCrop 则是 random size 的 crop;
- Pad:填充;
- ToTensor:把一个取值范围是 [0,255] 的 PIL.Image 转换成 Tensor。形状为 (H,W,C) 的 numpy.ndarray,转换成形状为 [C,H,W],取值范围是 [0,1.0] 的 torch.FloatTensor;
- RandomHorizontalFlip:图像随机水平翻转,翻转概率为 0.5;
- RandomVerticalFlip:图像随机垂直翻转;
- ColorJitter:修改亮度、对比度和饱和度。
-
对 Tensor 的常见操作如下:
- Normalize:标准化,即减均值,除以标准差;
- ToPILImage:将 Tensor 转为 PIL Image。
-
如果要对数据集进行多个操作,可通过 Compose 将这些操作像管道一样拼接起来,类似于 nn.Sequential。以下为示例代码:
transforms.Compose([ #将给定的 PIL.Image 进行中心切割,得到给定的 size, #size 可以是 tuple,(target_height, target_width)。 #size 也可以是一个 Integer,在这种情况下,切出来的图片形状是正方形。 transforms.CenterCrop(10), #切割中心点的位置随机选取 transforms.RandomCrop(20, padding=0), #把一个取值范围是 [0, 255] 的 PIL.Image 或者 shape 为 (H, W, C) 的 numpy.ndarray, #转换为形状为 (C, H, W),取值范围是 [0, 1] 的 torch.FloatTensor transforms.ToTensor(), #规范化到[-1,1] transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)) ])
3.2、ImageFolder
-
当文件依据标签处于不同文件下时,可以利用 torchvision.datasets.ImageFolder 来直接构造出 dataset。
ImageFolder 会将目录中的文件夹名自动转化成序列,那么 DataLoader 载入时,标签自动就是整数序列了。
示例代码:
from torchvision import transforms, utils
from torchvision import datasets
import torch
import matplotlib.pyplot as plt
my_trans=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
train_data = datasets.ImageFolder('./data/torchvision_data', transform=my_trans)
train_loader = data.DataLoader(train_data,batch_size=8,shuffle=True,)
for i_batch, img in enumerate(train_loader):
if i_batch == 0:
print(img[1])
fig = plt.figure()
grid = utils.make_grid(img[0])
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.show()
utils.save_image(grid,'test01.png')
break
4、可视化工具
TensorboardX 是 Google TensorFlow 的可视化工具,它可以记录训练数据、评估数据、网络结构、图像等,并且可以在 web 上展示,对于观察神经网路训练的过程非常有帮助。
-
使用 tensorboardX 的一般步骤为:
-
导入 tensorboardX,实例化 SummaryWriter 类,指明记录日志路径等信息:
from tensorboardX import SummaryWriter #实例化SummaryWriter,并指明日志存放路径。在当前目录没有logs目录将自动创建。 writer = SummaryWriter(log_dir='logs') #调用实例 writer.add_xxx() #关闭writer writer.close()
SummaryWriter(log_dir=None, comment='', **kwargs) #其中comment在文件命名加上comment后缀
-
调用相应的 API 接口,接口一般格式为:
add_xxx(tag-name, object, iteration-number) #即add_xxx(标签,记录的对象,迭代次数)
-
启动 tensorboard 服务
cd 到 logs 目录所在的同级目录,在命令行输入如下命令,logdir 等式右边可以是相对路径或绝对路径。
tensorboard --logdir=logs --port 6006 #如果是windows环境,要注意路径解析,如 #tensorboard --logdir=r'D:\myboard\test\logs' --port 6006
-
web 展示
在浏览器输入:
http://服务器IP或名称:6006 #如果是本机,服务器名称可以使用localhost
-