Pytorch建立MyDataLoader过程详解

简介

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device=‘’)

详细:DataLoader

自己基于DataLoader实现各个模块

代码实现

MyDataset基于torch中的Data实现对个人数据集的载入,例如图像和标签载入
SingleSampler基于torch中的Sampler实现对于数据的batch个数图像的载入,例如,Batch_Size=4,实现对所有数据中选取4个索引作为一组,然后在MyDataset中基于__getitem__根据图像索引去进行图像操作
MyBathcSampler基于torch的BatchSampler实现自己对于batch_size数据的处理。需要基于SingleSampler实现Sampler的处理,更为灵活。MyBatchSampler的存在会自动覆盖DataLoader中的batch_size参数
注:Sampler的实现,将会与shuffer冲突,shuffer是在没有实现sampler前提下去自动判断选择的sampler类型
collate_fn是实现将batch_size的图像数据进行打包,遍历过程中就可以实现batch_size的images和labels对应
在这里插入图片描述

sampler

from typing import Iterator, List
import torch
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Sampler


class MyDataset(Dataset):
    def __init__(self) -> None:
        self.data = torch.arange(20)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]
    
    @staticmethod
    def collate_fn(batch):
        return torch.stack(batch, 0)

class MyBatchSampler(BatchSampler):
    def __init__(self, sampler: Sampler[int], batch_size: int) -> None:
        self._sampler = sampler
        self._batch_size = batch_size
    
    def __iter__(self) -> Iterator[List[int]]:
        batch = []
        for idx in self._sampler:
            batch.append(idx)
            if len(batch) == self._batch_size:
                yield batch
                batch = []
        yield batch
    
    def __len__(self):
        return len(self._sampler) // self._batch_size

class SingleSampler(Sampler):
    def __init__(self, data_source) -> None:
        self._data = data_source
        self.num_samples = len(self._data)
        
    def __iter__(self):
        # 顺序采样
        # indices = range(len(self._data))
        # 随机采样
        indices = torch.randperm(self.num_samples).tolist()
        return iter(indices)
    
    def __len__(self):
        return self.num_samples
        

train_set = MyDataset()
single_sampler = SingleSampler(train_set)
batch_sampler = MyBatchSampler(single_sampler, 8)
train_loader = DataLoader(train_set, batch_size=4, sampler=single_sampler, pin_memory=True, collate_fn=MyDataset.collate_fn)
for data in train_loader:
    print(data)

batch_sampler

from typing import Iterator, List
import torch
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Sampler

class MyDataset(Dataset):
    def __init__(self) -> None:
        self.data = torch.arange(20)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]
    
    @staticmethod
    def collate_fn(batch):
        return torch.stack(batch, 0)

class MyBatchSampler(BatchSampler):
    def __init__(self, sampler: Sampler[int], batch_size: int) -> None:
        self._sampler = sampler
        self._batch_size = batch_size
    
    def __iter__(self) -> Iterator[List[int]]:
        batch = []
        for idx in self._sampler:
            batch.append(idx)
            if len(batch) == self._batch_size:
                yield batch
                batch = []
        yield batch
    
    def __len__(self):
        return len(self._sampler) // self._batch_size

class SingleSampler(Sampler):
    def __init__(self, data_source) -> None:
        self._data = data_source
        self.num_samples = len(self._data)
        
    def __iter__(self):
        # 顺序采样
        # indices = range(len(self._data))
        # 随机采样
        indices = torch.randperm(self.num_samples).tolist()
        return iter(indices)
    
    def __len__(self):
        return self.num_samples
        

train_set = MyDataset()
single_sampler = SingleSampler(train_set)
batch_sampler = MyBatchSampler(single_sampler, 8)
train_loader = DataLoader(train_set, batch_sampler=batch_sampler, pin_memory=True, collate_fn=MyDataset.collate_fn)
for data in train_loader:
    print(data)

参考

Sampler:https://blog.csdn.net/lidc1004/article/details/115005612

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Fighting_1997

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值