文章目录
0、介绍
简单易用的数据集加载库,可以方便快捷的从本地或着huggingface hub加载数据集
-
加载在线数据集 load_dataset
-
加载数据集某一项任务
-
按照数据集划分进行加载
-
查看数据集
-
数据集划分
-
数据集选取与过滤
-
数据映射
-
保存与加载
from datasets import load_dataset
from datasets import *
1、加载在线数据集
#加载在线数据集
datasets = load_dataset('madao33/new-title-chinese')
datasets
#加载数据集合集中的某一项任务
boolq_dataset = load_dataset('super_glue','boolq')
boolq_dataset
#按照数据集划分进行加载
datasets = load_dataset('madao33/new-title-chinese',split='train')
datasets
#加载部分数据
datasets = load_dataset('madao33/new-title-chinese',split='train[:100]')
datasets
datasets = load_dataset('madao33/new-title-chinese',split='train[:50%]')
datasets
datasets = load_dataset('madao33/new-title-chinese',split=['train[:50%]',"validation[:10%]"])
datasets
2、查看数据集
#加载在线数据集
datasets = load_dataset('madao33/new-title-chinese')
datasets
datasets['train'][:2]
datasets['train']['title'][:10]
datasets['train'].column_names
datasets['train'].features
3、数据集划分
dataset = datasets['train']
dataset.train_test_split(test_size=0.1)
#加载的数据标签均匀
dataset = boolq_dataset['train']
dataset.train_test_split(test_size=0.1,stratify_by_column='label')
4、数据选取与过滤
#选取
datasets['train'].select([0,1])
#过滤
a = datasets['train'].filter(lambda x: "中国" in x['title'])
a['title'][:5]
5、数据映射
def add_prefix(example):
example['title'] = 'Prefix: ' + example["title"]
return example
prefix_dataset = datasets.map(add_prefix)
prefix_dataset['train'][:10]['title']
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('uer/roberta-base-finetuned-dianping-chinese')
def preprocess_function(example,tokenizer=tokenizer):
model_input = tokenizer(example['content'],max_length=512,truncation=True)
labels = tokenizer(example['title'],max_length=32,truncation=True)
model_input['labels'] = labels['input_ids']
return model_input
process_dataset = datasets.map(preprocess_function,batched=True)
process_dataset
#num_proc 多线程
#batched 按批处理
process_dataset = datasets.map(preprocess_function,batched=True,num_proc=4)
process_dataset
#移除不需要的字段
process_dataset = datasets.map(preprocess_function,batched=True,num_proc=4,remove_columns=datasets['train'].column_names)
process_dataset
6、保存与加载
process_dataset.save_to_disk("./processed_data")
process_dataset.load_from_disk('./processed_data')
process_dataset
7、加载本地数据集
7.1 、直接加载文件作为数据集
dataset = load_dataset('csv',data_files='./dataset/ChnSentiCorp_htl_all.csv',split='train')
dataset
dataset['review'][:2]
dataset = Dataset.from_csv("./dataset/ChnSentiCorp_htl_all.csv")
dataset
7.2、加载文件夹内所有文件作为数据集
dataset = load_dataset('csv',data_dir='./dataset',split='train')
dataset
7.3、通过预先加载其他格式转换为加载数据集
import pandas as pd
data = pd.read_csv('./dataset/ChnSentiCorp_htl_all.csv')
data.head()
dataset = Dataset.from_pandas(data)
dataset
data = [{'text':'aca'},{'text':'1321'},{'text':'123'}]
a = Dataset.from_list(data)
a[2]
7.4、自定义数据加载方式
dataset = load_dataset("./load_data.py",split='train')
dataset
load_dataset函数定义如下
import json
import datasets
from datasets import DownloadManager,DatasetInfo
class CMRC2018(datasets.GeneratorBasedBuilder):
def __info(self)->DatasetInfo:
"""INFO方法,定义数据集信息,这里对数据的字段进行定义
"""
return DatasetInfo(
description="CMRC2018",
features=datasets.Features(
{
"id": datasets.Value("string"),
"context": datasets.Value("string"),
"question": datasets.Value("string"),
"answers": datasets.features.Sequence(
{
'text':datasets.Value('string'),
"answer_start":datasets.Value('int32'),
}
),
}
),
# supervised_keys=None,
# homepage="https://github.com/ymcui/cmrc2018",
)
def _split_generators(self, dl_manager: DownloadManager):
"""
返回datasets.splitFGenerator
涉及两个参数 :name和gen_kwargs
name: 指定的数据集划分
gen_kwargs:指定要读取的文件路径,与_generate_example的参数一致
"""
return [datasets.SplitGenerator(name=datasets.Split.TRAIN,
gen_kwargs={'filepath':"./datasets/cmrc2018_trial.json"})]
pass
def _generate_examples(self,filepath):
"""
生成具体的样本,使用yield
"""
with open(filepath,encoding='utf-8') as f:
data = json.load(f)
for example in data:
for paragraph in example['paragraphs']:
context= paragraph['context'].strip()
for qa in paragraph['qas']:
question = qa['question'].strip()
id_ = qa['id']
answer_starts = [answer['answer_start'] for answer in qa['answers']]
answers = [answer['text'] for answer in qa["answers"]]
yield id_,{
"context":context,
"question":question,
"id":id_,
"answers":{
"answer_starts":answer_starts,
"text":answers,
},
}
8、Dataset with DataCollator
from transformers import DataCollatorWithPadding
dataset = load_dataset('csv',data_files='./datasets/ChnSentiCorp_htl_all.csv',split='train')
dataset = dataset.filter(lambda x : x['review'] is not None)
dataset
def process_function(examples):
tokenized_examples = tokenizer(examples['review'],max_length=128,truncation=True)
tokenized_examples['labels'] = examples["label"]
return tokenized_examples
tokenized_dataset = dataset.map(process_function,batched=True,remove_columns=dataset.column_names)
tokenized_dataset
print(tokenized_dataset[:2])
collator = DataCollatorWithPadding(tokenizer=tokenizer)
from torch.utils.data import DataLoader
dl = DataLoader(tokenized_dataset,batch_size=4,collate_fn=collator,shuffle=True)
next(enumerate(dl))