方式一. 联网下载数据集
# pip install pyarrow
from datasets import load_dataset
dataset = load_dataset(path='glue', name='sst2')
正常情况:
DatasetDict({
train: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 67349
})
validation: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 872
})
test: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1821
})
})
异常报错:
ReadTimeout: (ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: a16bb2d9-22f4-467e-9900-317e2368f09b)')
方式二. 本地读取数据集
由于网络原因,无法顺利下载,所以直接去官网手动下载数据文件,然后处理成DatasetDict格式,效果一样。
import pandas as pd
# 读取 Parquet 文件
train_df = pd.read_parquet('../data/sst2/train-00000-of-00001.parquet')
validation_df = pd.read_parquet('../data/sst2/validation-00000-of-00001.parquet')
test_df = pd.read_parquet('../data/sst2/test-00000-of-00001.parquet')
# 将 pandas DataFrame 转换为 DatasetDict 格式
from datasets import Dataset, DatasetDict
dataset = DatasetDict({
'train': Dataset.from_pandas(train_df, preserve_index=False),
'validation': Dataset.from_pandas(validation_df, preserve_index=False),
'test': Dataset.from_pandas(test_df, preserve_index=False)
})
输出:
DatasetDict({
train: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 67349
})
validation: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 872
})
test: Dataset({
features: ['sentence', 'label', 'idx'],
num_rows: 1821
})
})