pytorch重写DataLoader加载本地数据

pytorch重写DataLoader加载本地数据

前两天学习了HuggingFace Datasets来写一个数据加载脚本,但是,在实验中发现,使用dataloader加载数据的便捷性,这两天查资料勉强重写DataLoader加载本地数据,在这里记录下,如果有错误,欢迎指正。


前言

pytorch官网搜索Dataloader,返回的一篇教程是
writing custom Datasets, Dataloaders and Transforms讲了编写自定义数据集,数据加载器和转换,但是讲的是图像数据,在这里我还使用苏剑林老师在《基于CNN的阅读理解式问答模型:DGCNN》里提供的WebQA和SogouQA数据集来重写Dataloader,因为苏剑林老师是用bert4keras库 来写数据集加载,我看的是task_reading_comprehension_by_mlm.py的数据加载格式,有部分代码我没有看懂,我按照自己的理解来重写的。

一、Dataset class

torch.utils.data.Dataset是表示数据集的抽象类。自定义数据集应继承数据集并覆盖以下方法:

  • __len__:返回len(dataset)数据集的大小
  • __getitem__: 支持索引,如dataset[i]可用于获取第i个样本

对自己的数据集首先要创建一个dataset 类。在__len__中读取文件,要将读取的文件加载到__getitem__中,这样的话所有数据不会立即存储到内存中,二十根据需要读取,因此可以提高内存效率。可视化演示可以参考这里

torch.utils.data.Dataset 是一个抽象类,它只能被继承。在B站上讲解的有两个我参考的比较好的教程,饭客帆刘二大人的Pytorch教程都很好。

二、使用步骤

1.重写Dataset

首先是导包。

import re
import torch 
import tokenizers
from torch.utils.data import Dataset, DataLoader
import numpy as np

这里有几个要预定义的超参数:

max_length = 384
max_p_len = 256
max_q_len = 64
max_a_len = 32

在苏神的示例中
输入: [CLS][MASK][MASK][SEP]问题[SEP]篇章[SEP]
输出: 答案

先要说明的是,苏神的这个数据加载的代码我没有完全看懂,全部的代码我暂时还没看完,我先按照自己的理解来写自己的,等我看完所有的代码会回来重新修改这篇文章。苏神的源码等我看懂了会回来修改的。

class MyDataset(Dataset):    # 继承Dataset模块的Dataset类
    # 初始化定义,得到数据内容
    def __init__(self, data_set):
        super(MyDataset, self).__init__()
        self.data_set = data_set    # 加载数据集
        self.length = len(data_set)    # 数据集长度
     
    # 返回数据集大小
    def __len__(self):
        return self.length
        
    # 数据预处理,这部分根据自己的数据集进行处理
    def __getitem__(self, index):    # index(或item)不能少,这个参数是来挑选某条数据的
        # 
        D = self.data_set[index]    # 从data_set中采样一个数据
        token_ids, segment_ids, a_token_ids = [], [], []
        question = D['question']
        answers = [p['answer'] for p in D['passages'] if p['answer']]
        # 苏神的代码,我没有看懂是否挑选了无答案的,这里是自己改的。挑选带答案的文章
        passage = ""    # 先声明再使用是个好习惯,不然会报错
        for pre_passage in D['passages']:
            if pre_passage['answer']:
                passage = pre_passage['passage']
                break
        passage = re.sub(u' |、|;|,', ',', passage)    # 清洗数据
        final_answer = ''
        for answer in answers:    # 选择答案
            if all([a in passage[:max_p_len - 2] for a in answer.split(' ')]):
                final_answer = answer.replace(' ', ',')
                break
        # print(question)       
        a_token_ids = tokenizer.encode(final_answer, max_length=max_a_len + 1, padding="max_length", truncation=True)    # 答案编码
        q_token_ids = tokenizer.encode(question, max_length=max_q_len + 1, truncation=True)    # 对问题进行截断
        p_token_ids = tokenizer.encode(passage, max_length=max_p_len + 1, truncation=True)

        token_ids += [tokenizer.mask_token_id] * max_a_len
        token_ids += [tokenizer.sep_token_id]
        token_ids += (q_token_ids[1:] + p_token_ids[1:-1])    # [MASK][MASK][SEP]问题[SEP]篇章
        
        token_ids = tokenizer.encode(tokenizer.convert_ids_to_tokens(token_ids),
                                        max_length=max_length ,
                                        padding="max_length", 
                                        truncation=True)    # [CLS][MASK][MASK][SEP]问题[SEP]篇章[SEP]
        
        segment_ids = [0] * len(token_ids)    
        token_ids = torch.as_tensor(token_ids)
        segment_ids = torch.as_tensor(segment_ids)
        a_token_ids = torch.as_tensor(a_token_ids)
        
        return [token_ids, segment_ids], a_token_ids

2.Dataloader加载

代码如下(示例):

首先,实例化

data = MyDataset(train_data)

输出一下结果

dataloader = DataLoader(data, batch_size=8, shuffle = True, 
                        drop_last=True)

for q_data, a_data in dataloader:
    print("q_data", tokenizer.decode(q_data[0][5]))
    print("a_data", tokenizer.decode(a_data[5]))
    break

在这里插入图片描述
在这里插入图片描述

总结

这里自己重写了DataLoader,有需要学习DataLoader加载本地数据的,可以仿照写就可以了。

  • 2
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值