一、DataLoader()类
torch.utils.data.DataLoader():构建可迭代的数据装载器.
class DataLoader(object):
def __init__(self,
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0, collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None)
常用:
DataLoader( dataset, => Dataset类,决定数据从哪儿读取及如何读取.
batch_size=1, => 批大小.
shuffle=False, => 每个epoch是否乱序.
num_workers=0, => 是否多进程读取数据.
drop_last=False, => 当Epoch/BatchSize不为整数时,是否丢弃最后一批Iteration数据.
补充:
Epoch:全部训练样本。
Iteration:一批样本。
BatchSize:一批样本的大小。
有:Iteration = Epoch/BatchSize + 0/1(取决于drop_last的设定)
二、Dataset()类
torch.utils.data.Dataset():Dataset抽象类,子类必须复写 getitem()函数.
class DataLoader(object):
def __getitem__(self, index): # 接受索引,返回样本
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
使用举例:
class RMBDataset(Dataset): # 子类RMBDataset继承父类Dataset
def __init__(self, data_dir, transform=None): # 复写构造函数,此处传入由transforms.Compose()函数返回的,包含图片变换操作的列表list()。
self.label_name = {'1': 0, '100': 1}
self.data_info = self.get_img_info(data_dir) # data_info为自定义的函数
self.transform = transform
def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img, label