一、参考资料
Hugging Face快速入门(重点讲解模型(Transformers)和数据集部分(Datasets))
二、Hugging Face数据集
Datasets类库(github, 官方文档)可以让你非常方便的访问和分享数据集,也可以用来对NLP、CV、语音等任务进行评价(Evaluation metrics)。
1. 安装Datasets库
pip install datasets
# 如果使用语音(Audio)数据集
pip install datasets[audio]
# 如果使用图片(Image)数据集
pip install datasets[vision]
2. 查找数据集
3. (在线)加载数据集
3.1 函数原型
def load_dataset(
path: str,
name: Optional[str] = None,
data_dir: Optional[str] = None,
data_files: Union[Dict, List] = None,
split: Optional[Union[str, Split]] = None,
cache_dir: Optional[str] = None,
features: Optional[Features] = None,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[GenerateMode] = None,
ignore_verifications: bool = False,
save_infos: bool = False,
script_version: Optional[Union[str, Version]] = None,
**config_kwargs,
) -> Union[DatasetDict, Dataset]:
参数解释
path
:参数path表示数据集的名字或者路径。可以是如下几种形式(每种形式的使用方式后面会详细说明)- 数据集的名字,比如
imdb
、glue
。 - 数据集文件格式,比如
json
、csv
、parquet
、txt
。 - 数据集目录中的处理数据集的脚本(
.py
)文件,比如glue/glue.py
。
- 数据集的名字,比如
name
:参数name表示数据集中的子数据集,当一个数据集包含多个数据集时,就需要这个参数,比如glue数据集下就包含sst2
、cola
、qqp
等多个子数据集,此时就需要指定name来表示加载哪一个子数据集。data_dir
:数据集所在的目录。data_files
:数据集文件。cache_dir
:构建的数据集缓存目录,方便下次快速加载。
3.2 加载数据集(方式一)
from datasets import load_dataset
dataset = load_dataset(path="glue", name='ax')
print(dataset)
输出结果
DatasetDict({
test: Dataset({
features: ['premise', 'hypothesis', 'label', 'idx'],
num_rows: 1104
})
})
3.2 3 加载数据集(方式二)
先下载到本地,然后再加载本地数据集。
下载默认路径
可以通过 cache_dir
属性指定下载路径。
Windows路径:
C:\Users\[用户名]\.cache\huggingface\datasets
Linux路径:
~/.cache/huggingface/datasets
from datasets import load_dataset, load_from_disk
# glue数据集的cola子集
dataset = load_dataset(path='glue', name='cola', cache_dir='./raw_datasets')
# 保存到本地
dataset.save_to_disk('/root/Downloads/datasets/glue/cola')
# 加载本地数据集
raw_dataset = load_from_disk("/root/Downloads/datasets/glue/cola")
print(raw_dataset )
输出结果
DatasetDict({
train: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 8551
})
validation: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1043
})
test: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1063
})
})
数据集目录
-- glue_cola
|-- dataset_dict.json
|-- test
| |-- data-00000-of-00001.arrow
| |-- dataset_info.json
| `-- state.json
|-- train
| |-- data-00000-of-00001.arrow
| |-- dataset_info.json
| `-- state.json
`-- validation
|-- data-00000-of-00001.arrow
|-- dataset_info.json
`-- state.json
4. (离线)加载数据集
如果 load_dataset
加载在线数据集失败,可下载数据集到本地,再加载本地数据集。
注意:手动离线下载的是原始数据,没有匹配datasets库的格式,导致 load_dataset()
加载失败。
关于如何本地下载Hugging Face原始数据,可参考博客:如何批量下载hugging face模型和数据集文件
4.1 GLUE数据集
数据集地址: nyu-mll/glue,有原始数据文件。
4.1.1 下载原始数据
在main
、script
两个不同分支上分别下载原始数据和 .py
脚本数据,下载之后合并这两个文件夹。
去掉后面的 /tree/main
,然后增添 .git
,即可使用git下载。
# 下载原始数据(main分支)
git lfs clone https://huggingface.co/datasets/nyu-mll/glue.git -b main
# 下载.py文件(script分支)
git lfs clone https://huggingface.co/datasets/nyu-mll/glue.git -b script
4.1.2 加载数据集
from datasets import load_dataset
dataset = load_dataset(path="/root/Downloads/glue", name='cola')
print(dataset)
输出结果
DatasetDict({
train: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 8551
})
validation: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1043
})
test: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1063
})
})
4.2 ChnSentiCorp数据集
用于中文情感分析,标记了每条评论的情感极性(0或1)。
数据集地址: seamew/ChnSentiCorp,有原始数据文件。
4.2.1 下载原始数据
git lfs clone https://huggingface.co/datasets/seamew/ChnSentiCorp.git
4.2.2 保存数据集
git下载的文件无法直接使用:
load_dataset
会执行.python
文件,通过https://drive.google.com
下载数据导致下载失败报错。load_from_disk
会执行失败,因为该文件夹非dist数据集格式。
解决办法:通过 save_to_disk()
保存为本地数据集。
# 设置data_files
data_files = {
'train': '/root/Downloads/ChnSentiCorp/chn_senti_corp-train.arrow',
'test': '/root/Downloads/ChnSentiCorp/chn_senti_corp-test.arrow',
'validation': '/root/Downloads/ChnSentiCorp/chn_senti_corp-validation.arrow'}
# 加载arrow数据集
dataset = load_dataset(path='arrow', data_files=data_files)
# 保存至本地
dataset.save_to_disk('/root/Downloads/datasets/chn_senti_corp')
保存本地数据集后,数据集目录如下所示:
`-- chn_senti_corp
|-- dataset_dict.json
|-- test
| |-- data-00000-of-00001.arrow
| |-- dataset_info.json
| `-- state.json
|-- train
| |-- data-00000-of-00001.arrow
| |-- dataset_info.json
| `-- state.json
`-- validation
|-- data-00000-of-00001.arrow
|-- dataset_info.json
`-- state.json
4.2.3 加载数据集
dataset = load_from_disk('/root/Downloads/datasets/chn_senti_corp')
print(dataset)
输出结果
DatasetDict({
train: Dataset({
features: ['label', 'text'],
num_rows: 9600
})
test: Dataset({
features: ['label', 'text'],
num_rows: 1200
})
validation: Dataset({
features: ['label', 'text'],
num_rows: 1200
})
})
4.3 peoples_daily_ner数据集
用于中文命名实体识别(NER),来自人民日报的文本数据,标记了人名、地名 、组织机构等。
数据集地址: peoples-daily-ner/peoples_daily_ner ,并无原始数据文件。
通过查看 .py
文件,可知数据集下载地址为:
_URL = "https://raw.githubusercontent.com/OYE93/Chinese-NLP-Corpus/master/NER/People's%20Daily/"
但是,https://raw.githubusercontent.com
无法访问,可通过 https://github.com
下载原始数据。
4.3.1 下载原始数据
分别下载原始数据以及数据集,并将原始数据拷贝到数据集文件夹中。
# 下载数据集
git lfs clone https://huggingface.co/datasets/peoples_daily_ner.git
# 下载原始数据
example.dev、example.train、example.test
# 拷贝原始数据
cp example.dev /root/Downloads/peoples_daily_ner
cp example.train /root/Downloads/peoples_daily_ner
cp example.test /root/Downloads/peoples_daily_ner
4.3.2 修改.py
文件
_URL = "https://raw.githubusercontent.com/OYE93/Chinese-NLP-Corpus/master/NER/People's%20Daily/"
改为
_URL = ""
4.3.3 保存数据集
_URL
修改为本地路径。
dataset = load_dataset('/root/Downloads/peoples_daily_ner')
dataset.save_to_disk('/root/Downloads/datasets/peoples_daily_ner')
保存本地数据集后,数据集目录如下所示:
`-- peoples_daily_ner
|-- dataset_dict.json
|-- test
| |-- data-00000-of-00001.arrow
| |-- dataset_info.json
| `-- state.json
|-- train
| |-- data-00000-of-00001.arrow
| |-- dataset_info.json
| `-- state.json
`-- validation
|-- data-00000-of-00001.arrow
|-- dataset_info.json
`-- state.json
4.3.4 加载数据集
dataset = load_from_disk('/root/Downloads/datasets/peoples_daily_ner')
print(dataset)
输出结果
DatasetDict({
train: Dataset({
features: ['id', 'tokens', 'ner_tags'],
num_rows: 20865
})
validation: Dataset({
features: ['id', 'tokens', 'ner_tags'],
num_rows: 2319
})
test: Dataset({
features: ['id', 'tokens', 'ner_tags'],
num_rows: 4637
})
})
5. 自定义数据集
自定义数据集,主要包括以下步骤:
- 首先,使用
Dataset.from_dict
方法定义了一个包含两个样本的数据集。 - 其次,将这个数据集添加到
DatasetDict
对象中,并使用键名my_dataset
进行标识。 - 然后,打印
DatasetDict
对象中的my_dataset
数据集。 - 最后,使用
save_to_disk
方法将数据集保存到指定位置。
5.1 保存数据集
from datasets import DatasetDict, Dataset
# 定义数据集
my_dataset = Dataset.from_dict({
"text": ["Hello, world!", "How are you?"],
"label": [1, 0]
})
# 将数据集添加到DatasetDict中
dataset_dict = DatasetDict({"my_dataset": my_dataset})
# 打印数据集
print(dataset_dict["my_dataset"])
# 将数据集保存到指定位置
my_dataset.save_to_disk("/root/Downloads/datasets/my_dataset")
保存本地数据集后,数据集目录如下所示:
|-- my_dataset
| |-- data-00000-of-00001.arrow
| |-- dataset_info.json
| `-- state.json
5.2 加载数据集
# 加载数据集
dataset = load_from_disk('/root/Downloads/datasets/peoples_daily_ner')
print(dataset)
输出结果
Dataset({
features: ['text', 'label'],
num_rows: 2
})
三、FAQ
Q:ConnectionError: Couldn't reach
加载(load_dataset)Huggingface/datasets下的数据集遇到网络连接问题解决办法
ConnectionError: Couldn't reach https://raw.githubusercontent.com/huggingface/datasets/2.4.0/datasets/glue/glue.py (ConnectionError(MaxRetryError("HTTPSConnectionPool(host='raw.githubusercontent.com', port=443): Max retries exceeded with url: /huggingface/datasets/2.4.0/datasets/glue/glue.py (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x000001D31F86E310>: Failed to establish a new connection: [Errno 11004] getaddrinfo failed'))")))
MaxRetryError: HTTPSConnectionPool(host='cdn-lfs.hf-mirror.com', port=443): Max retries exceeded with url: /datasets/glue/2e7538afa2000e63f5343f16a758d75c452661a384208399d2035cd2fce45c33?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27train-00000-of-00001.parquet%3B+filename%3D%22train-00000-of-00001.parquet%22%3B&Expires=1722497068&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjQ5NzA2OH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9kYXRhc2V0cy9nbHVlLzJlNzUzOGFmYTIwMDBlNjNmNTM0M2YxNmE3NThkNzVjNDUyNjYxYTM4NDIwODM5OWQyMDM1Y2QyZmNlNDVjMzM~cmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=YJ6O6~aOVDbT5hf4b1-ZP-4n1monO755EU9nnA35LSyikLAzMtDRP5ZpfJy1QynV9g3QKmo9ad9KRA9EWGu4qfc-TgYU~B4cctYea~TKfIG7Q7Z1ETL9~eWn~RadiCBcAEeSfdMKW0yvSHuoeyYYdJwk-tuov0PZI4pvX4n8ysnafdfalIjMiWqQSKVuwJ22HAEub6up9IAriUrid3AQDSuz7u49BT0MTYW1y2HdZQHGtIRWxsmjdHlW6Vyrfib-gEvIxI~6zxqkBoXKU3OROOSOnZUfoV0qC2i385q53Y37zZErX2CUDUgSyxa-dG7ctZlZSKzbWuHGHDBzfyStqQ__&Key-Pair-Id=K3ESJI6DHPFC7 (Caused by ProxyError('Cannot connect to proxy.', TimeoutError('_ssl.c:980: The handshake operation timed out')))
ProxyError: (MaxRetryError("HTTPSConnectionPool(host='cdn-lfs.hf-mirror.com', port=443): Max retries exceeded with url: /datasets/glue/2e7538afa2000e63f5343f16a758d75c452661a384208399d2035cd2fce45c33?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27train-00000-of-00001.parquet%3B+filename%3D%22train-00000-of-00001.parquet%22%3B&Expires=1722497068&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjQ5NzA2OH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9kYXRhc2V0cy9nbHVlLzJlNzUzOGFmYTIwMDBlNjNmNTM0M2YxNmE3NThkNzVjNDUyNjYxYTM4NDIwODM5OWQyMDM1Y2QyZmNlNDVjMzM~cmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=YJ6O6~aOVDbT5hf4b1-ZP-4n1monO755EU9nnA35LSyikLAzMtDRP5ZpfJy1QynV9g3QKmo9ad9KRA9EWGu4qfc-TgYU~B4cctYea~TKfIG7Q7Z1ETL9~eWn~RadiCBcAEeSfdMKW0yvSHuoeyYYdJwk-tuov0PZI4pvX4n8ysnafdfalIjMiWqQSKVuwJ22HAEub6up9IAriUrid3AQDSuz7u49BT0MTYW1y2HdZQHGtIRWxsmjdHlW6Vyrfib-gEvIxI~6zxqkBoXKU3OROOSOnZUfoV0qC2i385q53Y37zZErX2CUDUgSyxa-dG7ctZlZSKzbWuHGHDBzfyStqQ__&Key-Pair-Id=K3ESJI6DHPFC7 (Caused by ProxyError('Cannot connect to proxy.', TimeoutError('_ssl.c:980: The handshake operation timed out')))"), '(Request ID: 17a1f63a-c6f8-4a50-bd28-94480076328a)')
# 错误原因
无法访问外网,导致数据集加载失败
# 解决方法
下载数据集到本地,并加载本地数据集