Pytorch 中 torch.utils.data 下的 Dataset 与 DataLoader

1. Dataset

     在PyTorch中,‌Dataset是一个基础类,‌用于定义数据集的结构和数据的加载方式。‌通过继承Dataset类并实现特定的方法,‌可以创建自定义的数据集,‌这些方法通常包括获取数据集的大小、‌获取单个样本等。‌

      对于任何数据集,如果它提供了从键到数据样本的映射,在对数据集进行加载时,都应该对该类进行子类化

 Dataset 作为一个抽象类,在继承使用该类时,应该实现两个方法,定义自己的数据集。分别是:

  • __getitem__(self, index):  根据的索引 index,返回对应的数据;索引可以是一个整数,表示按顺序获取样本,也可以是其他方式,如通过文件名获取样本等。
  • __len__:   返回数据集中样本的总数。

2. 自定义Dataset

import torch
from torch.utils.data import Dataset

# 创建自己的 数据集 Dataset
class MyDataset(Dataset):
    # 对数据集 进行初始化 ,包括 数据与标签
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    # 返回数据集的长度
    def __len__(self):
        return len(self.data)
    # 根据索引返回对应的 value 值
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


# 假设data和labels是已经加载或生成的数据和标签
data = [1, 2, 3, 4, 5]  # 示例数据
labels = ['春', '夏', '春', '夏','夏']  # 示例标签

# 初始化数据
dataset = MyDataset(data, labels)

#   根据 索引,返回对应的数据与标签
print(dataset[1])           # (2, '夏')
#   返回数据集的长度
print(len(dataset))         # 5
#   查看数据集中的所有数据
print(dataset.__dict__)     # {'data': [1, 2, 3, 4, 5], 'labels': ['春', '夏', '春', '夏', '夏']}

此外,‌PyTorch还提供了多种内置的数据集类,‌如MNIST、‌CIFAR等,‌这些内置数据集可以直接使用,‌无需自定义。‌

读取 数据集练习:

项目的目录格式:

import os
from PIL import Image
class MyDataset1(Dataset):
    def __init__(self, root_dir, img_dir, label_dir):
        self.root_dir = root_dir
        self.img_dir = img_dir
        self.label_dir = label_dir
        # 由于 linux 与 Windows 的斜杠不同,因此采用该方法 os
        self.img_path = os.path.join(self.root_dir,self.img_dir)
        self.label_path = os.path.join(self.root_dir, self.label_dir)
        # 列出所有的图片名称
        self.img_name = os.listdir(self.img_path)
        # 列出所有的标签名称
        self.label_name = os.listdir(self.label_path)

    def __len__(self):
        return len(self.img_name)

    def __getitem__(self, idx):
        # 获取 图片与标签的相对路径
        img_relative_path = os.path.join(self.img_path,self.img_name[idx])
        lab_relative_path = os.path.join(self.label_path,self.label_name[idx])

        with open(lab_relative_path,"r") as f:
            labels = f.read()
        img  = Image.open(img_relative_path)

        return img, labels


root_dir = "hymenoptera_data/train"
ants_img_dir = "ants_img"
bees_img_dir = "bees_img"
ants_lab_dir = "ants_lab"
bees_lab_dir = "bees_lab"
ants_dataset = MyDataset1(root_dir,ants_img_dir,ants_lab_dir)
bees_dataset = MyDataset1(root_dir,bees_img_dir,bees_lab_dir)

# 获取所有的数据集
train_dataset  = ants_dataset + bees_dataset
# 读取数据集 
print(ants_dataset[0])

print(len(ants_dataset))

方式:

        1. 首先 通过 init() 函数, 传入对应的数据集的各个文件夹名称;在 init() 函数中 通过 os.path.join() 函数,对文件夹名称进行拼接,获取所需要的 相对地址;最后可以通过 os.listdir() 函数, 获取文件夹中的文件名称列表。

        2. 在  getitem() 函数中,通过 index 索引, 找到对应的 文件名称,拼接成该文件的 相对地址,并通过对应的函数进行打开;最后将打开的结果 通过 return 返回。

3. torchversion

介绍

torchvision是pytorch的一个图形库,专注于计算机视觉任务,包括预训练模型数据加载预处理模型构建和评估等。

torchvision的设计初衷是为了简化使用PyTorch进行计算机视觉任务的开发和研究,减少重复性代码的编写,提高开发效率。

 使用

import torchvision
from torch.utils.tensorboard import SummaryWriter


# Compose 对图像进行处理
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform,download=True)
print(train_set[0])  #(<PIL.Image.Image image mode=RGB size=32x32 at 0x2B0BD213170>, 6)
print(train_set.classes)  # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

img, target = train_set[0]
print(img)          # <PIL.Image.Image image mode=RGB size=32x32 at 0x18BB86AF200>
print(target)       # 6 查看 该图片的 标签 是frog

print(train_set.classes[target])  # frog


write = SummaryWriter("log_1")
for i in range(10):
    img, target = train_set[i]
    write.add_image("train_set",img,i)

write.close()

torchvision.datasets.CIFAR10是PyTorch中提供的一个数据集类,用于加载CIFAR-10数据集。CIFAR-10数据集是一个常用的图像分类数据集,包含10个类别的60000张32x32的彩色图像,其中50000张用于训练,10000张用于测试。这个数据集常用于训练图像分类模型。
参数:

  •   root:数据集的存储路径。如果设置为下载模式(download=True),则会在该路径下保存数据集;如果数据集已经存在,则可以直接从该路径加载。
  •     train:一个布尔值,表示是否加载训练数据集。如果为True,则加载训练数据;如果为False,则加载测试数据。
  •     transform:对数据集中的图像进行的一些操作,如归一化、随机裁剪、数据增强等。这是一个可选参数,可以根据需要进行设置。
  •     target_transform:对数据集中的标签进行的一些操作。这也是一个可选参数,用于对标签进行预处理。
  •     download:一个布尔值,表示是否需要从Internet下载数据集到指定的root目录。默认为False,即如果数据集已经存在,则不会重新下载。

 class CIFAR10(VisionDataset) --> class VisionDataset(data.Dataset) :

可见 CIFAR10 最终继承自 Datase;

其内部也实现了 dataset 的 方法:

  •         __getitem__ 方法
  •         __len__ 方法

 

3. DataLoader

主要用于简化深度学习中的数据加载预处理过程。‌它是PyTorch框架中的一个类,‌能够将数据集(‌通常是一个Dataset对象)‌打包成一个可迭代的对象,‌方便在训练过程中逐批次读取数据。‌DataLoader提供了多种功能,‌包括数据的随机打乱、‌并行加载、‌多线程加载等,‌以优化数据处理的效率和灵活性。‌此外,‌DataLoader还支持数据清洗和灵活的数据转换,‌如字段的映射、‌计算、‌替换等操作,‌进一步降低了数据处理的难度。‌处理完数据后,‌DataLoader还支持将数据导出为多种格式,‌如Excel、‌CSV、‌数据库等,‌满足不同用户的需求。‌

  1. dataset:‌这是加载的数据集对象,‌数据类型为dataset。‌它指定了DataLoader将要处理的数据来源。‌
  2. batch_size:‌定义了每个batch的大小,‌即数据被切分的份数。‌数据类型为int。‌这个参数决定了每次从数据集中取出的数据量。‌
  3. shuffle:‌决定是否打乱数据的顺序,‌默认设置为False。‌数据类型为bool。‌启用shuffle可以在每个epoch开始时打乱数据的顺序,‌有助于模型的泛化能力。‌
  4. drop_last:‌当数据集的大小不能被batch_size整除时,‌决定是否丢弃最后一个不完整的batch。‌数据类型为bool。‌如果设置为True,‌最后一个不完整的batch将被丢弃;‌如果设置为False,‌则会保留所有数据,‌包括不完整的batch。‌
  5. num_workers:‌表示在数据加载时使用的子进程数。‌在Windows上设置为0,‌而在Linux上可以设置大于0的值。‌数据类型为int。‌这个参数影响了数据加载的速度,‌通过增加工作进程数可以并行加载数据,‌提高效率。‌
  6. collate_fn:‌这是一个callable对象,‌用于将batch_size样本整理成一个batch样本,‌便于批量处理训练。‌它定义了如何将多个样本组合成一个batch的规则。‌
  7. sampler:‌这是一个“采样器”,‌表示从样本中如何取样,‌提供整个数据集的随机访问的索引列表。‌数据类型为Sampler。‌它允许更灵活地控制数据的取样方式。‌
  8. Pin_memory:‌内存寄存标志,‌默认为False。‌在数据返回前,‌决定是否将数据复制到内存中(‌CPU/GPU)‌。‌数据类型为bool。‌这可以影响数据加载的速度和效率。‌
  9. Timeout:‌设置数据读取的超时时间,‌超过该值还未读取到数据就会报错。‌数据类型为numeric。‌这保证了DataLoader在等待数据时的超时机制。‌
  10. Worker_init_fn:‌这是一个callable对象,‌用于初始化工作进程。‌它允许在每个工作进程开始时执行特定的初始化操作。‌

     示例代码:

import torchvision
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

# 测试集
train_data =  torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=torchvision.transforms.ToTensor(),download=True)

train_loader = DataLoader(dataset=train_data,batch_size=4,shuffle=True,num_workers=0 ,drop_last=False)

# 测试数据集
img , target = train_data[0]
print(img.shape)
print(target)

write = SummaryWriter("log_2")
write.add_image("Image2",img, dataformats="CHW")
step=0


for data in train_loader:
    imgs, target = data
    print(img.shape)
    # print(target)
    write.add_image("Image4",imgs,step, dataformats="NCHW")
    step = step +1
write.close()



dataformats 中的 NCHW的数据排列方式:

https://blog.csdn.net/chengyq116/article/details/112759824

  • 22
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值