【Foundation】(四)transformers之Dataset

0、介绍

 简单易用的数据集加载库,可以方便快捷的从本地或着huggingface hub加载数据集

  • 加载在线数据集 load_dataset

  • 加载数据集某一项任务

  • 按照数据集划分进行加载

  • 查看数据集

  • 数据集划分

  • 数据集选取与过滤

  • 数据映射

  • 保存与加载

from datasets import load_dataset
from datasets import *

1、加载在线数据集

#加载在线数据集
datasets = load_dataset('madao33/new-title-chinese')
datasets
#加载数据集合集中的某一项任务
boolq_dataset = load_dataset('super_glue','boolq')
boolq_dataset
#按照数据集划分进行加载
datasets = load_dataset('madao33/new-title-chinese',split='train')
datasets
#加载部分数据
datasets = load_dataset('madao33/new-title-chinese',split='train[:100]')
datasets
datasets = load_dataset('madao33/new-title-chinese',split='train[:50%]')
datasets
datasets = load_dataset('madao33/new-title-chinese',split=['train[:50%]',"validation[:10%]"])
datasets

2、查看数据集

#加载在线数据集
datasets = load_dataset('madao33/new-title-chinese')
datasets
datasets['train'][:2]
datasets['train']['title'][:10]
datasets['train'].column_names
datasets['train'].features

3、数据集划分

dataset = datasets['train']
dataset.train_test_split(test_size=0.1)
#加载的数据标签均匀
dataset = boolq_dataset['train']
dataset.train_test_split(test_size=0.1,stratify_by_column='label') 

4、数据选取与过滤

#选取
datasets['train'].select([0,1])
#过滤
a = datasets['train'].filter(lambda x: "中国" in x['title'])
a['title'][:5]

5、数据映射

def add_prefix(example):
    example['title'] = 'Prefix: ' + example["title"]
    return example 
prefix_dataset = datasets.map(add_prefix)
prefix_dataset['train'][:10]['title']
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('uer/roberta-base-finetuned-dianping-chinese')

def preprocess_function(example,tokenizer=tokenizer):
    model_input = tokenizer(example['content'],max_length=512,truncation=True)
    labels = tokenizer(example['title'],max_length=32,truncation=True)
    model_input['labels'] = labels['input_ids']
    return model_input
process_dataset = datasets.map(preprocess_function,batched=True)
process_dataset
#num_proc 多线程
#batched 按批处理
process_dataset = datasets.map(preprocess_function,batched=True,num_proc=4)
process_dataset
#移除不需要的字段
process_dataset = datasets.map(preprocess_function,batched=True,num_proc=4,remove_columns=datasets['train'].column_names)
process_dataset

6、保存与加载

process_dataset.save_to_disk("./processed_data")
process_dataset.load_from_disk('./processed_data')
process_dataset

7、加载本地数据集

7.1 、直接加载文件作为数据集

dataset = load_dataset('csv',data_files='./dataset/ChnSentiCorp_htl_all.csv',split='train')
dataset
dataset['review'][:2]
dataset = Dataset.from_csv("./dataset/ChnSentiCorp_htl_all.csv")
dataset

7.2、加载文件夹内所有文件作为数据集

dataset = load_dataset('csv',data_dir='./dataset',split='train')
dataset

7.3、通过预先加载其他格式转换为加载数据集

import pandas as pd
data = pd.read_csv('./dataset/ChnSentiCorp_htl_all.csv')
data.head()
dataset = Dataset.from_pandas(data)
dataset
data = [{'text':'aca'},{'text':'1321'},{'text':'123'}]
a = Dataset.from_list(data)
a[2]

7.4、自定义数据加载方式


dataset = load_dataset("./load_data.py",split='train')
dataset

load_dataset函数定义如下

import json
import datasets
from datasets import DownloadManager,DatasetInfo
class CMRC2018(datasets.GeneratorBasedBuilder):
    def __info(self)->DatasetInfo:
        """INFO方法,定义数据集信息,这里对数据的字段进行定义
        """
        return DatasetInfo(
            description="CMRC2018",
            features=datasets.Features(
                {
                    "id": datasets.Value("string"),
                    "context": datasets.Value("string"),
                    "question": datasets.Value("string"),
                    "answers": datasets.features.Sequence(
                        {
                            'text':datasets.Value('string'),
                            "answer_start":datasets.Value('int32'),
                        }
                    ),
                }
            ),
            # supervised_keys=None,
            # homepage="https://github.com/ymcui/cmrc2018",
        )
    def _split_generators(self, dl_manager: DownloadManager):
        """
            返回datasets.splitFGenerator
            涉及两个参数 :name和gen_kwargs
            name: 指定的数据集划分
            gen_kwargs:指定要读取的文件路径,与_generate_example的参数一致
        """
        return [datasets.SplitGenerator(name=datasets.Split.TRAIN,
                                        gen_kwargs={'filepath':"./datasets/cmrc2018_trial.json"})]
        
        pass
    def _generate_examples(self,filepath):
        """
        生成具体的样本,使用yield

        """
        with open(filepath,encoding='utf-8') as f:
            data = json.load(f)
            for example in data:
                for paragraph in example['paragraphs']:
                    context= paragraph['context'].strip()
                    for qa in paragraph['qas']:
                        question = qa['question'].strip()
                        id_ = qa['id']

                        answer_starts = [answer['answer_start'] for answer in qa['answers']]
                        answers = [answer['text'] for answer in qa["answers"]]

                        yield id_,{
                            "context":context,
                            "question":question,
                            "id":id_,
                            "answers":{
                                "answer_starts":answer_starts,
                                "text":answers,
                            },
                        }

8、Dataset with DataCollator

from transformers import DataCollatorWithPadding
dataset = load_dataset('csv',data_files='./datasets/ChnSentiCorp_htl_all.csv',split='train')
dataset = dataset.filter(lambda x : x['review'] is not None)
dataset
def process_function(examples):
    tokenized_examples = tokenizer(examples['review'],max_length=128,truncation=True)
    tokenized_examples['labels'] = examples["label"]
    return tokenized_examples
tokenized_dataset = dataset.map(process_function,batched=True,remove_columns=dataset.column_names)
tokenized_dataset
print(tokenized_dataset[:2])
collator = DataCollatorWithPadding(tokenizer=tokenizer)
from torch.utils.data import DataLoader
dl = DataLoader(tokenized_dataset,batch_size=4,collate_fn=collator,shuffle=True)

next(enumerate(dl))
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

鲸可落

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值