PyTorch中的数据输入和预处理

PyTorch中的数据输入和预处理

数据载入类

在使用PyTorch构建和训练模型的过程中,经常需要将原始的数据转换为张量。为了能够方便地批量处理图片数据,PyTorch引入了一系列工具来对这个过程进行包装。

PyTorch数据的载入使用torch.utils.data.DataLoader类。该类的签名如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
		  batch_sampler=None, num_workers=0, collate_fn=None, 
		  pin_memory=False, drop_last=False, timeout=0, 
		  worker_init_fn=None)

其中,dataset是一个torch.utils.data.Dataset类的实例,batch_size是Mini-batch的大小,shuffle代表数据会不会被随机打乱,sampler是自定义的采样器,每次迭代的时候会返回一个数据的下标索引,batch_sampler类似于sampler,不过返回的是一个Mini-batch的数据索引,而sampler仅仅返回下标索引。num_workers是数据载入器使用的进程数目。默认为0,即使用单进程来处理输入数据,collate_fn定义如何把一批dataset的实例转化为包含Mini-batch的张量。pin_memory参数会把数据转移到和GPU相关联的CPU内存中,从而加快GPU载入数据的速度,drop_last的设置决定了是否要把最后一个Mini-batch的数据丢弃掉,加入最后一个MIni-batch的数据数目小于预先设置的batch_size参数,timeout值如果大于0,就会决定在多进程情况下对数据的等待时间,worker_init_fn决定了每个数据载入的子进程开始时运行的函数,这个函数运行在随机种子设置之后、数据载入之前。

映射类型的数据集

为了能够使用DataLoader类,首先需要构造关于单个数据的torch.utils.data.Dataset类。这个类有两种,第一种是映射类型的,对于这个类型,每个数据有一个对应的索引,通过输入具体的索引,就能得到对应的数据,其构造方法如下所示:

class Dataset(object):
    def __getitm__(self, index):
        # index: 数据索引
        # ...
        # 返回数据张量

    def __len__(self):
        # 返回数据的数目
        # ...

对于这个类,主要需要重写两个方法,第一个方法是__geitem__,该方法是Python内置的操作符方法,对应的操作符是索引操作符[],通过输入整数数据索引,其大小在0至N-1之间,返回具体的某一条数据记录。另一个方法是__len__,该方法返回数据的总数,若是一个Dataset类重写了该方法可以通过使用len内置函数来获取数据的数目。

torchvision工具包的使用

一个简单torch.utils.data.Dataset类的实现如下:

class VisionDataset(data.Dataset):
    def __init__(self, root, transforms=None, transform=None, target_transform=None):
        # ...

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

class DatasetFolder(VisionDataset):
    def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None):
        super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform)
        classes, class_to_idx = self._find_classes(self.root)
        self.samples = make_dataset(self.root, class_to-idx, extensions, is_valid_file)
        self.loader = loader
        # ...

    def __getitem__(self, index):
        path, target = self.samples[inex]
        sample = self.loader(path)
        
        if self.transform is not None:
            sample = self.transform(sample)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target
    
    def __len__(self):
        
        return len(self.samples)

从DataFolder类开始看,该类的使用情形是数据集存储在一个目录下,每个目录有很多子目录,子目录的个数是图片类的数目,每个子目录下都存储有很多图片,且这些图片都属于一类。DataFolder类继承了VisionDataset类。在DataFolder类的构造函数中一开始调用了类内部的_find_classes来找到具体的预测目标的类别和类别对应的class_to_idx,得到包含所有数据记录的一个列表。这个列表记录着数据的路径和数据的预测目标。另外这个构造函数还传入了一个参数loader,用来载入数据。

__ getitem__这个方法会传入一个index,根据index从self.samples取得一条数据记录,得到数据记录的路径和预测目标,然后使用loader来对数据进行载入,并使用self.transform和self.target_transform对数据进行变换。最后返回变换以后的数据和预测目标。

torchvision包中有一些内置的转换函数,有一类主要作用于PyTorch张量。首先将张量转换为图片的类。其次在生成深度学习训练模型的时候,转换图片为张量以后,使用torchvision.transforms.Nomalize类标准化。这个类需要传入两个参数,第一个参数是所有图片的平均值张量,另一个是所有图片的标准差张量,输出的结果是输入图片张量减去平均值张量,然后除以标准差张量。最后,前文所述所有的转换类可以组成一个大的转换类,构造一个整体的包含所有列表按次序转换的转换类,这个类的调用效果是输出这些转换一次作用后的结果。

可迭代类型的数据集

可迭代类型的数据集相比于映射类型的数据集,不需要实现__getitem__方法和__len__方法,它本身更像一个Python迭代器。

对于不同的映射类型,因为索引之间相互独立,在使用多个进程载入数据的情况下,多个进程可以独立分配索引,迭代器在使用过程中,因为索引之间有先后顺序关系,需要考虑如何分割数据,使得不同的进程可以得到不同的数据。对这一类型的数据,可以根据不同工作进程的序号worker_id,设定不同数据迭代器的取值范围,保证不同的进程获取不同的迭代器,而且迭代器返回的数据各不相同。

总结

在进行深度学习的过程中,数据的输入和预处理十分重要。PyTorch提供的数据抽象类以及数据载入器的类,通过继承数据的抽象类,可以构造出针对某一个特殊数据的实例,然后输入数据载入器中,数据载入器可以自动对数据进行多进程处理,最后输出数据的张量供深度学习模型使用。

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch提供了许多方便的工具和类来进行数据读取和预处理。下面是一个基本的数据读取和预处理流程的示例: 1. 导入必要的库: ```python import torch from torchvision import transforms from torch.utils.data import DataLoader ``` 2. 定义数据集类: ```python class CustomDataset(torch.utils.data.Dataset): def __init__(self, data, targets, transform=None): self.data = data self.targets = targets self.transform = transform def __getitem__(self, index): x = self.data[index] y = self.targets[index] if self.transform: x = self.transform(x) return x, y def __len__(self): return len(self.data) ``` 在上面的代码,`CustomDataset` 是一个自定义的数据集类,其 `data` 是输入数据,`targets` 是对应的标签。`transform` 是一个可选的数据预处理函数。 3. 数据预处理: 可以使用 `torchvision.transforms` 的函数来对数据进行常见的预处理操作,例如缩放、裁剪、标准化等。下面是一个示例: ```python transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) ``` 在上面的代码,我们使用 `transforms.Compose` 将多个预处理操作连接在一起。示例使用了 `ToTensor` 将数据转换为张量,然后使用 `Normalize` 进行标准化。 4. 创建数据集实例: 使用定义的数据集类和预处理操作,创建数据集实例: ```python dataset = CustomDataset(data, targets, transform=transform) ``` 其 `data` 和 `targets` 是输入数据和标签,`transform` 是之前定义的数据预处理操作。 5. 创建数据加载器: 使用 `torch.utils.data.DataLoader` 创建数据加载器,可以指定批次大小、是否打乱数据等参数。 ```python dataloader = DataLoader(dataset, batch_size=32, shuffle=True) ``` 6. 迭代数据集: 现在可以使用数据加载器来迭代数据集,并进行模型训练或评估。 ```python for batch_data, batch_targets in dataloader: # 在这里执行模型训练或评估的操作 pass ``` 以上是一个基本的PyTorch数据读取和预处理的示例。你可以根据实际需求进行调整和扩展。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值