torch.utils.data中Dataset, DataLoader

本文介绍了PyTorch中的torch.utils.data模块,包括Dataset和DataLoader的使用,以及如何自定义数据集、采样器和变换。Dataset是数据集的抽象类,需实现__len__和__getitem__方法。DataLoader则负责数据的批量加载和预处理,支持多进程加载和内存固定。文章还提到了num_workers、pin_memory等参数以及Sampler和Transform的自定义方法。
摘要由CSDN通过智能技术生成

torch.utils.dataPyTorch中用于数据加载和预处理的模块。其中包括DatasetDataLoader两个类,它们通常结合使用来加载和处理数据。

Dataset

torch.utils.data.Dataset是一个抽象类,用于表示数据集。它需要用户自己实现两个方法:__len____getitem__。其中,__len__方法返回数据集的大小,__getitem__方法用于根据给定的索引返回一个数据样本。

以下是一个简单的示例,展示了如何定义一个数据集:

import torch.utils.data as data

class MyDataset(data.Dataset):
def __init__(self, data_list):
self.data_list = data_list

def __len__(self):
return len(self.data_list)

def __getitem__(self, index):
return self.data_list[index]

在这个示例中,MyDataset继承了torch.utils.data.Dataset类,并实现了__len____getitem__方法。__len__方法返回数据集的大小,这里使用了Python内置函数len__getitem__方法根据给定的索引返回一个数据样本,这里返回的是数据列表中对应的元素。

DataLoader

torch.utils.data.DataLoader是用于加载数据的类,它可以自动对数据进行批量处理和随机化。以下是一个简单的示例:

import torch.utils.data as data

my_dataset = MyDataset([1, 2, 3, 4, 5])
my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=True)

for batch in my_dataloader:
print(batch)

在这个示例中,我们首先创建了一个MyDataset实例my_dataset,它包含了一个整数列表。然后,我们使用DataLoader类创建了一个数据加载器my_dataloader,它将my_dataset作为输入,并将数据分成大小为2的批次,并对数据进行随机化。最后,我们使用一个循环来遍历my_dataloader,并打印出每个批次的数据。

总结一下,torch.utils.data.Dataset用于表示数据集,torch.utils.data.DataLoader用于加载数据,并对数据进行批量处理和随机化。下面是一个完整的示例,展示了如何使用这两个类来加载和处理数据:

import torch.utils.data as data

class MyDataset(data.Dataset):
def __init__(self, data_list):
self.data_list = data_list

def __len__(self):
return len(self.data_list)

def __getitem__(self, index):
return self.data_list[index]

my_dataset = MyDataset([1, 2, 3, 4, 5])
my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=True)

for batch in my_dataloader:
print(batch)

除了上述介绍的基本用法,torch.utils.data模块还有许多其他的功能和选项。下面介绍一些常用的选项和功能。

num_workers

num_workers参数用于指定使用多少个进程来加载数据。默认值为0,表示使用主进程加载数据。如果设置为大于0的值,将使用多个进程来加载数据,可以提高数据加载的效率。

以下是一个示例:

my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=True, num_workers=4)

在这个示例中,num_workers被设置为4,表示将使用4个进程来加载数据。

pin_memory

pin_memory参数用于指定是否将数据加载到CUDA主机内存中的固定位置(pinned memory),以提高数据传输效率。默认值为False

以下是一个示例:

my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=True, pin_memory=True)

在这个示例中,pin_memory被设置为True,表示将数据加载到CUDA主机内存中的固定位置。

collate_fn

collate_fn参数用于指定如何将样本组合成一个批次。默认情况下,DataLoader将每个样本作为一个单独的元素传递给模型,但在某些情况下,需要将样本组合成一个批次,以便一次性对整个批次进行处理。

以下是一个示例:

def my_collate_fn(batch):
# 将样本组合成一个批次
data = [item[0] for item in batch]
target = [item[1] for item in batch]
return [data, target]

my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=True, collate_fn=my_collate_fn)

在这个示例中,my_collate_fn是一个自定义的函数,用于将样本组合成一个批次。DataLoader将每个样本作为一个元素传递给my_collate_fn函数,函数将样本组合成一个批次,并返回一个包含数据和目标的列表。

Sampler

Sampler是一个用于指定数据集采样方式的类,它控制DataLoader如何从数据集中选取样本。PyTorch提供了多种Sampler类,例如RandomSamplerSequentialSampler,分别用于随机采样和顺序采样。

以下是一个示例:

from torch.utils.data.sampler import RandomSampler

my_sampler = RandomSampler(my_dataset)
my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=False, sampler=my_sampler)

在这个示例中,我们使用RandomSampler类来指定随机采样方式,然后将其传递给DataLoadersampler参数。这将覆盖默认的shuffle参数,使数据集按照sampler指定的采样方式进行

自定义Dataset

除了使用torchvision.datasets中提供的数据集,我们还可以使用torch.utils.data.Dataset类来自定义自己的数据集。自定义数据集需要实现__len____getitem__方法。

__len__方法返回数据集中样本的数量,__getitem__方法根据给定的索引返回一个样本。样本可以是一个张量或者一个元组,其中第一个元素是数据,第二个元素是目标。

以下是一个示例:

class MyDataset(data.Dataset):
def __init__(self, data_path):
self.data = torch.load(data_path)

def __len__(self):
return len(self.data)

def __getitem__(self, index):
x = self.data[index][0]
y = self.data[index][1]
return x, y

在这个示例中,MyDataset类继承自torch.utils.data.Dataset类,实现了__len____getitem__方法。MyDataset类的构造函数接受一个数据路径作为参数,数据集被保存为一个由数据-目标对组成的列表。__len__方法返回数据集中样本的数量,__getitem__方法根据给定的索引返回一个数据-目标对。

自定义Sampler

除了使用torch.utils.data.sampler中提供的采样器,我们还可以使用Sampler类来自定义自己的采样器。自定义采样器需要实现__iter____len__方法。

__iter__方法返回一个迭代器,用于遍历数据集中的样本索引。__len__方法返回数据集中样本的数量。

以下是一个示例:

class MySampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source

def __iter__(self):
return iter(range(len(self.data_source)))

def __len__(self):
return len(self.data_source)

在这个示例中,MySampler类继承自torch.utils.data.sampler.Sampler类,实现了__iter____len__方法。MySampler类的构造函数接受一个数据集作为参数,__iter__方法返回一个迭代器,用于遍历数据集中的样本索引,__len__方法返回数据集中样本的数量。

自定义Transform

除了使用torchvision.transforms中提供的变换,我们还可以使用transforms模块中的Compose类来自定义自己的变换。Compose类将多个变换组合在一起,并按照顺序应用它们。

以下是一个示例:

class MyTransform(object):
def __call__(self, x):
x = self.crop(x)
x = self.to_tensor(x)
return x

def crop(self, x):
# 实现裁剪变换
return x

def to_tensor(self, x):
# 实现张量化变换
return x

my_transform = transforms.Compose

my_transform = transforms.Compose([
MyTransform()
])

# 创建数据集和数据加载器
my_dataset = MyDataset(data_path)
my_dataloader = DataLoader(my_dataset, batch_size=32, shuffle=True, num_workers=4)

# 遍历数据集
for batch in my_dataloader:
# 在这里处理数据批次
pass

在这个示例中,MyTransform类实现了一个自定义的变换,它将裁剪和张量化两个变换组合在一起。transforms.Compose将这个自定义变换组合成一个变换序列,并在数据集中的每个样本上应用这个序列。

最后,我们创建了一个数据集和数据加载器,并用它们来遍历数据集。在数据加载器返回的每个批次中,数据已经通过了我们自定义的变换序列。

总结

在这篇文章中,我们介绍了torch.utils.data模块中的DatasetDataLoader类,并给出了详细的代码示例。我们还讨论了如何自定义数据集、采样器和变换,并给出了相应的代码示例。使用DatasetDataLoader类,我们可以轻松地加载和处理大规模数据集,为模型训练提供了强大的支持。

本文由 mdnice 多平台发布

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

高山莫衣

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值