Pytorch基本操作(2)——Dataset以及Dataloader

1 简介

在学习李沐在B站发布的《动手学深度学习》PyTorch版本教学视频中发现在操作使用PyTorch方面有许多地方看不懂,往往只是“动手”了,没有动脑。所以打算趁着寒假的时间好好恶补、整理一下PyTorch的操作,以便跟上课程。

学习资源:

2 Dataset以及Dataloader

Dataset 以及 Dataloader 是Pytorch中读取数据需要用到的两个重要的类

  • Dataset :提供一种方式去获取数据及其lable,需要我们自己去写。
  • Dataloader :为后面的网络提供不同的数据形式

常见的图片数据集有两种形式:

  • label直接标在文件夹上
  • label另外放在另一个文件夹对应的txt文件中(OCR)
  • label写在图片的名称上

2.1 Dataset

功能:

  1. 如何获取每一个数据及其label。
  2. 告诉我们总共有多少的数据。

我们可以运行 from torch.utils.data import Dataset (其中 utils 有实用工具的意思,理解为工具区)来导入Dataset这个类,同时也可以使用 Dataset??help(Dataset) 来看如何使用Dataset

import torch
from torch.utils.data import Dataset
Dataset??
help(Dataset)
Help on class Dataset in module torch.utils.data.dataset:

class Dataset(typing.Generic)
 |  An abstract class representing a :class:`Dataset`.
 |  
 |  All datasets that represent a map from keys to data samples should subclass
 |  it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
 |  data sample for a given key. Subclasses could also optionally overwrite
 |  :meth:`__len__`, which is expected to return the size of the dataset by many
 |  :class:`~torch.utils.data.Sampler` implementations and the default options
 |  of :class:`~torch.utils.data.DataLoader`.
 |  
 |  .. note::
 |    :class:`~torch.utils.data.DataLoader` by default constructs a index
 |    sampler that yields integral indices.  To make it work with a map-style
 |    dataset with non-integral indices/keys, a custom sampler must be provided.
 |  
 |  Method resolution order:
 |      Dataset
 |      typing.Generic
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]'
 |  
 |  __getitem__(self, index) -> +T_co
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors defined here:
 |  
 |  __dict__
 |      dictionary for instance variables (if defined)
 |  
 |  __weakref__
 |      list of weak references to the object (if defined)
 |  
 |  ----------------------------------------------------------------------
 |  Data and other attributes defined here:
 |  
 |  __orig_bases__ = (typing.Generic[+T_co],)
 |  
 |  __parameters__ = (+T_co,)
 |  
 |  ----------------------------------------------------------------------
 |  Class methods inherited from typing.Generic:
 |  
 |  __class_getitem__(params) from builtins.type
 |  
 |  __init_subclass__(*args, **kwargs) from builtins.type
 |      This method is called when a class is subclassed.
 |      
 |      The default implementation does nothing. It may be
 |      overridden to extend subclasses.

简言之:Dataset是个抽象类,所有的子类都要重写__getitem__方法获取label,也可以重写__len__方法获取长度

2.2 Dataset类代码实战

其中用到的两个库:

  1. PIL中的Image
    PIL(Python Imaging Library):是Python的图像处理库

    • img = Image.open(image_path) 读取对应路径的图片为一个变量
    • img.show() 使用系统默认图片打开方式打开此图片
  2. os:operating system

    • os.path.join(root_dir, label_dir):将两个路径连起来,这个函数会根据操作系统自动调整路径的语法;其中dir = directory目录
    • os.listdir(dir_path):顾名思义,就是把括号里面的目录路径中的“所有文件的路径”生成一个列表,可以用类似a[0]的语句取出 对应图片的路径
from torch.utils.data import Dataset
from PIL import Image
import os

2.2.1 创建类

图片数据集格式为:label直接标在文件夹上

class MyData(Dataset):
    
    def __init__(self, root_dir, label_dir):
        # 初始化函数,为后面的getitem和next方法提供所需要的量
        self.root_dir = root_dir # 根目录的路径;self可以理解为当前类内部的一个全局变量
        self.label_dir = label_dir # label目录的路径,因为下一行要合起来,并且label名是文件夹名,所以这里的label_dir可以直接取对应的label,如:"ants"
        self.path = os.path.join(root_dir, label_dir) # 将两个路径连起来
        self.img_path = os.listdir(self.path) # 获取该路径下所有文件的路径列表
        
    def __getitem__(self, idx):
        """获取数据集中的每一个图片,输入索引,得到对应的图片"""
        # idx是index索引的缩写
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.path, img_name)
        image = Image.open(img_item_path)
        label = self.label_dir
        return image, label
    
    def __len__(self):
        return len(self.img_path)

2.2.2 创建个实例看看

对应的蜜蜂蚂蚁图片识别数据集见小土堆B站视频简介

root_dir = r"F:\Data and code\data\蚂蚁蜜蜂数据\hymenoptera_data\hymenoptera_data\train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)

取一个试试看到底能不能取出来单独数据

img, label = ants_dataset[0]
# img.show()
img

在这里插入图片描述

2.2.3 将数据集格式转化为txt存放label格式

"""下面的代码是将本节中所提到的格式转化成使用txt文档存放label的格式"""
"""这种格式就是新建一个ant_label文件夹,其中放的都是.txt文件。每一个文件的名字都是对应图片的名字,文件的内容则是对应的label"""
'''说实话还没看'''

root_dir = r"F:\Data and code\data\蚂蚁蜜蜂数据\hymenoptera_data\hymenoptera_data\train"
target_dir = "ants_image"
img_path_list = os.listdir(os.path.join(root_dir, target_dir))
label = target_dir.split('_')[0]
out_dir = "ants_label"
for i in img_path_list:
    file_name = i.split('.jpg')[0]
    with open(os.path.join(root_dir, out_dir, "{}.txt".format(file_name)), 'w') as f:
        f.write(label)
PyTorch中,数据读取是构建深度学习模型的重要一环。为了高效处理大规模数据集,PyTorch提供了三个主要的工具:DatasetDataLoader和TensorDatasetDataset是一个抽象类,用于自定义数据集。我们可以继承Dataset类,并重写其中的__len__和__getitem__方法来实现自己的数据加载逻辑。__len__方法返回数据集的大小,而__getitem__方法根据给定的索引返回样本和对应的标签。通过自定义Dataset类,我们可以灵活地处理各种类型的数据集。 DataLoader是数据加载器,用于对数据集进行批量加载。它接收一个Dataset对象作为输入,并可以定义一些参数例如批量大小、是否乱序等。DataLoader能够自动将数据集划分为小批次,将数据转换为Tensor形式,然后通过迭代器的方式供模型训练使用。DataLoader在数据准备和模型训练的过程中起到了桥梁作用。 TensorDataset是一个继承自Dataset的类,在构造时将输入数据和目标数据封装成Tensor。通过TensorDataset,我们可以方便地处理Tensor格式的数据集。TensorDataset可以将多个Tensor按行对齐,即将第i个样本从各个Tensor中取出,构成一个新的Tensor作为数据集的一部分。这对于处理多输入或者多标签的情况非常有用。 总结来说,Dataset提供了自定义数据集的接口,DataLoader提供了批量加载数据集的能力,而TensorDataset则使得我们可以方便地处理Tensor格式的数据集。这三个工具的配合使用可以使得数据处理变得更加方便和高效。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值