Hugging Face datasets
库的常用方法和属性
datasets
是 Hugging Face 提供的 高效数据处理库,用于 加载、转换、处理、筛选、格式化和保存数据集,特别适用于 文本分类、机器翻译、摘要、问答等 NLP 任务。
1. datasets
的常见属性
在加载数据集后,可以访问以下常见属性:
from datasets import load_dataset
dataset = load_dataset("imdb")
属性 | 作用 | 示例 |
---|---|---|
dataset | 数据集字典,包含 train/test/validation | print(dataset) |
dataset["train"] | 访问训练集 | dataset["train"] |
dataset.column_names | 获取数据集列名 | dataset["train"].column_names |
dataset.features | 获取数据集的特征信息 | dataset["train"].features |
dataset.num_rows | 获取数据集样本数量 | dataset["train"].num_rows |
dataset.shape | 获取数据集形状 | dataset["train"].shape |
2. datasets
的常用方法
方法 | 作用 |
---|---|
load_dataset(name, split) | 加载数据集 |
dataset.select(indices) | 选择部分数据 |
dataset.shuffle(seed) | 随机打乱数据集 |
dataset.train_test_split(test_size) | 拆分训练集和测试集 |
dataset.filter(function) | 过滤数据 |
dataset.map(function) | 处理数据 |
dataset.to_pandas() | 转换为 Pandas DataFrame |
dataset.to_csv("file.csv") | 保存为 CSV 文件 |
dataset.save_to_disk("path") | 保存数据集到本地 |
load_from_disk("path") | 从本地加载数据集 |
3. datasets
详细用法
3.1. 加载数据集
加载 Hugging Face Hub 上的公开数据集
from datasets import load_dataset
dataset = load_dataset("imdb")
print(dataset)
输出
DatasetDict({
train: Dataset({
features: ['text', 'label'],
num_rows: 25000
})
test: Dataset({
features: ['text', 'label'],
num_rows: 25000
})
})
加载本地 CSV/JSON/TXT 文件
dataset = load_dataset("csv", data_files="data.csv")
dataset = load_dataset("json", data_files="data.json")
dataset = load_dataset("text", data_files="data.txt")
3.2. 访问数据
print(dataset["train"][0])
示例输出
{'text': 'This is a great movie!', 'label': 1}
获取某列数据
texts = dataset["train"]["text"]
labels = dataset["train"]["label"]
3.3. 训练集和测试集拆分
train_test_split = dataset["train"].train_test_split(test_size=0.2)
train_dataset = train_test_split["train"]
test_dataset = train_test_split["test"]
3.4. map()
- 对整个数据集进行转换
- 将文本转换为小写
def lowercase(example):
example["text"] = example["text"].lower()
return example
dataset = dataset.map(lowercase)
- 添加新字段
def add_length(example):
example["text_length"] = len(example["text"])
return example
dataset = dataset.map(add_length)
- 批量处理(提高速度)
def batch_process(batch):
batch["text_length"] = [len(text) for text in batch["text"]]
return batch
dataset = dataset.map(batch_process, batched=True)
- 移除不需要的列
dataset = dataset.map(lowercase, remove_columns=["text_length"])
3.5. filter()
- 按条件筛选数据
- 筛选文本长度大于 100 的数据
dataset = dataset.filter(lambda example: len(example["text"]) > 100)
- 筛选正面情感的评论
dataset = dataset.filter(lambda example: example["label"] == 1)
- 使用
num_proc
进行多进程加速
dataset = dataset.filter(lambda example: len(example["text"]) > 100, num_proc=4)
3.6. 转换数据格式
转换为 PyTorch/TensorFlow 格式
dataset.set_format(type="torch", columns=["text", "label"])
dataset.set_format(type="tensorflow", columns=["text", "label"])
3.7. 保存数据集
dataset.save_to_disk("./my_dataset")
重新加载保存的数据
from datasets import load_from_disk
dataset = load_from_disk("./my_dataset")
3.8. 处理超大规模数据
流式加载超大数据集
dataset = load_dataset("c4", split="train", streaming=True)
for sample in dataset:
print(sample)
break # 仅打印第一条数据
适用于 大规模 NLP 数据,如 C4
、Common Crawl
。
3.9. 数据集拆分与合并
按比例拆分数据
train_test_split = dataset["train"].train_test_split(test_size=0.2)
train_dataset = train_test_split["train"]
test_dataset = train_test_split["test"]
组合多个数据集
combined_dataset = dataset1["train"].concatenate(dataset2["train"])
4. datasets
主要方法对比
方法 | 作用 |
---|---|
load_dataset | 加载 Hugging Face Hub 或本地数据集 |
map | 转换数据(分词、文本清理、特征添加等) |
filter | 筛选数据(去除短文本、保留特定类别等) |
set_format | 转换数据格式(PyTorch/TensorFlow) |
train_test_split | 拆分数据集(训练集/测试集) |
concatenate_datasets | 合并多个数据集 |
save_to_disk | 保存数据集到本地 |
load_from_disk | 加载本地数据集 |
5. 总结
datasets
提供 高效的数据加载、预处理、筛选、存储 方法,适用于 文本分类、机器翻译、摘要、问答等 NLP 任务。
常见方法:
load_dataset(name)
加载数据集map()
转换数据filter()
筛选数据set_format()
转换数据格式save_to_disk()
保存数据集load_from_disk()
加载数据集train_test_split()
拆分数据concatenate_datasets()
合并数据集