pytorch DataLoader 自定义 sampler

本文介绍了如何在PyTorch中创建自定义数据加载器和Sampler。通过一个基础代码示例,展示了如何定义一个名为`IndependentHalvesSampler`的Sampler子类,该Sampler将数据集分为两半并随机打乱,然后以交错方式遍历。使用这个Sampler创建DataLoader,可以观察到数据以特定顺序被加载。这为理解和定制数据采样策略提供了基础。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文归纳自:

https://www.scottcondron.com/jupyter/visualisation/audio/2020/12/02/dataloaders-samplers-collate.html#SequentialSampler 

基础代码

from torch.utils.data import DataLoader
#构造数据
xs = list(range(11))
ys = list(range(10,21))
print('xs values: ', xs)
print('ys values: ', ys)
#定义Dataset子类 
class MyDataset:
    def __init__(self, xs, ys):
        self.xs = xs
        self.ys = ys
    
    def __getitem__(self, i):
        return self.xs[i], self.ys[i]
    
    def __len__(self):
        return len(self.xs)
#实例化一个Dataset
dataset = MyDataset(xs, ys)
dataset[2] # returns the tuple (x[2], y[2])

自定义 Sampler

每个自定义的Sampler子类需要一个 __iter__ 方法, 以保证遍历到dataset的每一个元素,

还需要一个 __len__ 方法 ,来获取数据集的大小.

#collapse-hide
import random
from torch.utils.data.sampler import Sampler

class IndependentHalvesSampler(Sampler):
    def __init__(self, dataset):
        halfway_point = int(len(dataset)/2)
        self.first_half_indices = list(range(halfway_point))
        self.second_half_indices = list(range(halfway_point, len(dataset)))
        
    def __iter__(self):
        random.shuffle(self.first_half_indices)
        random.shuffle(self.second_half_indices)
        return iter(self.first_half_indices + self.second_half_indices)
    
    def __len__(self):
        return len(self.first_half_indices) + len(self.second_half_indices)


our_sampler = IndependentHalvesSampler(dataset)
print('First half indices: ', our_sampler.first_half_indices)
print('Second half indices:', our_sampler.second_half_indices)

dl = DataLoader(dataset, sampler=our_sampler, batch_size=1)

for i, data in enumerate(dl):
  print(data)

[tensor([1]), tensor([11])]
[tensor([4]), tensor([14])]
[tensor([3]), tensor([13])]
[tensor([0]), tensor([10])]
[tensor([2]), tensor([12])]
[tensor([7]), tensor([17])]
[tensor([8]), tensor([18])]
[tensor([9]), tensor([19])]
[tensor([5]), tensor([15])]
[tensor([6]), tensor([16])]
[tensor([10]), tensor([20])]

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

子燕若水

吹个大气球

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值