🤗关注公众号 funNLPer 白嫖畅读全文🤗
文章目录
datasets是抱抱脸开发的一个数据集python库,可以很方便的从Hugging Face Hub里下载数据,也可很方便的从本地加载数据集,本文主要对load_dataset方法的使用进行详细说明
1. load_dataset参数
load_dataset有以下参数,具体可参考 源码
def load_dataset(
path: str,
name: Optional[str] = None,
data_dir: Optional[str] = None,
data_files: Union[Dict, List] = None,
split: Optional[Union[str, Split]] = None,
cache_dir: Optional[str] = None,
features: Optional[Features] = None,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[GenerateMode] = None,
ignore_verifications: bool = False,
save_infos: bool = False,
script_version: Optional[Union[str, Version]] = None,
**config_kwargs,
) -> Union[DatasetDict, Dataset]:
path
:参数path表示数据集的名字或者路径。可以是如下几种形式(每种形式的使用方式后面会详细说明)- 数据集的名字,比如imdb、glue
- 数据集文件格式,比如json、csv、parquet、txt
- 数据集目录中的处理数据集的脚本(.py)文件,比如“glue/glue.py”
name
:参数name表示数据集中的子数据集,当一个数据集包含多个数据集时,就需要这个参数,比如glue数据集下就包含"sst2"、“cola”、"qqp"等多个子数据集,此时就需要指定name来表示加载哪一个子数据集data_dir
:数据集所在的目录data_files
:数据集文件cache_dir
:构建的数据集缓存目录,方便下次快速加载
以上为一些常用且比较重要的参数,其他参数很少用到因此在此处不再详细说明,下面会通过一些case更加具体的说明各种用法
2. 详细用法
2.1 从HuggingFace Hub上加载数据
首先我们可以通过如下方式查看Hubs上有哪些数据集
from datasets import list_datasets
datasets_list = list_datasets()
print( len(datasets_list))
print(datasets_list[:10])
输出如下
47660
['acronym_identification', 'ade_corpus_v2', 'adversarial_qa', 'aeslc', 'afrikaans_ner_corpus', 'ag_news', 'ai2_arc', 'air_dialogue', 'ajgt_twitter_ar', 'allegro_reviews']
后面通过直接指定path
等于相关数据集的名字就能下载并加载相关数据集
from datasets import load_dataset
dataset = load_dataset(path='squad', split='train')
2.2 从本地加载数据集
2.2.1 加载指定格式的文件
用path
参数指定数据集格式
- json格式,
path="json"
- csv格式,
path="csv"
- 纯文本格式,
path="text"
- dataframe格式,
path="panda"
- 图片,
path="imagefolder"
然后用data_files
指定文件名称,data_files可以是字符串,列表或者字典,data_dir
指定数据集目录。如下case
from datasets import load_dataset
dataset = load_dataset('csv', data_files='my_file.csv')
dataset = load_dataset('csv', data_files=['my_file_1.csv', 'my_file_2.csv', 'my_file_3.csv'])
dataset = load_dataset('csv', data_files={'train':['my_train_file_1.csv','my_train_file_2.csv'],'test': 'my_test_file.csv'})
2.2.2 加载图片
如下我们通过打开指定图片目录进行加载图片数据集
dataset = load_dataset(path="imagefolder",
data_dir="D:\Desktop\workspace\code\loaddataset\data\images")
print(dataset)
print(dataset["train"][0])
输出
DatasetDict({
train: Dataset({
features: ['image'],
num_rows: 2
})
})
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=800x320 at 0x1E6A636B520>}
图片文本对应
很多情况下加载图片并非只要图片,还会有对应的文本,比如在图片分类的时候,每张图片都对应一个类别。这种情况我们需要在图片所在文件夹中加入一个metadata.jsonl
的文件,来指定每个图片对应的类别,格式如下,注意file_name
字段必须要有,其他字段可自行命名
{
"file_name": "1.jpg",
"class": 1
}
{
"file_name": "2.png",
"class": 0
}
然后我们再来运行
dataset = load_dataset(path="imagefolder",
data_dir="D:\Desktop\workspace\code\loaddataset\data\images")
print(dataset)
print(dataset["train"][0])
输出如下
DatasetDict({
train: Dataset({
features: ['image', 'class'],
num_rows: 2
})
})
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=800x320 at 0x2912172B520>, 'class': 1}
2.2.3 自定义数据集加载脚本
一些情况下加载数据集的逻辑较为复杂,需要自定义加载方式。比如训练ControlNet时,输入有原始图片,边缘图,以及prompt,这时候我们就需要通过在图片所在的目录下写一个python脚本来处理数据加载方式。
如下所示,我们数据处理需要是,每条数据包括两张图片,一个文本。
- step1: 首先我们先创建一个json文件
train.jsonl
,把图片和文本对应起来,json文件的格式如下所示
{"text": "pale golden rod circle with old lace background", "image": "images/0.png", "conditioning_image": "conditioning_images/0.png"}
{"text": "light coral circle with white background", "image": "images/1.png", "conditioning_image": "conditioning_images/1.png"}
{"text": "aqua circle with light pink background", "image": "images/2.png", "conditioning_image": "conditioning_images/2.png"}
- step2:创建一个python脚本
fill50k.py
根据json文件中的对应关系加载图片,python脚本如下所示,这个脚本中定义一个Fill50k
类,并继承datasets.GeneratorBasedBuilder
,在类中重写_info(self):
_split_generators(self, dl_manager)
和_split_generators(self, dl_manager)
这三个方法
import pandas as pd
import datasets
import os
import logging
# 数据集路径设置
META_DATA_PATH = "D:\Desktop\workspace\code\loaddataset\\fill50k\\train.jsonl"
IMAGE_DIR = "D:\Desktop\workspace\code\loaddataset\\fill50k"
CONDITION_IMAGE_DIR = "D:\Desktop\workspace\code\loaddataset\\fill50k"
# 定义数据集中有哪些特征,及其类型
_FEATURES = datasets.Features(
{
"image": datasets.Image(),
"conditioning_image": datasets.Image(),
"text": datasets.Value("string"),
},
)
# 定义数据集
class Fill50k(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [datasets.BuilderConfig(name="default", version=datasets.Version("0.0.2"))]
DEFAULT_CONFIG_NAME = "default"
def _info(self):
return datasets.DatasetInfo(
description="None",
features=_FEATURES,
supervised_keys=None,
homepage="None",
license="None",
citation="None",
)
def _split_generators(self, dl_manager):
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"metadata_path": META_DATA_PATH,
"images_dir": IMAGE_DIR,
"conditioning_images_dir": CONDITION_IMAGE_DIR,
},
),
]
def _generate_examples(self, metadata_path, images_dir, conditioning_images_dir):
metadata = pd.read_json(metadata_path, lines=True)
for _, row in metadata.iterrows():
text = row["text"]
image_path = row["image"]
image_path = os.path.join(images_dir, image_path)
# 打开文件错误时直接跳过
try:
image = open(image_path, "rb").read()
except Exception as e:
logging.error(e)
continue
conditioning_image_path = os.path.join(
conditioning_images_dir, row["conditioning_image"]
)
# 打开文件错误直接跳过
try:
conditioning_image = open(conditioning_image_path, "rb").read()
except Exception as e:
logging.error(e)
continue
yield row["image"], {
"text": text,
"image": {
"path": image_path,
"bytes": image,
},
"conditioning_image": {
"path": conditioning_image_path,
"bytes": conditioning_image,
},
}
- step3: 通过
load_dataset
加载数据集
dataset = load_dataset(path="D:\Desktop\workspace\code\loaddataset\\fill50k\\fill50k.py",
cache_dir="D:\Desktop\workspace\code\loaddataset\\fill50k\cache")
print(dataset)
print(dataset["train"][0])
输出结果如下
DatasetDict({
train: Dataset({
features: ['image', 'conditioning_image', 'text'],
num_rows: 50000
})
})
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512 at 0x1AEA2FF9040>, 'conditioning_image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512 at 0x1AEA2FE2640>, 'text': 'pale golden rod circle with old lace background'}
更多AI算法,请关注微信公众号 funNLPer