pytorch的DataSet类数据加载自我实现

原理

我们之所以要自己写DataSet,是因为原本提供的DataSet无法满足我们的实际需求,此时就需要我们自定义。通过继承 torch.utils.data.Dataset来实现,在继承的时候,需要 override 三个方法:

  • __init__: 用来初始化一些有关操作数据集的参数
  • __getitem__:定义数据获取的方式(包括读取数据,对数据进行变换等),该方法支持从 0 到 len(self)-1的索引。obj[index]等价于obj.__getitem__
  • __len__:获取数据集的大小。len(obj)等价于obj.__len__()

dataset中应尽量只包含只读对象,避免修改任何可变对象。如下面例子中的self.num可能在多进程下出问题:

class BadDataset(Dataset):
	def __init__(self):
		self.datas = range(100)
		self.num = 0 # read data times
	def __getitem__(self, index):
		self.num += 1
		return self.datas[index]

自定义DataSet的框架:

class CustomDataset(data.Dataset):#需要继承data.Dataset
    def __init__(self):
        # TODO
        # 1. Initialize file path or list of file names.
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        #这里需要注意的是,第一步:read one data,是一个data
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0

DataSet创建及使用的完整流程如下:

  • 首先创建好自己的DataSet类,然后创建一个Dataset的对象
  • 将创建好的DataSet对象传入DataLoader中,创建一个 DataLoader的对象
  • 遍历这个DataLoader对象,从而取出自己的数据。

代码示例如下:

dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
    for data in dataloader:
        ....
# 以下两个代码是等价的
for data in dataloader:
    ...
# 等价与
iterr = iter(dataloader)
while True:
    try:
        next(iterr)
    except StopIteration:
        break

 在DataLoader 中,iter(dataloader) 返回的是一个 DataLoaderIter 对象, 这个才是我们一直 next的 对象。

这个DataLoaderIter其实就是DataLoader类的__iter__()方法的返回值:

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, 
                 batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, 
                 drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')

        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    # dataset.__len__() 在 Sampler 中被使用。
                    # 目的是生成一个 长度为 len(dataset) 的 序列索引(随机的)。
                    sampler = RandomSampler(dataset)
                else:
                    # dataset.__len__() 在 Sampler 中被使用。
                    # 目的是生成一个 长度为 len(dataset) 的 序列索引(顺序的)。
                    sampler = SequentialSampler(dataset)
            # Sampler 是个迭代器,一次之只返回一个 索引
            # BatchSampler 也是个迭代器,但是一次返回 batch_size 个 索引
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler

    def __iter__(self):
        return DataLoaderIter(self)

    def __len__(self):
        return len(self.batch_sampler)

DataLoader()的各个参数含义如下:

1.  dataset:加载的数据集,这个从DataSet()函数而来。
2.  batch_size:batch size,设定每次训练迭代时加载的数据量。
3.  shuffle::是否将数据打乱
4.  sampler: 样本抽样
5.  num_workers:使用多进程加载的进程数,0代表不使用多进程,设定多进程可以使得加载数据时更加快速。
6.  collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可

7.  pin_memory:是否将数据(tensor)保存在pin memory区,pin memory中的数据转到GPU中会快一些
8.  drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃,False表示不丢弃。

注意,这个DataLoaderIter中*init(self, loader)*中的loader就是对应的DataLoader类的实例。

class DataLoaderIter(object):
    "Iterates once over the DataLoader's dataset, as specified by the sampler"

    def __init__(self, loader):
        # loader 是 DataLoader 对象
        self.dataset = loader.dataset
        # 这个留在最后一个部分介绍
        self.collate_fn = loader.collate_fn
        self.batch_sampler = loader.batch_sampler
        # 表示 开 几个进程。
        self.num_workers = loader.num_workers
        # 是否使用 pin_memory
        self.pin_memory = loader.pin_memory
        self.done_event = threading.Event()

        # 这样就可以用 next 操作 batch_sampler 了
        self.sample_iter = iter(self.batch_sampler)

        if self.num_workers > 0:
            # 用来放置 batch_idx 的队列,其中元素的是 一个 list,其中放了一个 batch 内样本的索引
            self.index_queue = multiprocessing.SimpleQueue()
            # 用来放置 batch_data 的队列,里面的 元素的 一个 batch的 数据
            self.data_queue = multiprocessing.SimpleQueue()

            # 当前已经准备好的 batch 的数量(可能有些正在 准备中)
            # 当为 0 时, 说明, dataset 中已经没有剩余数据了。
            # 初始值为 0, 在 self._put_indices() 中 +1,在 self.__next__ 中减一
            self.batches_outstanding = 0 
            self.shutdown = False
            # 用来记录 这次要放到 index_queue 中 batch 的 idx
            self.send_idx = 0
            # 用来记录 这次要从的 data_queue 中取出 的 batch 的 idx
            self.rcvd_idx = 0
            # 因为多线程,可能会导致 data_queue 中的 batch 乱序
            # 用这个来保证 batch 的返回 是 idx 升序出去的。
            self.reorder_dict = {}
            # 这个地方就开始 开多进程了,一共开了 num_workers 个进程
            # 执行 _worker_loop , 下面将介绍 _worker_loop
            self.workers = [
                multiprocessing.Process(
                    target=_worker_loop,
                    args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn))
                for _ in range(self.num_workers)]

            for w in self.workers:
                w.daemon = True  # ensure that the worker exits on process exit
                w.start()

            if self.pin_memory:
                in_data = self.data_queue
                self.data_queue = queue.Queue()
                self.pin_thread = threading.Thread(
                    target=_pin_memory_loop,
                    args=(in_data, self.data_queue, self.done_event))
                self.pin_thread.daemon = True
                self.pin_thread.start()

            # prime the prefetch loop
            # 初始化的时候,就将 2*num_workers 个 (batch_idx, sampler_indices) 放到 index_queue 中。
            for _ in range(2 * self.num_workers):
                self._put_indices()

    def __len__(self):
        return len(self.batch_sampler)

    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch

        # check if the next sample has already been generated
        if self.rcvd_idx in self.reorder_dict:
            batch = self.reorder_dict.pop(self.rcvd_idx)
            return self._process_next_batch(batch)

        if self.batches_outstanding == 0:
            # 说明没有 剩余 可操作数据了, 可以停止 worker 了
            self._shutdown_workers()
            raise StopIteration

        while True:
            # 这里的操作就是 给 乱序的 data_queue 排一排 序
            assert (not self.shutdown and self.batches_outstanding > 0)
            idx, batch = self.data_queue.get()
            # 一个 batch 被 返回,batches_outstanding -1
            self.batches_outstanding -= 1
            if idx != self.rcvd_idx:
                # store out-of-order samples
                self.reorder_dict[idx] = batch
                continue
            # 返回的时候,再向 indice_queue 中 放下一个 (batch_idx, sample_indices)
            return self._process_next_batch(batch)

    next = __next__  # Python 2 compatibility

    def __iter__(self):
        return self

    def _put_indices(self):
        assert self.batches_outstanding < 2 * self.num_workers
        indices = next(self.sample_iter, None)
        if indices is None:
            return
        self.index_queue.put((self.send_idx, indices))
        self.batches_outstanding += 1
        self.send_idx += 1

    def _process_next_batch(self, batch):
        self.rcvd_idx += 1
        # 放下一个 (batch_idx, sample_indices)
        self._put_indices()
        if isinstance(batch, ExceptionWrapper):
            raise batch.exc_type(batch.exc_msg)
        return batch

    def __getstate__(self):
        # TODO: add limited pickling support for sharing an iterator
        # across multiple threads for HOGWILD.
        # Probably the best way to do this is by moving the sample pushing
        # to a separate thread and then just sharing the data queue
        # but signalling the end is tricky without a non-blocking API
        raise NotImplementedError("DataLoaderIterator cannot be pickled")

    def _shutdown_workers(self):
        if not self.shutdown:
            self.shutdown = True
            self.done_event.set()
            for _ in self.workers:
                # shutdown 的时候, 会将一个 None 放到 index_queue 中
                # 如果 _worker_loop 获得了这个 None, _worker_loop 将会跳出无限循环,将会结束运行
                self.index_queue.put(None)

    def __del__(self):
        if self.num_workers > 0:
            self._shutdown_workers()

DatasetDataloaderDataLoaderIter是层层封装的关系,最终在内部使用DataLoaderIter进行迭代。

实现

本人使用自己的数据集实现的一个DataSet类,如下所图所示:

class TrainsetLoader(Dataset):
    def __init__(self, trainset_dir_hr, trainset_dir_lr,upscale_factor, patch_size, n_iters):
        super(TrainsetLoader).__init__()
        self.trainset_dir_hr = trainset_dir_hr
        self.trainset_dir_lr = trainset_dir_lr
        self.upscale_factor = upscale_factor
        self.patch_size = patch_size
        self.n_iters = n_iters
        self.video_list = os.listdir(trainset_dir_hr)
    def __getitem__(self, idx):
        idx_video = random.randint(0, self.video_list.__len__()-1)
        idx_frame = random.randint(1, 98)    #1-98之间的随机数,包含左右边界
        hr_dir = self.trainset_dir_hr + '/' + self.video_list[idx_video]
        lr_dir = self.trainset_dir_lr + '/' + self.video_list[idx_video]
        # read HR & LR frames
        LR0 = Image.open(lr_dir + '/' + str(idx_frame-1).rjust(8,'0') + '.png')
        LR1 = Image.open(lr_dir + '/' + str(idx_frame).rjust(8,'0') + '.png')
        LR2 = Image.open(lr_dir + '/' + str(idx_frame+1).rjust(8,'0') + '.png')
        HR0 = Image.open(hr_dir + '/' + str(idx_frame-1).rjust(8,'0')+ '.png')
        HR1 = Image.open(hr_dir + '/' + str(idx_frame).rjust(8,'0') + '.png')
        HR2 = Image.open(hr_dir + '/' + str(idx_frame+1).rjust(8,'0') + '.png')

        LR0 = np.array(LR0, dtype=np.float32) / 255.0
        LR1 = np.array(LR1, dtype=np.float32) / 255.0
        LR2 = np.array(LR2, dtype=np.float32) / 255.0
        HR0 = np.array(HR0, dtype=np.float32) / 255.0
        HR1 = np.array(HR1, dtype=np.float32) / 255.0
        HR2 = np.array(HR2, dtype=np.float32) / 255.0
        # extract Y channel for LR inputs
        HR0 = rgb2y(HR0)
        HR1 = rgb2y(HR1)
        HR2 = rgb2y(HR2)
        LR0 = rgb2y(LR0)
        LR1 = rgb2y(LR1)
        LR2 = rgb2y(LR2)
        # crop patchs randomly
        HR0, HR1, HR2, LR0, LR1, LR2 = random_crop(HR0, HR1, HR2, LR0, LR1, LR2, self.patch_size, self.upscale_factor)

        HR0 = HR0[:, :, np.newaxis]
        HR1 = HR1[:, :, np.newaxis]
        HR2 = HR2[:, :, np.newaxis]
        LR0 = LR0[:, :, np.newaxis]
        LR1 = LR1[:, :, np.newaxis]
        LR2 = LR2[:, :, np.newaxis]

        HR = np.concatenate((HR0, HR1, HR2), axis=2)
        LR = np.concatenate((LR0, LR1, LR2), axis=2)
        # data augmentation
        LR, HR = augumentation()(LR, HR)
        return toTensor(LR), toTensor(HR)
    def __len__(self):
        return self.n_iters

 

使用的方式如下:

train_set = TrainsetLoader(opt.trainset_dir_hr,opt.trainset_dir_lr, opt.upscale_factor, opt.patch_size, opt.n_iters)
train_loader = DataLoader(train_set, num_workers=4, batch_size=opt.batch_Size, shuffle=True)

for iteration,data in enumerate(train_loader,1):
    print('..................................')
    if cuda:
        batch_lr,batch_hr = Variable(data[0]).cuda(gpus_list[1]),Variable(data[1]).cuda(gpus_list[1])
    ......

 

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值