Dataset和DataLoader用法

Dataset和DataLoader用法

在d2l中有简洁的加载固定数据的方式,如下

d2l.load_data_fashion_mnist()
# 源码
Signature: d2l.load_data_fashion_mnist(batch_size, resize=None)
Source:   
def load_data_fashion_mnist(batch_size, resize=None):
    """Download the Fashion-MNIST dataset and then load it into memory.

    Defined in :numref:`sec_fashion_mnist`"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))
File:      ~/anaconda3/envs/d2l/lib/python3.9/site-packages/d2l/torch.py
Type:      function

如果我们要自定义需要加载的数据集

数据集:一个图片文件夹,用csv文件来表示训练数据和标签

# 定义Dataset
import pandas as pd
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import torchvision.transforms as transforms

class CustomDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.data = pd.read_csv(csv_file) 
        self.root_dir = root_dir
        self.transform = transform
        label_encoder = LabelEncoder()
        self.labels = label_encoder.fit_transform(self.data['label'])
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.data.iloc[idx, 0])
        # 读取图片并做增广
        image = Image.open(img_name)
        if self.transform is not None:
            image = self.transform(image)
        # 将数字转换成独热编码的张量(记得转换成float)
        label = F.one_hot(torch.tensor(self.labels[idx]), 		
        					num_classes=self.data['label'].nunique()).float()
        return image, label

# 定义参数和超参数训练
batch_size = 256
lr = num_epoch = 0.9, 10

# 加载数据
sample = '/kaggle/input/classify-leaves/sample_submission.csv'
ts_path = "/kaggle/input/classify-leaves/test.csv"
tr_path = "/kaggle/input/classify-leaves/train.csv"
image_path = '/kaggle/input/classify-leaves'

dataset = CustomDataset(csv_file = sample, root_dir = image_path, transform=transform_train)
train_size = int(0.8 * len(dataset))
valid_size = len(dataset) - train_size
tr_dataset, te_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size])

tr_dataloader = DataLoader(tr_dataset, batch_size, shuffle=True)
ts_dataloader = DataLoader(te_dataset, batch_size, shuffle=False)

总结

需要将__init__,len,__getitem__按照数据集和模型的要求,对应的编写好代码。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
datasetdataloader是在深度学习中常用的数据处理工具。 Dataset是一个抽象类,用于表示数据集。在使用时,我们可以继承该类并实现自己的数据加载逻辑。通常情况下,我们需要重写`__len__`方法返回数据集大小,以及`__getitem__`方法根据索引返回对应的样本数据。 Dataloader是一个用于批量加载数据的迭代器。它接收一个Dataset对象作为输入,并提供一些参数用于配置数据加载的行为。通过调用dataloader的`__iter__`方法,我们可以得到一个可迭代的对象,每次迭代返回一个批次的数据。 下面是一个简单示例,展示了如何使用datasetdataloader加载数据: ```python import torch from torch.utils.data import Dataset, DataLoader class MyDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index] data = [1, 2, 3, 4, 5] dataset = MyDataset(data) dataloader = DataLoader(dataset, batch_size=2, shuffle=True) for batch in dataloader: # 在这里进行模型训练或推断 print(batch) ``` 在上面的示例中,我们首先定义了一个自定义的Dataset类`MyDataset`,并实现了必要的方法。然后我们创建了一个dataset对象并传入了我们的数据。接下来,我们创建了一个dataloader对象,并指定了一些参数,例如批大小和是否打乱数据等。最后,我们使用for循环迭代dataloader,每次迭代得到一个batch的数据,可以用于模型的训练或推断。 通过使用datasetdataloader,我们可以更方便地处理和加载数据,从而提高模型训练和推断的效率。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值