在自然语言处理(NLP)领域,Hugging Face 是不可或缺的处理库,而 Spark 则是大数据处理的必备工具。将两者的优势结合起来,可以实现高效的 NLP 大数据处理。以下是结合 Hugging Face 和 Spark 的两种方法,基于 Spark & PySpark 3.3.1 版本进行探索。
方法一:升级 Spark 版本至 3.4 及以上
如果你愿意升级 Spark 版本到 3.4 或更高版本,那么结合 Hugging Face 和 Spark 将变得非常方便。Spark 3.4 及以上版本天然支持加载模型进行预测。
关键步骤说明:
- 模型加载策略:需要为每个 Worker 单独加载模型,确保模型在分布式环境中的可用性。
- 文件夹管理:在加载 Hugging Face 预训练模型之前,务必删除之前的模型文件夹,防止加载失败。
注:如果图片无法显示,请检查链接合法性或稍后重试。
方法二:基于 Spark 3.3.1 的手动封装接口
如果你希望保持当前的 Spark 3.3.1 版本,那么可以通过手动封装接口来实现 Hugging Face 和 Spark 的结合。以下是详细的代码实现和关键说明。
封装分布式的模型缓存
为了高效管理模型加载和缓存,我们从spark3.4的源代码中抽取了一个分布式的模型缓存机制:
from collections import OrderedDict
from threading import Lock
from typing import Callable, Optional
from uuid import UUID
class ModelCache:
"""Cache for model prediction functions on executors.
This requires the `spark.python.worker.reuse` configuration to be set to `true`, otherwise a
new python worker (with an empty cache) will be started for every task.
If a python worker is idle for more than one minute (per the IDLE_WORKER_TIMEOUT_NS setting in
PythonWorkerFactory.scala), it will be killed, effectively clearing the cache until a new python
worker is started.
Caching large models can lead to out-of-memory conditions, which may require adjusting spark
memory configurations, e.g. `spark.executor.memoryOverhead`.
"""
_models: OrderedDict = OrderedDict()
_capacity: int = 3 # "reasonable" default size for now, make configurable later, if needed
_lock: Lock = Lock()
@staticmethod
def add(uuid: UUID, predict_fn: Callable) -> None:
with ModelCache._lock:
ModelCache._models[uuid] = predict_fn
ModelCache._models.move_to_end(uuid)
if len(ModelCache._models) > ModelCache._capacity:
ModelCache._models.popitem(last=False)
@staticmethod
def get(uuid: UUID) -> Optional[Callable]:
with ModelCache._lock:
predict_fn = ModelCache._models.get(uuid)
if predict_fn:
ModelCache._models.move_to_end(uuid)
return predict_fn
封装处理逻辑
from __future__ import annotations
import os
import argparse
import random
import logging
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, column, encode
from pyspark.sql.types import *
from datetime import datetime, timedelta
import requests as req
from io import BytesIO
import numpy as np
import uuid
import inspect
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import (
ArrayType,
ByteType,
DataType,
DoubleType,
FloatType,
IntegerType,
LongType,
ShortType,
StringType,
StructType,
)
from typing import Any, Callable, Iterator, List, Mapping, TYPE_CHECKING, Tuple, Union, Optional
supported_scalar_types = (
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType,
StringType,
)
PredictBatchFunction = Callable[
[np.ndarray], Union[np.ndarray, Mapping[str, np.ndarray], List[Mapping[str, np.dtype]]]
]
hadoop = os.path.join(os.environ['HADOOP_COMMON_HOME'], 'bin/hadoop')
def init_spark():
spark = SparkSession.builder \
.config("spark.sql.caseSensitive", "false") \
.config("spark.shuffle.spill", "true") \
.config("spark.shuffle.spill.compress", "true") \
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
.config("metastore.catalog.default", "hive") \
.config("spark.sql.hive.convertMetastoreOrc", "true") \
.config("spark.kryoserializer.buffer.max", "1024m") \
.config("spark.kryoserializer.buffer", "64m") \
.config("spark.driver.maxResultSize","4g") \
.config("spark.sql.broadcastTimeout", "36000") \
.enableHiveSupport() \
.getOrCreate()
return spark
def system_command(command):
code = os.system(command)
if code != 0:
logging.error(f"Command: ({
command}) excute failed.")
else:
logging.info(f"Command: ({
command}) excute succeed.")
def parse_args<