DataLoader与Dataset

\quad \quad

\quad \quad 在处理任何机器学习问题之前都需要数据读取,并进行预处理。Pytorch提供了许多方法使得数据读取和预处理变得很容易。

1.以“人民币二分类”为例 \quad

\quad \quad 训练一套机器模型实现对第四套人民币中100元和1元的二分类。 \quad

先回顾机器学习5大步骤: \quad

其中数据部分又包括: \quad

∙ \bullet 数据收集:样本和标签。 \quad

∙ \bullet 数据划分:训练集、验证集和测试集. \quad

∙ \bullet 数据读取:对应于 PyTorch 的 DataLoader。其中 DataLoader 包括 Sampler 和 DataSet。Sampler 的功能是生成索引, DataSet 是根据生成的索引读取样本以及标签。 \quad

∙ \bullet 数据预处理:对应于 PyTorch 的 transforms。 \quad

接下来主要学习DataLoader and Dataset \quad

\quad \quad torch.utils.data.Dataset()是代表自定义数据集方法的抽象类,你可以自己定义你的数据类继承这个抽象类,非常简单,只需要定义__len__()__getitem__()这两个方法就可以。
\quad \quad 通过继承torch.utils.data.Dataset()的这个抽象类,我们可以自定义好我们需要的数据类。当我们通过迭代的方式来取得每一个数据,但是这样很难实现取batch,shuffle或者多线程读取数据,所以Pytorch还提供了一个简单的方法来做这件事情,通过torch.utils.data.DataLoader()类来定义一个新的迭代器,用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
\quad \quad 总之,通过torch.utils.data.Dataset()torch.utils.data.DataLoader()这两个类,使数据的读取变得非常简单,快捷。

1.1 DataLoader(torch.utils.data.DataLoader) \quad

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None)

功能:构建可迭代的数据装载器 \quad

dataset: Dataset类,决定数据从哪读取及如何读取 \quad

batchsize : 批大小 \quad

num_works: 是否多进程读取数据 \quad

shuffle : 每个epoch是否乱序 \quad

drop_last :当样本数不能被batchsize整除时,是否舍弃最后一批数据 \quad

理解drop_last 需要知道: \quad

Epoch : 所有训练样本都已输入到模型中,称为一个Epoch; \quad

Iteration :一批样本输入到模型中,称之为一个Iteration; \quad

Batchsize :批大小,决定一个Epoch有多少个Iteration; \quad

样本总数:80, Batchsize:8 \quad

1 Epoch = 10 Iteration \quad

样本总数:87, Batchsize:8 \quad

1 Epoch = 10 Iteration–drop_last = True(丢弃最后7个样本,只有10个Iteration) \quad

1 Epoch = 11 Iteration–drop_last = False(不丢弃最后7个样本,有11个Iteration) \quad

1.2Dataset(torch.utils.data.Dataset) \quad
class Dataset(object):
  def __getitem__(self, index):
     raise NotImplementedError
  def __add__(self, other):
     return ConcatDataset([self, other])

功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem()__ 方法和__len__()方法 。 \quad

__getitem__() \quad
getitem :接收一个索引,返回索引对应的样本和标签,这是我们自己需要实现的逻辑。 \quad

__len__()
len:返回所有样本的数量。 \quad

数据读取包含3个方面: \quad

∙ \bullet 读取哪些数据 :每个 Iteration 读取一个 Batchsize 大小的数据,每个 Iteration 应该读取哪些数据。 \quad

∙ \bullet 从哪里读取数据 :如何找到硬盘中的数据,应该在哪里设置文件路径参数。 \quad

∙ \bullet 如何读取数据 :不同的文件需要使用不同的读取方法和库。 \quad

\quad \quad 组织文件路径结构如下,有两类人民币图片:1 元和 100 元,每一类各有 100 张图片。 \quad

在这里插入图片描述

\quad \quad 首先划分数据集为训练集、验证集和测试集,比例为 8:1:1。 \quad

\quad \quad 数据划分好后的路径构造如下: \quad
在这里插入图片描述

\quad \quad 实现读取数据的Dataset,编写一个get_img_info()方法,读取每一个图片的路径和对应的标签,组成一个元组,再把所有的元组作为list存放到self.data_info变量中,这里需要注意的是标签需要映射到0 开始的整数:rmb_label = {"1": 0, "100": 1}

   @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        # data_dir 是训练集、验证集或者测试集的路径
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            # dirs ['1', '100']
            for sub_dir in dirs:
                # 文件列表
                img_names = os.listdir(os.path.join(root, sub_dir))
                # 取出 jpg 结尾的文件
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    # 图片的绝对路径
                    path_img = os.path.join(root, sub_dir, img_name)
                    # 标签,这里需要映射为 0、1 两个类别
                    label = rmb_label[sub_dir]
                    # 保存在 data_info 变量中
                    data_info.append((path_img, int(label)))
        return data_info

\quad \quad 然后在Dataset 的初始化函数中调用get_img_info() 方法。

def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.data_info = self.get_img_info(data_dir)
        self.transform = transform

\quad \quad 然后在__getitem__() 方法中根据index 读取self.data_info中路径对应的数据,并在这里做 transform 操作,返回的是样本和标签。

 def __getitem__(self, index):
        # 通过 index 读取样本
        path_img, label = self.data_info[index]
        # 注意这里需要 convert('RGB')
        img = Image.open(path_img).convert('RGB')     # 0~255
        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等
        # 返回是样本和标签
        return img, label

\quad \quad __len__() 方法中返回self.data_info的长度,即为所有样本的数量。

# 返回所有样本的数量
    def __len__(self):
        return len(self.data_info)

在train_lenet.py中,分 5 步构建模型

1.设置数据: \quad

\quad \quad 首先定义训练集、验证集、测试集的路径,定义训练集和测试集的transforms。然后构建训练集和验证集的RMBDataset对象,把对应的路径和transforms传进去。再构建DataLoder,设置 batch_size,其中训练集设置shuffle=True,表示每个 Epoch 都打乱样本。

# 构建MyDataset实例train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
                 #valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
# 构建DataLoder
# 其中训练集设置 shuffle=True,表示每个 Epoch 都打乱样本
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
2.构建模型

\quad \quad 这里采用经典的 **LeNet **图片分类网络。

net = LeNet(classes=2)  #二分类
net.initialize_weights()
3.设置损失函数

\quad \quad 这里使用交叉熵损失函数。

criterion = nn.CrossEntropyLoss() #交叉熵损失函数
4.设置优化器

\quad \quad 这里采用 SGD (随机梯度下降)优化器。

optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 设置学习率下降策略
5.迭代训练模型 \quad

\quad \quad 在每一个 epoch 里面,需要遍历 train_loader 取出数据,每次取得数据是一个 batchsize 大小。这里又分为 4 步。第 1 步进行前向传播,第 2 步进行反向传播求导,第 3 步使用optimizer更新权重,第 4 步统计训练情况。每一个 epoch 完成时都需要使用scheduler更新学习率,和计算验证集的准确率、loss。

for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()#训练前加net.train(),测试前加net.val(),一般出现在有BN和Dropout的网络中
    
    # 遍历 train_loader 取数据
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率
    # 每个 epoch 计算验证集得准确率和loss
    ...
    ...

\quad \quad 我们可以看到每个 iteration,我们是从train_loader中取出数据的。

def __iter__(self):
    if self.num_workers == 0:
        return _SingleProcessDataLoaderIter(self)  #单任务处理数据加载
    else:
        return _MultiProcessingDataLoaderIter(self) #多任务并行处理数据加载

\quad \quad 这里我们没有设置多进程,会执行_SingleProcessDataLoaderIter()的方法。我们以_SingleProcessDataLoaderIter()为例。在_SingleProcessDataLoaderIter()里只有一个方法_next_data(),如下:

def _next_data(self):
 index = self._next_index()  # may raise StopIteration
 data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
 if self._pin_memory:
  data = _utils.pin_memory.pin_memory(data)
 return data

\quad \quad 在该方法中,self._next_index()是获取一个 batchsize 大小的 index 列表,代码如下:

def _next_index(self):
    return next(self._sampler_iter)  # may raise StopIteration

\quad \quad 其中调用的sampler类的__iter__()方法返回 batch_size 大小的随机 index 列表。

def __iter__(self):
 batch = []
 for idx in self.sampler:
  batch.append(idx)
  if len(batch) == self.batch_size:
   yield batch
   batch = []
 if len(batch) > 0 and not self.drop_last:
  yield batch

\quad \quad 然后再返回看 dataloader的_next_data()方法:

def _next_data(self):
 index = self._next_index()  # may raise StopIteration
 data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
 if self._pin_memory:
  data = _utils.pin_memory.pin_memory(data)
 return data

\quad \quad 在第二行中调用了self._dataset_fetcher.fetch(index)获取数据。这里会调用_MapDatasetFetcher中的fetch()函数:

def fetch(self, possibly_batched_index):
 if self.auto_collation:
  data = [self.dataset[idx] for idx in possibly_batched_index]
 else:
  data = self.dataset[possibly_batched_index]
 return self.collate_fn(data)

\quad \quad 这里调用了self.dataset[idx],这个函数会调用dataset.__getitem__()方法获取具体的数据,所以__getitem__()方法是我们必须实现的。我们拿到的data是一个 list,每个元素是一个 tunple,每个 tunple 包括样本和标签。所以最后要使用self.collate_fn(data)把 data 转换为两个 list,第一个 元素 是样本的 batch 形式,形状为 [16, 3, 32, 32] (16 是 batch size,[3, 32, 32] 是图片像素);第二个元素是标签的 batch 形式,形状为 [16]。 \quad

\quad \quad 所以在代码中,我们使用inputs, labels = data来接收数据。 \quad

PyTorch 数据读取流程图: \quad

\quad \quad 首先在 for 循环中遍历DataLoader,然后根据是否采用多进程,决定使用单进程或者多进程的DataLoaderIter。在DataLoaderIter里调用Sampler生成Index的 list,再调用DatasetFetcher根据index获取数据。在DatasetFetcher里会调用Dataset__getitem__()方法获取真正的数据。这里获取的数据是一个 list,其中每个元素是 (img, label) 的元组,再使用 collate_fn()函数整理成一个 list,里面包含两个元素,分别是 img 和 label 的tenser

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

bobodareng

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值