Dataset和Dataloader的使用

目录

需要导入

dataset模板

dataset重写例子

使用方法

dataloader使用


Dataset如果是一叠扑克牌的话,DataLoader就是一只手,参数就是告诉这只手怎么抓取扑克牌。

DataLoader常用参数:
dataset (Dataset) – 使用哪个数据集

batch_size (int, optional) – 一次抓取多少个数据,多少张牌

shuffle (bool, optional) – 是否重新洗牌

num_workers (int, optional) – 多线程,windows下要设成0,不然会出错,出现
 

BrokenPipeError: [Errno 32] Broken pipe

drop_last (bool, optional) – 当数据集大小不能被批大小整除时,设置为True则以删除最后一个不完整的批。False,则不删除,最后一批将变小。(默认值:False)

datasetloader加载批数据过程:

需要导入

Dataset(torch.utils.data.Dataset)

dataset模板

class mydataset(Dataset):
    def __init__(self, xxx):
        ...
    
    def __getitem__(self, item):
        ...

    def __len()__(self):
        ...

dataset重写例子



from torch.utils.data import Dataset, DataLoader
import torch

class MyDataset(Dataset):
    """
        下载数据、初始化数据,都可以在这里完成
    """

    def __init__(self,file):
        self.x = torch.linspace(11,20,10)
        self.y = torch.linspace(1,10,10)
        self.len = len(self.x)

    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return self.len



​

使用方法

dataset=mydataset(文件)
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
    def __init__(self,root_dir,laber):
        self.root_dir=root_dir
        self.laber=laber
        self.path=os.path.join(self.root_dir,self.laber)#目录拼接在一起
        self.img_path=os.listdir(self.path)#获取文件目录列表
    def __getitem__(self,idx):
        name=self.img_path[idx]
        img_item_path=os.path.join(self.root_dir,self.laber ,name)#文件路径
        img=Image.open(img_item_path)#文件传递给img
        laber=self.laber
        return img,laber
    def __len__(self):
        return len(self.img_path)

r=r"C:\Users\23087\Desktop\hymenoptera_data\train"
r1="ants"
root_dir = r"C:\Users\23087\Desktop\hymenoptera_data\train"
root1="bees"
ant_dataset = MyData(r, r1)
print(ant_dataset[0])
bees=MyData(root_dir,root1)
print(bees[0])
train=ant_dataset+bees
print(train[123])

dataloader使用

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

train_data = torchvision.datasets.CIFAR10("dataset", train=True, transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10("dataset", train=False, transform=torchvision.transforms.ToTensor())

trian_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

img, target = train_data[0]
print(img.shape)
print(target)

writer = SummaryWriter("logs/dataloadlogs")
for epoh in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        #print(imgs.shape)
        #print(targets)
        writer.add_images("epoh {}".format(epoh), imgs, step)
        step = step+1
        print(step)

writer.close()

datasetdataloader是在深度学习中常用的数据处理工具。 Dataset是一个抽象类,用于表示数据集。在使用时,我们可以继承该类并实现自己的数据加载逻辑。通常情况下,我们需要重写`__len__`方法返回数据集大小,以及`__getitem__`方法根据索引返回对应的样本数据。 Dataloader是一个用于批量加载数据的迭代器。它接收一个Dataset对象作为输入,并提供一些参数用于配置数据加载的行为。通过调用dataloader的`__iter__`方法,我们可以得到一个可迭代的对象,每次迭代返回一个批次的数据。 下面是一个简单示例,展示了如何使用datasetdataloader加载数据: ```python import torch from torch.utils.data import Dataset, DataLoader class MyDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index] data = [1, 2, 3, 4, 5] dataset = MyDataset(data) dataloader = DataLoader(dataset, batch_size=2, shuffle=True) for batch in dataloader: # 在这里进行模型训练或推断 print(batch) ``` 在上面的示例中,我们首先定义了一个自定义的Dataset类`MyDataset`,并实现了必要的方法。然后我们创建了一个dataset对象并传入了我们的数据。接下来,我们创建了一个dataloader对象,并指定了一些参数,例如批大小和是否打乱数据等。最后,我们使用for循环迭代dataloader,每次迭代得到一个batch的数据,可以用于模型的训练或推断。 通过使用datasetdataloader,我们可以更方便地处理和加载数据,从而提高模型训练和推断的效率。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值