pytorch中的DataLoader与DataSet

p y t o r c h 中 的 D a t a L o a d e r 与 D a t a S e t pytorch中的DataLoader与DataSet pytorchDataLoaderDataSet


class torch.utils.data.Dataset

决定数据从哪读取,如何读取,进行何种预处理

表示Dataset的抽象类。
所有子类应该override __len____getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。

LoadDataset(data_dir=major_config.train_image, transform=train_transform)
import os
import random
from PIL import Image
from torch.utils.data import Dataset
import major_config
random.seed(1)

# 类别对应表
dict_label = major_config.dict_label

# 返回所有图片路径和标签
def get_img_label(data_dir):
    img_label_list = list()
    for root, dirs, _ in os.walk(data_dir):
        # 遍历类别
        for sub_dir in dirs:
            img_names = os.listdir(os.path.join(root, sub_dir))
            # img_names = list(filter(lambda x: x.endswith('.png'), img_names))   # 如果改了图片格式,这里需要修改
            # 遍历图片
            for i in range(len(img_names)):
                img_name = img_names[i]
                path_img = os.path.join(root, sub_dir, img_name)
                label = dict_label[sub_dir]
                img_label_list.append((path_img, int(label)))
    return img_label_list

# 主要是用来接受索引返回样本用的
class LoadDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        # 获取所有图片的路径、label , 和 确定预处理操作
        self.img_label_list = get_img_label(data_dir)  # img_label_list,在DataLoader中通过index读取样本
        self.transform = transform

    #接受一个索引,返回一个样本 ---  img, label
    def __getitem__(self, index):
        path_img, label = self.img_label_list[index]
        img = Image.open(path_img).convert('RGB')     # 0~255
        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等
        return img, label

    def __len__(self):
        return len(self.img_label_list)




__getitem__的主要作用

主要是用来接受索引返回样本用的。(Sample :Index 生成索引)

__len__的主要作用

__getitem__接受索引的范围就是__len__里确定的范围


class torch.utils.data.DataLoader

class torch.utils.data.DataLoader(  dataset, 
									batch_size=1,
									shuffle=False,
									sampler=None, 
									num_workers=0, 
									collate_fn=<function default_collate>,
									pin_memory=False, 
									drop_last=False
)
  • dataset (Dataset) – 加载数据的数据集。
  • batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
  • shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
  • sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
  • num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
  • drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)。即当样本数不能被batchsize整除时,是否舍弃最后一批数据

_SingleProcessDataLoaderIter

def _next_data(self):

_sampler_iter

sampler

在这里插入图片描述

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)
    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: 如果你是在问 PyTorch 的数据集和数据加载器,那么我很愿意为您解答。 PyTorch 是一个开源深度学习框架,其包含了用于加载和预处理数据的工具。其最重要的两个组件是数据集 (Dataset) 和数据加载器 (DataLoader)。 数据集是一个 PyTorch 类,它定义了如何读取数据、如何访问数据以及如何将数据转换为张量。您可以使用内置的数据集类,例如 torchvision.datasets 的 ImageFolder,或者自定义数据集类。 数据加载器是一个 PyTorch 类,它可以从数据集读取数据并将其批量加载到内存。数据加载器可以进行并行加载,以提高加载速度,并且可以通过 shuffle 参数来随机打乱数据。 举个例子,如果您有一个图像数据集,您可以使用以下代码来创建数据集和数据加载器: ``` import torch import torchvision # 创建数据集 dataset = torchvision.datasets.ImageFolder(root='path/to/data', transform=transforms.ToTensor()) # 创建数据加载器 dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4) ``` 然后,您可以使用以下代码来读取数据: ``` for inputs, labels in dataloader: # 处理输入数据 ... ``` 希望对您有所帮助! ### 回答2: PyTorch是一种广泛使用的深度学习框架,具有易于使用的API和优秀的性能。其DatasetDataLoader是两个非常重要的类,它们可以帮助我们有效地加载和处理数据。 Dataset是一个抽象的概念,用于表示一组数据。我们可以继承它并重写其的方法,以实现对不同数据集的适配。在初始化时,我们需要传递一个数据集,比如说图片数据集,然后在DataLoader使用这个数据集,实现数据的准备和加载。在自定义Dataset时,我们需要定义__getitem__和__len__两个方法,分别用于返回数据集的某个数据和数据总数。 DataLoader是一个非常实用的工具,用于加载数据并把数据变成可迭代的对象,其包含了批量大小、数据是否随机等设置。我们可以设置num_workers参数,用多个进程来读取数据提高读取数据的速度。通过使用DataLoader,我们可以很方便地迭代整个数据集,可以按批次加载和处理数据。 当我们使用在线学习时,经常需要不断地读取数据并进行训练。在应用,我们会遇到许多不同的数据集,其可能包含不同的数据类型,比如图像、音频、文本等。使用DatasetDataLoader类,我们可以轻松处理这些数据,从而使我们的深度学习应用具有更广泛的适用性和扩展性。 总之,DatasetDataLoaderPyTorch非常重要的类,它们可以帮助我们非常方便地进行数据的处理和加载。无论你想要使用哪种数据集,它们都能够很好地适配。在实际应用,我们可以灵活地使用这两个类来加载和准备数据并进行训练,从而加快应用的速度并提高深度学习的精度。 ### 回答3: PyTorch是一个流行的深度学习框架,它提供了DatasetDataLoader这两个类来帮助我们更方便地处理数据。 Dataset可以看作是一个数据集,它定义了如何读取数据。官方提供了两种Dataset:TensorDataset和ImageFolder。TensorDataset是用来处理张量数据,而ImageFolder则是用来处理图像数据。如果我们需要使用其他类型的数据,我们可以通过重写Dataset的__getitem__和__len__方法来实现。 在实现Dataset之后,我们需要将数据读取到内存,在模型训练时提供给模型,这时我们就需要使用到DataLoader了。DataLoader可以看作是一个数据加载器,它会自动将Dataset的数据批量读取到内存,并且支持数据的分布式加载。 在使用DataLoader时我们可以设置很多参数,比如batch_size表示每个batch的大小,shuffle表示是否打乱数据顺序,num_workers表示使用多少线程读取数据等等。这些参数都可以帮助我们更好地利用硬件资源,提高训练速度和效率。 使用PyTorchDatasetDataLoader可以帮助我们更方便快捷地处理数据,并且让我们可以更专注于模型的设计和训练。但我们也要注意一些细节问题,比如数据读取是否正确、内存使用是否合理等等。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值