高效数据加载
在处理庞大的文本数据集时,高效的读取和预处理变得至关重要。基于 PyTorch 的 torch.utils.data.Dataset
类实现的JsonLDataset,用来高效地处理存储为 JSONL (JSON Lines) 格式的大型数据集。
功能概述
实现的JsonlDataset
类提供了以下关键功能:
- 内存映射文件读取:使用
mmap
模块,该类映射大文件到内存,使得访问任何部分的成本接近恒定,而不需要将整个文件加载到内存。 - 元数据支持:该类读取一个与数据集同名的
.meta
文件,该文件包含了每个数据项的起始偏移量和长度,进一步提高访问特定数据项的效率。 - 数据项筛选:通过
min_length
参数,类在初始化时可以过滤掉长度低于指定阈值的数据项。 - 线程安全:通过使用
threading.local
,类保证了在多线程环境下的线程安全,使得每个线程都有自己的文件句柄和内存映射。
类实现细节
构造函数 __init__
在 JsonlDataset
类的构造函数中,主要进行以下操作:
def __init__(self, path: str, dataset_type_id: int = 0, min_length=50):
self.path = path
self.threadlocal = threading.local()
resolved_path = Path(path).resolve()
self.resolved_path = resolved_path
self.meta = Path(f"{resolved_path}.meta")
self.type_id = dataset_type_id
assert os.path.exists(self.meta), f"The cache file:{self.meta} is not found for file:{self.path}"
try:
with open(self.meta, "rb") as f:
meta = np.load(f)
except Exception as e:
print(f"Cannot load file {self.meta}...")
raise e
self.offsets = meta[:, 0]
self.lengths = meta[:, -1]
if min_length > 0:
mask = self.lengths >= min_length
self.offsets = self.offsets[mask]
self.lengths = self.lengths[mask]
此段代码负责解析 .meta
文件并根据 min_length
参数筛选数据项。
方法 __getitem__
__getitem__
方法是实现 torch.utils.data.Dataset
接口的关键部分,用于获取数据集中的单个项:
def __getitem__(self, idx):
f = self._get_mmap()
position = self.offsets[idx]
f.seek(position)
item = f.readline().decode("utf-8")
try:
item = json.loads(item)
item["length"] = len(item["tokens"])
item["type_id"] = self.type_id
except Exception as err:
raise json.decoder.JSONDecodeError(
doc=self.path,
pos=position,
msg=(
f"Error while loading JSONL line in file {self.path} at byte "
f"{position}. Contents of line:\n{item}\n{err}"
),
)
return item
该方法使用内存映射文件快速定位和读取所需的数据行,并处理异常情况。
线程安全和资源管理
通过 _get_mmap
方法确保每个线程有自己的文件句柄和内存映射对象,同时在 __del__
方法中关闭并清理资源。