Pytorch中的sampler

Pytorch的DataLoader使用dataset和sampler参数处理数据。如果不指定sampler,DataLoader默认按顺序取样。SequentialSampler保持顺序取样,RandomSampler随机取样,SubsetRandomSampler用于子集随机取样,WeightedRandomSampler基于权重的随机取样。BatchSampler则控制batch_size的划分方式。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Pytorch中常见sampler的使用

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False, pin_memory_device='')

Pytorch中在使用DataLoader函数时要传入三个重要参数dataset, batch_sizesampler

  • dataset: 是数据集
  • batch_size: 是一次要喂入的参数数量
  • sampler:是从dataset中取数据的策略

Pytorch给出了集中常见的sampler:SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler
如果DataLoader不指定sampler的话,它就会按顺序依次喂入数据,例如:

import  torch
import  numpy  as  np
a = torch.from_numpy(np.arange(10,20))

dataloader = torch.utils.data.DataLoader(a, batch_size=3, shuffle=False)
for i, x in enumerate(dataloader):
	print(i, x)

输出的数据为:

0 tensor([10, 11, 12])
1 tensor([13, 14, 15])
2 tensor([16, 17, 18])
3 tensor([19])

更多详细信息参考官方文档:https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler

1. SequentialSampler

SequentialSampler和不指定sampler一样,都是按顺序产生batch data:

sampler = torch.utils.data.SequentialSampler(a)
dataloader = torch.utils.data.DataLoader(a, batch_size=3, sampler=sampler,shuffle=False)
for i, x in enumerate(dataloader):
	print(i, x)

输出为:

0 tensor([10, 11, 12])
1 tensor([13, 14, 15])
2 tensor([16, 17, 18])
3 tensor([19])

我们可以将sampler打印出来看看里面输出的是什么:

for i in sampler:
	print(i)

可以看到输出的是数据集a的索引:

0
1
2
3
...
9

2. RandomSampler

RandomSampler会打乱daat的顺序随即输出batch

sampler = torch.utils.data.RandomSampler(a)
dataloader = torch.utils.data.DataLoader(a, batch_size=3, sampler=sampler,shuffle=False)
for i, x in enumerate(dataloader):
	print(i, x)

输出结果为:

0 tensor([14, 17, 11])
1 tensor([18, 15, 12])
2 tensor([19, 13, 16])
3 tensor([10])

3. SubsetRandomSampler

SubsetRandomSampler会用来产生数据的子集,需要自己生成随即的indices。

import numpy.random as random
n_a = len(a)
indices = random.permutation(list(range(n_a)))
sampler = torch.utils.data.SubsetRandomSampler(indices[:8])
dataloader = torch.utils.data.DataLoader(a, batch_size=3, sampler=sampler,shuffle=False)
for i, x in enumerate(dataloader):
	print(i, x)

输出结果为:

0 tensor([16, 10, 14])
1 tensor([11, 17, 15])
2 tensor([13, 19])

4. WeightedRandomSampler

WeightedRandomSampler给每个样本分配不同的权重:

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]

5. BatchSampler

除了在DataLoader中定义batch_size以外,还可以使用BatchSampler来确定batch_size。例如:

>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
from torch.utils.data import SequentialSampler, BatchSampler
n_a = len(a)
sampler = BatchSampler(SequentialSampler(range(n_a)), batch_size=3, drop_last=False)
dataloader = torch.utils.data.DataLoader(a, sampler=sampler,shuffle=False)
for i, x in enumerate(dataloader):
print(i, x)

结果为(为啥是二维的?):

0 tensor([[10, 11, 12]])
1 tensor([[13, 14, 15]])
2 tensor([[16, 17, 18]])
3 tensor([[19]])

打印出sampler我可以看到每个输出是一个batch_size的索引:

for i in sampler:
print(i)

输出:

[0, 1, 2]
[3, 4, 5]
[6, 7, 8]
[9]
### PyTorch DataLoader 的功能与使用方法 #### 1. 基础概念 `DataLoader` 是 PyTorch 中用于批量加载数据的核心工具之一。它通过封装 `Dataset` 对象,提供了一种高效的方式来处理大规模数据集并支持多线程读取[^1]。 #### 2. 参数详解 以下是 `DataLoader` 的主要参数及其作用: - **dataset**: 这是一个实现了 `__getitem__()` 和 `__len__()` 方法的对象,表示要加载的数据集合。 - **batch_size**: 定义每次迭代返回的样本数量,默认值为 1。 - **shuffle**: 如果设置为 True,则会在每个 epoch 开始前打乱数据顺序(仅当未指定 sampler 时有效)。默认值为 False。 - **sampler**: 自定义采样器对象,用于控制数据加载的顺序。如果指定了 sampler,则 shuffle 应该设为 None 或者不指定[^2]。 - **num_workers**: 表示用于数据加载的子进程数。增加此数值可以加速数据预处理过程,尤其是在 GPU 训练场景下推荐大于零的值。 - **collate_fn**: 用户自定义函数,用来合并一批次的数据样本到张量或其他结构化形式中去。如果没有特别需求的话会采用默认实现方式。 #### 3. 使用实例 下面展示如何创建一个简单的 `DataLoader` 并结合自定义 `Sampler` 来完成特定任务: ```python from torch.utils.data import Dataset, DataLoader, Sampler class MyCustomDataset(Dataset): def __init__(self, data_list): self.data = data_list def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] class CustomSampler(Sampler): def __init__(self, data_source): super().__init__(data_source) self.indices = list(range(len(data_source))) def __iter__(self): random.shuffle(self.indices) # 随机排列索引 return iter(self.indices) def __len__(self): return len(self.indices) # 创建数据集和采样器 my_dataset = MyCustomDataset([i for i in range(10)]) custom_sampler_instance = CustomSampler(my_dataset) # 初始化 Data Loader dataloader = DataLoader( my_dataset, batch_size=2, sampler=custom_sampler_instance, num_workers=0 ) for batch_data in dataloader: print(batch_data) ``` 上述代码片段展示了如何构建一个带有随机抽样的 `DataLoader` 实例[^3]。 #### 4. 数据增强 虽然 `DataLoader` 主要是负责数据分发的工作流管理,但它也可以配合其他库或者模块来进行图像变换等操作以达到数据扩增的目的。例如 torchvision.transforms 提供了一系列丰富的转换手段可以帮助我们轻松实现这一点。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值