数据处理
数据预处理
PyTorch使用torchvision来完成数据的处理,其只实现了一些数据集的处理,如果处理自己的工程则需要修改增加内容
把原始数据处理为模型使用的数据
分为三步
数据处理格式的定义
transforms.Compose()
Compose() 代码
Composes several transforms together.
通过Compose把一些对图像处理的方法集中起来。比如先中心化,然后转换为张量(PyTorch的数据结构)
代码为:
transforms.Compose([transform.CenterCrop(10), transofrms.ToTensor()])
比如先转换为张量,然后正则化
代码为:
transforms.Compose([transofrms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
还要注意
def __call_(self, img):
for t in self.transforms:
img = t(img)
return img
把输入到Compose的操作按顺序进行执行。先执行第一个,然后第二个……。如果需要处理自己的数据,可以把具体的操纵放在这个类中实现
数据处理
torchvision.datasets
实现了不同针对数据集的处理方法,主要用来加载数据和处理数据。比如在mnist.py 和cifar.py 中用来处理mnist和cifar数据集。类的实现需要继承父类data.Dataset
主要方法有2个:
初始化类和对数据进行加载
有时需要定义一些开关来防止重复处理。数据的加载就是针对不同的数据,把其data和label(分为训练数据和测试数据)读入到内存中。
__init__(self, root, train=Ture, transform=None, traget_transform=None, download=False):
把读入的输出传给PyTorch(迭代器的方式)
transform.Compose在次数进行调用,通过index确定需要访问的数据,然后对其格式进行转换,最后返回处理后的数据。数据在定义时只是定义了一个类,其具体的数据传出在需要使用时使用该方法完成
--getitem__(self, index):
对数据进行加载,然后处理传给PyTorch已经完成,如果需要对自己的数据进行处理,也是通过修改和增加此部分完成。接下来需要对训练的数据进行处理,比如分批次的大小,十分随机处理等等。
数据加载
torch.utils.data.DataLoader()
把合成数据并且提供迭代访问
输入参数有:
dataset(Dataset)
输入加载的数据,就是上面的torchvision.datasets.myData()的实现,所以需要继承data.Dataset,满足此接口。
batch-size, shuffle, sampler, num_workers, collate_fn, pin_memory, drop_last
名称 | 作用 |
---|---|
batch-size | 样本每个batch的大小,默认为1 |
shuffle | 是否打乱数据,默认为False |
sampler | 定义一个方法来绘制样本数据,如果定义该方法,则不能使用shuffle |
num_workers | 数据分为几批处理(对于大数据) |
collate_fn | 整理数据,把每个batch数据整理为tensor。(一般使用默认调用default_collate(batch)) |
pin_memory | 针对不同类型的batch进行处理。比如为Map或者Squence等类型,需要处理为tensor类型 |
drop_last | 用于处理最后一个batch的数据。因为最后一个可能不能够被整除,如果设置为True,则舍弃最后一个,为False则保留最后一个,但是最后一个可能很小 |
迭代器(DataLoaderIter)的具体处理就是根据这些参数的设置,分别进行不同的处理
补充
torch.utils.data.DataLoader类主要使torch.utils.data.sampler实现
sampler是所有采样器的基础类,提供了迭代器的迭代(iter)和长度(len)接口实现,同时sampler也是通过索引对数据进行洗牌(shuffle)等操作。因此,如果DataLoader不适用于数据,需要重新设计数据的分批次,可以充分使用所提供的smapler。