pytorch中的数据输入和预处理

1. 数据载入类

pytorch数据载入使用的是torch.utils.data.DataLoader类,该类的的签名如下。

from torch.utils.data import DataLoader

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 
           batch_sampler=None, num_workers=0, collate_fn=None, 
           pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

其中主要参数解释为,dataset是一个torch.utils.DataLoader类的实例,batch_size是迷你批次的大小,shuffle代表数据会不会被随机打乱,sampler是自定义的采样器(shuffle=True时会构造默认的采样器,如果想使用自定义采样方法,可以构造一个torch.utils.data.Sampler的实例来进行采样,并设置shuffle=False),其中采样器是一个Python迭代器,每次迭代的时候会返回一个数据的下标索引,而batch_sampler类似于sampler,但返回的是一个迷你批次的数据索引,而sampler返回的仅仅是一个下标索引。num_workers是数据载入使用的进程数目,默认为0,collate_fn定义如何把一批dataset的实例转换为包含迷你批次数据的张量。

2. 映射类型的数据集

为了能够使用DataLoader类,首先需要构造关于单个数据的torch.utils.data.Dataset类。该类的一种映射类型,对于每个类型,每个数据都有一个对应的索引。

from torch.utils.data import Dataset

class Dataset(object):
    def __getitem__(sell, index):
        # index:数据索引
        # ...
        # 返回数据张量
        
    def __len__(self):
        # 返回数据数目
        # ...

该类主要是重写两个方法,第一个是__getitem__。该方法是Python内置的操作符方法,对应的操作符是索引操作符[],通过输入整数数据索引(0到N-1之间),返回具体的一条数据记录。另一个是__len__,该方法返回数据的总数。

3. torchvision工具包的使用

以DatasetFolder类为例,如果数据集存储在一个目录下,每个目录有很多子目录,子目录的个数是图片类的数目,每个子类目都存储着多张图片,且这些图片都属于某一个类。DatasetFolder类继承了VisionDataset类,而Vision类的主要目的是存储数据集所在的根目录,以及训练数据和预测目标的变换函数(transform和target_transform)。该方法主要是__getitem__实现,该方法传入一个index,根据index从self.samples取得一条数据记录,得到数据记录的路径(path)和预测目标(target),然后使用loader来对数据进行载入,并使用self.transform和self.target_transform对数据进行转换,最后返回变换以后的数据和预测目标。

from torch.utils import data

class VisionDataset(data.Dataset):
    def __init__(self, root, transforms=Noen, transform=None, target_transform=None):
        # ...
    
    def __getitem__(self, index):
        raise NotImplementedError
        
    def __len__(self):
        raise NotImplementedError
        
# 重写VisionDataset类        
class DatasetFolder(VisionDataset):
    def __init__(self, root, loader, extensions=None, transform=None,
                target_transform=None, is_valid_file=None):
        super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform)
        classes, class_to_idx = self._find_classes(self.root)
        self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
        self.loader = loader
        # ...
        
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            sample = self.target_transform(sample)
        return sample, target
    
    def __len__(self):
        return len(self.samples)

4. 可迭代类型的数据集

相比映射类型的数据集,可迭代类型数据集并不需要实现__getitem__方法或者__len__方法,它本身更像一个Python迭代器。
torch.utils.IterableDataset类的构造方法:

import torch

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
            assert end > start, \ 
    'this example code only works with end >= start'
            self.start = start
            self.end = end 
        
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # 单进程数据载入
            iter_start = self.start
            iter_end = self.end
        else:  # 多进程,分割数据
            per_worker = int(math.ceil((self.end = self.start) / float(worker_infor.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        return iter(range(iter_start, iter_endr))
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

饕餮&化骨龙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值