官方解释:Dataloader 组合了 dataset & sampler,提供在数据上的 iterable
主要参数:
1、dataset:这个dataset一定要是torch.utils.data.Dataset本身或继承自它的类
里面最主要的方法是 getitem(self, index) 用于根据index索引来取数据的
2、batch_size:每个batch批次要返回几条数据
3、shuffle:是否打乱数据,默认False
4、sampler:sample strategy,数据选取策略,有它就不用shuffle了,因为sample本身就是一种无序。这个sampler貌似也一定要是torch.utils.data.sampler.Sampler本身或继承自它的类。
以下内容和代码都是基于torch的老版本,虽然老,但是其思想具有参考意义:
__ iter __(self) 方法(核心)
新版本1.10中该方法在dataloader类(iter方法)——》dataloaderiter类(next方法)——》batchsampler类(iter方法,源码和下面这个一样)
里面最主要的方法是__iter__(self) 方法,每次调用 iter 只能获取 batchsize 个数据,也就是一个批次的数据。
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
sampler参数
sampler:sample strategy,数据选取策略,有它就不用shuffle了,因为sample本身就是一种无序。这个sampler貌似也一定要是torch.utils.data.sampler.Sampler本身或继承自它的类。
常用格式:
trainloader = DataLoader(
ImageDataset(self.dataset.train, transform=self.transform_train),
# 为传入的数据中的每个id选择config.k个样本
sampler=ClassUniformlySampler(self.dataset.train, class_position=1, k=config.k), # 传入的数据中第2维是类别,所以class_position=1
batch_size=config.p * config.k, num_workers=config.workers,
# shuffle=True, # 有了ClassUniformlySampler就不用shuffle了
pin_memory=pin_memory, drop_last=False
)
for batch_idx, (imgs, pids, _) in enumerate(trainloader):
print("batch_idx: ", batch_idx)
for i in range(len(pids)):
print(pids[i], imgs[i].shape)
- 为什么执行 enumerate 代码,就可以源源不断地返回所需数据??
在执行 trainloader = DataLoader()语句的时候,DataLoder,ImageDataset,ClassUniformlySampler 并没有什么特殊的操作,都仅仅是init初始化了一下。
这里所使用的 ClassUniformlySampler ,是Sampler类的一种,作用是对数据中的所有id仅保留k条数据。因此它在初始化时,生成了一个字典,key为类别,value为属于该类别的所有数据的索引。
一个自定义sampler类:
class ClassUniformlySampler(Sampler):
'''
功能:按类别标签随机抽样
Arguments:
data_source (Dataset): data_loader to sample from
class_position (int): which one is used as class
k (int): sample k images of each class
'''
def __init__(self, data_source, class_position, k):
self.class_position = class_position
self.k = k
self.samples = data_source
self.class_dict = self._tuple2dict(self.samples) # 返回一个字典,key为类别,value为属于该类别的所有数据的索引
def __iter__(self):
self.sample_list = self._generate_list(self.class_dict)
return iter(self.sample_list)
def __len__(self):
return len(self.sample_list)
def _tuple2dict(self, inputs):
'''
:param inputs: list with tuple elemnts, [(image_path1, class_index_1), (imagespath_2, class_index_2), ...]
:return: dict, {class_index_i: [samples_index1, samples_index2, ...]}
'''
dict = {}
for index, each_input in enumerate(inputs):
class_index = each_input[self.class_position]
if class_index not in list(dict.keys()):
dict[class_index] = [index]
else:
dict[class_index].append(index)
return dict
def _generate_list(self, dict):
'''
:param dict: {class_index_i: [samples_index1, samples_index2, ...]}
:return:[samples_index1, samples_index3, samples_index2, ...]
'''
sample_list = []
dict_copy = dict.copy()
keys = list(dict_copy.keys()) #[class_index_0,class_index_1,...]
random.shuffle(keys)
for key in keys:
value = dict_copy[key] #[samples_index1, samples_index2, ...]
if len(value) >= self.k:
random.shuffle(value)
sample_list.extend(value[0: self.k])
else:
value = value * self.k
random.shuffle(value)
sample_list.extend(value[0: self.k])
return sample_list
在第一次执行 for batch_idx, (imgs, pids, _ ) in enumerate(trainloader) 时,首先调用的是sampler.__ iter __() 方法,对所有数据进行采样后返回一个存储了所采样的数据的索引列表,并用iter(sampler_list) 作为返回。iter方法在一开始已经提及,每次调用只能返回 batchsize 条数据。
随后,Dataset就上场了,它只需根据 sampler_list 中的索引挨个取数据即可,取到第 batchsize 条数据的时候,iter 就不会再让它取了。
这之后,每一次执行 for batch_idx, (imgs, pids, _) in enumerate(trainloader) 时,Dataset 都会从上一次iter中断的数据索引处继续取 batchsize 个数据,直到取完所有数据。
注:因为在采样时,已经打乱了原有的数据顺序,对于采样后返回的sample_list,即使按顺序取,也不是真的有序,而且这样还可以防止重复抽取到相同数据,数据取完就可以结束一个epoch。
默认的dataloader类
class DataLoader(object):
"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.
Arguments:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: 1).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: False).
sampler (Sampler, optional): defines the strategy to draw samples from
the dataset. If specified, ``shuffle`` must be False.
batch_sampler (Sampler, optional): like sampler, but returns a batch of
indices at a time. Mutually exclusive with batch_size, shuffle,
sampler, and drop_last.
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means that the data will be loaded in the main process.
(default: 0)
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
into CUDA pinned memory before returning them.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: False)
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: 0)
worker_init_fn (callable, optional): If not None, this will be called on each
worker subprocess with the worker id as input, after seeding and before data
loading. (default: None)
"""
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_workers = num_workers
self.collate_fn = collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
self.timeout = timeout
self.worker_init_fn = worker_init_fn
if timeout < 0:
raise ValueError('timeout option should be non-negative')
if batch_sampler is not None:
if batch_size > 1 or shuffle or sampler is not None or drop_last:
raise ValueError('batch_sampler is mutually exclusive with '
'batch_size, shuffle, sampler, and drop_last')
if sampler is not None and shuffle:
raise ValueError('sampler is mutually exclusive with shuffle')
if self.num_workers < 0:
raise ValueError('num_workers cannot be negative; '
'use num_workers=0 to disable multiprocessing.')
if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
def __iter__(self):
return DataLoaderIter(self)
def __len__(self):
return len(self.batch_sampler)
当代码运行到要从torch.utils.data.DataLoader类生成的对象中取数据的时候,比如:
train_data=torch.utils.data.DataLoader(…)
for i, (input, target) in enumerate(train_data):
…
就会调用DataLoader类的__iter__方法,__iter__方法就一行代码:return DataLoaderIter(self),输入正是DataLoader类的属性。因此当调用__iter__方法的时候就牵扯到另外一个类:DataLoaderIter,接下来介绍。
dataloaderiter类
源码(初始化部分):
前面部分都是一些赋值操作,比较特殊的是self.sample_iter = iter(self.batch_sampler),得到的self.sample_iter可以通过next(self.sample_iter)来获取batch size个数据的index。self.rcvd_idx表示读取到的一个batch数据的index,初始化为0,该值在迭代读取数据的时候会用到。
self.num_workers语句,如果设置为多进程读取数据,那么就会采用队列的方式来读,如果不是采用多进程来读取数据,那就采用普通方式来读。
class DataLoaderIter(object):
"Iterates once over the DataLoader's dataset, as specified by the sampler"
def __init__(self, loader):
self.dataset = loader.dataset
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
self.timeout = loader.timeout
self.done_event = threading.Event()
self.sample_iter = iter(self.batch_sampler)
if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
self.index_queue = multiprocessing.SimpleQueue()
self.worker_result_queue = multiprocessing.SimpleQueue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}
base_seed = torch.LongTensor(1).random_()[0]
self.workers = [
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
base_seed + i, self.worker_init_fn, i))
for i in range(self.num_workers)]
if self.pin_memory or self.timeout > 0:
self.data_queue = queue.Queue()
self.worker_manager_thread = threading.Thread(
target=_worker_manager_loop,
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
torch.cuda.current_device()))
self.worker_manager_thread.daemon = True
self.worker_manager_thread.start()
else:
self.data_queue = self.worker_result_queue
for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
w.start()
_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
_set_SIGCHLD_handler()
self.worker_pids_set = True
# prime the prefetch loop
for _ in range(2 * self.num_workers):
self._put_indices()
初始化结束后,就会调用__next__方法,接下来介绍:
next方法
DataLoaderIter类的__next__方法如下,包含3个if语句和1个while语句。
- 第一个if语句是用来处理self.num_workers等于0的情况,也就是不采用多进程进行数据读取,可以看出在这个if语句中先通过indices = next(self.sample_iter)获取长度为batch size的列表:indices,这个列表的每个值表示一个batch中每个数据的index,每执行一次next操作都会读取一批长度为batch size的indices列表。然后通过self.collate_fn函数将batch size个tuple(每个tuple长度为2,其中第一个值是数据,Tensor类型,第二个值是标签,int类型)封装成一个list,这个list长度为2,两个值都是Tensor,一个是batch size个数据组成的FloatTensor,另一个是batch size个标签组成的LongTensor。所以简单讲self.collate_fn函数就是将batch size个分散的Tensor封装成一个Tensor。batch = pin_memory_batch(batch)中pin_memory_batch函数的作用就是将输入batch的每个Tensor都拷贝到CUDA中,该函数后面会详细介绍。
- 第二个if语句是判断当前想要读取的batch的index(self.rcvd_idx)是否之前已经读出来过(已读出来的index和batch数据保存在self.reorder_dict字典中,可以结合最后的while语句一起看,因为self.reorder_dict字典的更新是在最后的while语句中),如果之前已经读取过了,就根据这个index从reorder_dict字典中弹出对应的数据。最后返回batch数据的时候是 return self._process_next_batch(batch),该方法后面会详细介绍。主要做是获取下一个batch的数据index信息。
- 第三个if语句,self.batches_outstanding的值在前面初始中调用self._put_indices()方法时修改了,所以假设你的进程数self.num_workers设置为3,那么这里self.batches_outstanding就是3*2=6,可具体看self._put_indices()方法。
- 最后的while循环就是真正用来从队列中读取数据的操作,最主要的就是idx, batch = self._get_batch(),通过调用_get_batch()方法来读取,后面有介绍,简单讲就是调用了队列的get方法得到下一个batch的数据,得到的batch一般是长度为2的列表,列表的两个值都是Tensor,分别表示数据(是一个batch的)和标签。_get_batch()方法除了返回batch数据外,还得到另一个输出:idx,这个输出表示batch的index,这个if idx != self.rcvd_idx条件语句表示如果你读取到的batch的index不等于当前想要的index:selg,rcvd_idx,那么就将读取到的数据保存在字典self.reorder_dict中:self.reorder_dict[idx] = batch,然后继续读取数据,直到读取到的数据的index等于self.rcvd_idx。
def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch
# check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch)
if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration
while True:
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self._get_batch()
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch
continue
return self._process_next_batch(batch)
_get_batch方法
DataloaderIter类的_get_batch方法。主要根据是否设置了超时时间来操作,如果超过指定的超时时间后没有从队列中读到数据就报错,如果不设置超时时间且一致没有从队列中读到数据,那么就会一直卡着且不报错,这部分是PyTorch后来修的一个bug。
def _get_batch(self):
if self.timeout > 0:
try:
return self.data_queue.get(True, self.timeout)
except queue.Empty:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
else:
return self.data_queue.get()
_process_next_batch方法
DataLoaderIter类的_process_next_batch方法。首先对self.rcvd_idx进行加一,也就是更新下下一个要读取的batch数据的index。然后调用_put_indices()方法获取下一个batch的每个数据的index。
def _process_next_batch(self, batch):
self.rcvd_idx += 1
self._put_indices()
if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch
_put_indices方法
DataLoaderIter类的_put_indices方法。该方法主要实现从self.sample_iter中读取下一个batch数据中每个数据的index:indices = next(self.sample_iter, None),注意这里的index和前面idx是不一样的,这里的index是一个batch中每个数据的index,idx是一个batch的index;然后将读取到的index通过调用queue对象的put方法压到队列self.index_queue中:self.index_queue.put((self.send_idx, indices))
def _put_indices(self):
assert self.batches_outstanding < 2 * self.num_workers
indices = next(self.sample_iter, None)
if indices is None:
return
self.index_queue.put((self.send_idx, indices))
self.batches_outstanding += 1
self.send_idx += 1
pin_memory_batch方法
pin_memory_batch函数不是定义在DataLoader类或DataLoaderIter类中。该函数主要是对batch中的Tensor执行batch.pin_memory()操作,这里的很多条件语句只是用来判断batch的类型,假如batch是一个列表,列表中的每个值是Tensor,那么就会执行 elif isinstance(batch, collections.Sequence):这个条件,从而遍历该列表中的每个Tensor,然后执行第一个条件语句的内容: return batch.pin_memory()
def pin_memory_batch(batch):
if torch.is_tensor(batch):
return batch.pin_memory()
elif isinstance(batch, string_classes):
return batch
elif isinstance(batch, collections.Mapping):
return {k: pin_memory_batch(sample) for k, sample in batch.items()}
elif isinstance(batch, collections.Sequence):
return [pin_memory_batch(sample) for sample in batch]
else:
return batch
参考:
https://blog.csdn.net/m0_37738114/article/details/120780544
https://blog.csdn.net/u014380165/article/details/79058479