本文总结一下PyTorch中Dataset和DataLoader的使用流程:
首先假设存在一批数据格式如下:
[{ "answer": "hat",
"question": "What is the man wearing on his head?"},
{"answerl": "yes",
"question": "Are they all happy?"}
...]
希望将它批次加载给模型输入,如何操作呢?
1. 我们首先利用Dataset,Dataset是一个抽象类,用于表示数据集。可以创建自定义的Dataset类,也可以使用PyTorch提供的预定义的Dataset类,如torchvision.datasets.ImageFolder(用于处理图像数据)或torchtext.data.Dataset(用于文本数据)。
我们这里重点介绍自定义的Dataset,自定义Dataset类必须实现两个方法:
- __len__:返回数据集大小,即样本数量。
- __getitem__:根据给定的索引返回数据集中的样本。这个方法将用于数据加载和批处理,(在后续枚举Dataloader时会自动调用)
以下是一个简单的示例:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5TokenizerFast
data = [
{
"answer": "hat",
"question": "What is the man wearing on his head?"
},
{
"answer": "yes",
"question": "Are they all happy?"
},
# 添加更多数据
]
tokenizer = T5TokenizerFast #这里需要加载tokenizer,就不仔细写了
class CustomQADataset(Dataset):
def __init__(self, data, tokenizer):
self.data = data
self.tokenizer = tokenizer
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
question = item["question"]
answer = item["answer"]
input_ids = self.tokenizer.encode(question, max_length=20, truncation=True)
target_ids = self.tokenizer.encode(answer, max_length=20, truncation=True)
out_dict = {}
out_dict['question'] = question
out_dict['answer'] = answer
out_dict['input_length'] = len(input_ids)
out_dict['input_ids'] = input_ids
out_dict['target_ids'] = target_ids
out_dict['target_length'] = len(target_ids)
return out_dict
dataset = CustomQADataset(data)
这样就创建了一个自定义的Dataset类实例
2. 接着创建DataLoader实例,不过在创建之前,我们一般可以定义一个collate_fn函数,用于组合样本成批次,如下示例:
def custom_collate_fn(batch):
questions = [item['question'] for item in batch]
answers = [item['answer'] for item in batch]
input_lengths = [item['input_length'] for item in batch]
input_ids = [item['input_ids'] for item in batch]
target_ids = [item['target_ids'] for item in batch]
target_lengths = [item['target_length'] for item in batch]
# 在这里可以对数据进行进一步处理,如填充、对齐或转换成张量
return {
'questions': questions,
'answers': answers,
'input_lengths': input_lengths,
'input_ids': input_ids,
'target_ids': target_ids,
'target_lengths': target_lengths
}
然后创建DataLoader实例:
# 创建DataLoader实例,并传递自定义的collate_fn
batch_size = 32
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=custom_collate_fn)
现在可以迭代data_loader来获取批次数据:
for batch in data_loader:
questions = batch['questions']
answers = batch['answers']
input_ids = batch['input_ids']
target_ids = batch['target_ids']
# 在这里执行模型训练或其他操作
以上就是对于Dataset和DataLoader使用的一个简单介绍,具体细节和更高级的用法留待以后(挖个坑~)