torch 的dataloader 的核心函数

官方解释:Dataloader 组合了 dataset & sampler,提供在数据上的 iterable

主要参数:

1、dataset:这个dataset一定要是torch.utils.data.Dataset本身或继承自它的类

里面最主要的方法是 getitem(self, index) 用于根据index索引来取数据的

2、batch_size:每个batch批次要返回几条数据

3、shuffle:是否打乱数据,默认False

4、sampler:sample strategy,数据选取策略,有它就不用shuffle了,因为sample本身就是一种无序。这个sampler貌似也一定要是torch.utils.data.sampler.Sampler本身或继承自它的类。

以下内容和代码都是基于torch的老版本,虽然老,但是其思想具有参考意义:

__ iter __(self) 方法(核心)

新版本1.10中该方法在dataloader类(iter方法)——》dataloaderiter类(next方法)——》batchsampler类(iter方法,源码和下面这个一样)
里面最主要的方法是__iter__(self) 方法,每次调用 iter 只能获取 batchsize 个数据,也就是一个批次的数据。

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

sampler参数

sampler:sample strategy,数据选取策略,有它就不用shuffle了,因为sample本身就是一种无序。这个sampler貌似也一定要是torch.utils.data.sampler.Sampler本身或继承自它的类。
常用格式:

trainloader = DataLoader(
            ImageDataset(self.dataset.train, transform=self.transform_train),
            # 为传入的数据中的每个id选择config.k个样本
            sampler=ClassUniformlySampler(self.dataset.train, class_position=1, k=config.k),  # 传入的数据中第2维是类别,所以class_position=1
            batch_size=config.p * config.k, num_workers=config.workers,
            # shuffle=True, # 有了ClassUniformlySampler就不用shuffle了
            pin_memory=pin_memory, drop_last=False
        )
 
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
    print("batch_idx: ", batch_idx)
    for i in range(len(pids)):
        print(pids[i], imgs[i].shape)
  • 为什么执行 enumerate 代码,就可以源源不断地返回所需数据??

在执行 trainloader = DataLoader()语句的时候,DataLoder,ImageDataset,ClassUniformlySampler 并没有什么特殊的操作,都仅仅是init初始化了一下。

这里所使用的 ClassUniformlySampler ,是Sampler类的一种,作用是对数据中的所有id仅保留k条数据。因此它在初始化时,生成了一个字典,key为类别,value为属于该类别的所有数据的索引。

一个自定义sampler类:

class ClassUniformlySampler(Sampler):
    '''
    功能:按类别标签随机抽样
    Arguments:
        data_source (Dataset): data_loader to sample from
        class_position (int): which one is used as class
        k (int): sample k images of each class
    '''
 
    def __init__(self, data_source, class_position, k):
 
        self.class_position = class_position
        self.k = k
 
        self.samples = data_source
        self.class_dict = self._tuple2dict(self.samples)  # 返回一个字典,key为类别,value为属于该类别的所有数据的索引
 
    def __iter__(self):
        self.sample_list = self._generate_list(self.class_dict)
        return iter(self.sample_list)
 
    def __len__(self):
        return len(self.sample_list)
 
    def _tuple2dict(self, inputs):
        '''
        :param inputs: list with tuple elemnts, [(image_path1, class_index_1), (imagespath_2, class_index_2), ...]
        :return: dict, {class_index_i: [samples_index1, samples_index2, ...]}
        '''
        dict = {}
        for index, each_input in enumerate(inputs):
            class_index = each_input[self.class_position]
            if class_index not in list(dict.keys()):
                dict[class_index] = [index]
            else:
                dict[class_index].append(index)
        return dict
 
    def _generate_list(self, dict):
        '''
        :param dict:  {class_index_i: [samples_index1, samples_index2, ...]}
        :return:[samples_index1, samples_index3, samples_index2, ...]
        '''
 
        sample_list = []
 
        dict_copy = dict.copy()
        keys = list(dict_copy.keys()) #[class_index_0,class_index_1,...]
        random.shuffle(keys)
        for key in keys:
            value = dict_copy[key] #[samples_index1, samples_index2, ...]
            if len(value) >= self.k:
                random.shuffle(value)
                sample_list.extend(value[0: self.k])
            else:
                value = value * self.k
                random.shuffle(value)
                sample_list.extend(value[0: self.k])
 
        return sample_list

在第一次执行 for batch_idx, (imgs, pids, _ ) in enumerate(trainloader) 时,首先调用的是sampler.__ iter __() 方法,对所有数据进行采样后返回一个存储了所采样的数据的索引列表,并用iter(sampler_list) 作为返回。iter方法在一开始已经提及,每次调用只能返回 batchsize 条数据。

随后,Dataset就上场了,它只需根据 sampler_list 中的索引挨个取数据即可,取到第 batchsize 条数据的时候,iter 就不会再让它取了。

这之后,每一次执行 for batch_idx, (imgs, pids, _) in enumerate(trainloader) 时,Dataset 都会从上一次iter中断的数据索引处继续取 batchsize 个数据,直到取完所有数据。

注:因为在采样时,已经打乱了原有的数据顺序,对于采样后返回的sample_list,即使按顺序取,也不是真的有序,而且这样还可以防止重复抽取到相同数据,数据取完就可以结束一个epoch。

默认的dataloader类

class DataLoader(object):
"""
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.

    Arguments:
        dataset (Dataset): dataset from which to load the data.
        batch_size (int, optional): how many samples per batch to load
            (default: 1).
        shuffle (bool, optional): set to ``True`` to have the data reshuffled
            at every epoch (default: False).
        sampler (Sampler, optional): defines the strategy to draw samples from
            the dataset. If specified, ``shuffle`` must be False.
        batch_sampler (Sampler, optional): like sampler, but returns a batch of
            indices at a time. Mutually exclusive with batch_size, shuffle,
            sampler, and drop_last.
        num_workers (int, optional): how many subprocesses to use for data
            loading. 0 means that the data will be loaded in the main process.
            (default: 0)
        collate_fn (callable, optional): merges a list of samples to form a mini-batch.
        pin_memory (bool, optional): If ``True``, the data loader will copy tensors
            into CUDA pinned memory before returning them.
        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
            if the dataset size is not divisible by the batch size. If ``False`` and
            the size of dataset is not divisible by the batch size, then the last batch
            will be smaller. (default: False)
        timeout (numeric, optional): if positive, the timeout value for collecting a batch
            from workers. Should always be non-negative. (default: 0)
        worker_init_fn (callable, optional): If not None, this will be called on each
            worker subprocess with the worker id as input, after seeding and before data
            loading. (default: None)
"""

    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
                 timeout=0, worker_init_fn=None):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.collate_fn = collate_fn
        self.pin_memory = pin_memory
        self.drop_last = drop_last
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn

        if timeout < 0:
            raise ValueError('timeout option should be non-negative')

        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler is mutually exclusive with '
                                 'batch_size, shuffle, sampler, and drop_last')

        if sampler is not None and shuffle:
            raise ValueError('sampler is mutually exclusive with shuffle')

        if self.num_workers < 0:
            raise ValueError('num_workers cannot be negative; '
                             'use num_workers=0 to disable multiprocessing.')

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler

    def __iter__(self):
        return DataLoaderIter(self)

    def __len__(self):
        return len(self.batch_sampler)

当代码运行到要从torch.utils.data.DataLoader类生成的对象中取数据的时候,比如:
train_data=torch.utils.data.DataLoader(…)
for i, (input, target) in enumerate(train_data):

就会调用DataLoader类的__iter__方法,__iter__方法就一行代码:return DataLoaderIter(self),输入正是DataLoader类的属性。因此当调用__iter__方法的时候就牵扯到另外一个类:DataLoaderIter,接下来介绍。

dataloaderiter类

源码(初始化部分):

前面部分都是一些赋值操作,比较特殊的是self.sample_iter = iter(self.batch_sampler),得到的self.sample_iter可以通过next(self.sample_iter)来获取batch size个数据的index。self.rcvd_idx表示读取到的一个batch数据的index,初始化为0,该值在迭代读取数据的时候会用到。
self.num_workers语句,如果设置为多进程读取数据,那么就会采用队列的方式来读,如果不是采用多进程来读取数据,那就采用普通方式来读。

class DataLoaderIter(object):
    "Iterates once over the DataLoader's dataset, as specified by the sampler"

    def __init__(self, loader):
        self.dataset = loader.dataset
        self.collate_fn = loader.collate_fn
        self.batch_sampler = loader.batch_sampler
        self.num_workers = loader.num_workers
        self.pin_memory = loader.pin_memory and torch.cuda.is_available()
        self.timeout = loader.timeout
        self.done_event = threading.Event()

        self.sample_iter = iter(self.batch_sampler)

        if self.num_workers > 0:
            self.worker_init_fn = loader.worker_init_fn
            self.index_queue = multiprocessing.SimpleQueue()
            self.worker_result_queue = multiprocessing.SimpleQueue()
            self.batches_outstanding = 0
            self.worker_pids_set = False
            self.shutdown = False
            self.send_idx = 0
            self.rcvd_idx = 0
            self.reorder_dict = {}

            base_seed = torch.LongTensor(1).random_()[0]
            self.workers = [
                multiprocessing.Process(
                    target=_worker_loop,
                    args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
                          base_seed + i, self.worker_init_fn, i))
                for i in range(self.num_workers)]

            if self.pin_memory or self.timeout > 0:
                self.data_queue = queue.Queue()
                self.worker_manager_thread = threading.Thread(
                    target=_worker_manager_loop,
                    args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
                          torch.cuda.current_device()))
                self.worker_manager_thread.daemon = True
                self.worker_manager_thread.start()
            else:
                self.data_queue = self.worker_result_queue

            for w in self.workers:
                w.daemon = True  # ensure that the worker exits on process exit
                w.start()

            _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
            _set_SIGCHLD_handler()
            self.worker_pids_set = True

            # prime the prefetch loop
            for _ in range(2 * self.num_workers):
                self._put_indices()

初始化结束后,就会调用__next__方法,接下来介绍:

next方法

DataLoaderIter类的__next__方法如下,包含3个if语句和1个while语句。

  • 第一个if语句是用来处理self.num_workers等于0的情况,也就是不采用多进程进行数据读取,可以看出在这个if语句中先通过indices = next(self.sample_iter)获取长度为batch size的列表:indices,这个列表的每个值表示一个batch中每个数据的index,每执行一次next操作都会读取一批长度为batch size的indices列表。然后通过self.collate_fn函数将batch size个tuple(每个tuple长度为2,其中第一个值是数据,Tensor类型,第二个值是标签,int类型)封装成一个list,这个list长度为2,两个值都是Tensor,一个是batch size个数据组成的FloatTensor,另一个是batch size个标签组成的LongTensor。所以简单讲self.collate_fn函数就是将batch size个分散的Tensor封装成一个Tensor。batch = pin_memory_batch(batch)中pin_memory_batch函数的作用就是将输入batch的每个Tensor都拷贝到CUDA中,该函数后面会详细介绍。
  • 第二个if语句是判断当前想要读取的batch的index(self.rcvd_idx)是否之前已经读出来过(已读出来的index和batch数据保存在self.reorder_dict字典中,可以结合最后的while语句一起看,因为self.reorder_dict字典的更新是在最后的while语句中),如果之前已经读取过了,就根据这个index从reorder_dict字典中弹出对应的数据。最后返回batch数据的时候是 return self._process_next_batch(batch),该方法后面会详细介绍。主要做是获取下一个batch的数据index信息。
  • 第三个if语句,self.batches_outstanding的值在前面初始中调用self._put_indices()方法时修改了,所以假设你的进程数self.num_workers设置为3,那么这里self.batches_outstanding就是3*2=6,可具体看self._put_indices()方法。
  • 最后的while循环就是真正用来从队列中读取数据的操作,最主要的就是idx, batch = self._get_batch(),通过调用_get_batch()方法来读取,后面有介绍,简单讲就是调用了队列的get方法得到下一个batch的数据,得到的batch一般是长度为2的列表,列表的两个值都是Tensor,分别表示数据(是一个batch的)和标签。_get_batch()方法除了返回batch数据外,还得到另一个输出:idx,这个输出表示batch的index,这个if idx != self.rcvd_idx条件语句表示如果你读取到的batch的index不等于当前想要的index:selg,rcvd_idx,那么就将读取到的数据保存在字典self.reorder_dict中:self.reorder_dict[idx] = batch,然后继续读取数据,直到读取到的数据的index等于self.rcvd_idx。
    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch

        # check if the next sample has already been generated
        if self.rcvd_idx in self.reorder_dict:
            batch = self.reorder_dict.pop(self.rcvd_idx)
            return self._process_next_batch(batch)

        if self.batches_outstanding == 0:
            self._shutdown_workers()
            raise StopIteration

        while True:
            assert (not self.shutdown and self.batches_outstanding > 0)
            idx, batch = self._get_batch()
            self.batches_outstanding -= 1
            if idx != self.rcvd_idx:
                # store out-of-order samples
                self.reorder_dict[idx] = batch
                continue
            return self._process_next_batch(batch)

_get_batch方法

DataloaderIter类的_get_batch方法。主要根据是否设置了超时时间来操作,如果超过指定的超时时间后没有从队列中读到数据就报错,如果不设置超时时间且一致没有从队列中读到数据,那么就会一直卡着且不报错,这部分是PyTorch后来修的一个bug。

    def _get_batch(self):
        if self.timeout > 0:
            try:
                return self.data_queue.get(True, self.timeout)
            except queue.Empty:
                raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
        else:
            return self.data_queue.get()

_process_next_batch方法

DataLoaderIter类的_process_next_batch方法。首先对self.rcvd_idx进行加一,也就是更新下下一个要读取的batch数据的index。然后调用_put_indices()方法获取下一个batch的每个数据的index。

   def _process_next_batch(self, batch):
        self.rcvd_idx += 1
        self._put_indices()
        if isinstance(batch, ExceptionWrapper):
            raise batch.exc_type(batch.exc_msg)
        return batch

_put_indices方法

DataLoaderIter类的_put_indices方法。该方法主要实现从self.sample_iter中读取下一个batch数据中每个数据的index:indices = next(self.sample_iter, None),注意这里的index和前面idx是不一样的,这里的index是一个batch中每个数据的index,idx是一个batch的index;然后将读取到的index通过调用queue对象的put方法压到队列self.index_queue中:self.index_queue.put((self.send_idx, indices))

def _put_indices(self):
        assert self.batches_outstanding < 2 * self.num_workers
        indices = next(self.sample_iter, None)
        if indices is None:
            return
        self.index_queue.put((self.send_idx, indices))
        self.batches_outstanding += 1
        self.send_idx += 1

pin_memory_batch方法

pin_memory_batch函数不是定义在DataLoader类或DataLoaderIter类中。该函数主要是对batch中的Tensor执行batch.pin_memory()操作,这里的很多条件语句只是用来判断batch的类型,假如batch是一个列表,列表中的每个值是Tensor,那么就会执行 elif isinstance(batch, collections.Sequence):这个条件,从而遍历该列表中的每个Tensor,然后执行第一个条件语句的内容: return batch.pin_memory()

def pin_memory_batch(batch):
    if torch.is_tensor(batch):
        return batch.pin_memory()
    elif isinstance(batch, string_classes):
        return batch
    elif isinstance(batch, collections.Mapping):
        return {k: pin_memory_batch(sample) for k, sample in batch.items()}
    elif isinstance(batch, collections.Sequence):
        return [pin_memory_batch(sample) for sample in batch]
    else:
        return batch

参考:
https://blog.csdn.net/m0_37738114/article/details/120780544
https://blog.csdn.net/u014380165/article/details/79058479

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Shashank497

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值