PyTorch中Dataset和DataLoader的使用介绍

本文总结一下PyTorch中Dataset和DataLoader的使用流程:

首先假设存在一批数据格式如下:

[{ "answer": "hat", 
  "question": "What is the man wearing on his head?"}, 
 {"answerl": "yes", 
  "question": "Are they all happy?"}
...]

希望将它批次加载给模型输入,如何操作呢?

1. 我们首先利用Dataset,Dataset是一个抽象类,用于表示数据集。可以创建自定义的Dataset类,也可以使用PyTorch提供的预定义的Dataset类,如torchvision.datasets.ImageFolder(用于处理图像数据)或torchtext.data.Dataset(用于文本数据)。

我们这里重点介绍自定义的Dataset,自定义Dataset类必须实现两个方法:

  • __len__:返回数据集大小,即样本数量。
  • __getitem__:根据给定的索引返回数据集中的样本。这个方法将用于数据加载和批处理,(在后续枚举Dataloader时会自动调用)

以下是一个简单的示例:

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5TokenizerFast

data = [
    {
        "answer": "hat",
        "question": "What is the man wearing on his head?"
    },
    {
        "answer": "yes",
        "question": "Are they all happy?"
    },
    # 添加更多数据
]

tokenizer = T5TokenizerFast #这里需要加载tokenizer,就不仔细写了

class CustomQADataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item["question"]
        answer = item["answer"]
        input_ids = self.tokenizer.encode(question, max_length=20, truncation=True)
        target_ids = self.tokenizer.encode(answer, max_length=20, truncation=True)

        out_dict = {}
        out_dict['question'] = question
        out_dict['answer'] = answer
        out_dict['input_length'] = len(input_ids)
        out_dict['input_ids'] = input_ids
        out_dict['target_ids'] = target_ids
        out_dict['target_length'] = len(target_ids)
        
        return out_dict
    

dataset = CustomQADataset(data)

这样就创建了一个自定义的Dataset类实例

2. 接着创建DataLoader实例,不过在创建之前,我们一般可以定义一个collate_fn函数,用于组合样本成批次,如下示例:

def custom_collate_fn(batch):
    questions = [item['question'] for item in batch]
    answers = [item['answer'] for item in batch]
    input_lengths = [item['input_length'] for item in batch]
    input_ids = [item['input_ids'] for item in batch]
    target_ids = [item['target_ids'] for item in batch]
    target_lengths = [item['target_length'] for item in batch]

    # 在这里可以对数据进行进一步处理,如填充、对齐或转换成张量

    return {
        'questions': questions,
        'answers': answers,
        'input_lengths': input_lengths,
        'input_ids': input_ids,
        'target_ids': target_ids,
        'target_lengths': target_lengths
    }

然后创建DataLoader实例:

# 创建DataLoader实例,并传递自定义的collate_fn
batch_size = 32
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=custom_collate_fn)

现在可以迭代data_loader来获取批次数据:

for batch in data_loader:
    questions = batch['questions']
    answers = batch['answers']
    input_ids = batch['input_ids']
    target_ids = batch['target_ids']
    # 在这里执行模型训练或其他操作

以上就是对于Dataset和DataLoader使用的一个简单介绍,具体细节和更高级的用法留待以后(挖个坑~)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值