parallize()
和 boradcast()
方法,在不使用 spark.io.encryption.enabled=true
的情况下,都会以文件的格式跟 JVM 交互,因为将一个大的 dataset 发送到 JVM 是比较耗时的,所以 pyspark 默认采用本地文件的方式,如果有安全方面的考虑,毕竟 dataset 会 pickle
之后存在本地,那么就需要考虑 spark.io.encryption.enabled=true
这个选项了,不过这个选项肯定也会增加 CPU 的使用的。
需要注意的是,这些临时文件是存在 spark.local.dirs
这个目录下,对应的 spark 目录下的子目录,并且是以 pyspark-
开头的。这个目录是调用了 Java 的方法来创建的临时目录。
通过 pyspark 代码的全局搜索,这个目录只有在 parallize()
和 boradcast()
方法会写到。
在使用过中,用户发现广播变量调用了 destroy()
方法之后还是无法删除本地的文件,这点还有待确认,但是按照 pyspark 的源码来看是调用了 Python 的 os.unlink()
方法。总之,pyspark 要谨慎考虑使用的。
context.py 的部分代码。
def parallelize(self, c, numSlices=None):
"""
Distribute a local Python collection to form an RDD. Using xrange
is recommended if the input represents a range for performance.
>>> sc.parallelize([0, 2, 3, 4, 6], 5).glom().collect()
[[0], [2], [3], [4], [6]]
>>> sc.parallelize(xrange(0, 6, 2), 5).glom().collect()
[[], [0], [], [2], [4]]
"""
numSlices = int(numSlices) if numSlices is not None else self.defaultParallelism
if isinstance(c, xrange):
size = len(c)
if size == 0:
return self.parallelize([], numSlices)
step = c[1] - c[0] if size > 1 else 1
start0 = c[0]
def getStart(split):
return start0 + int((split * size / numSlices)) * step
def f(split, iterator):
# it's an empty iterator here but we need this line for triggering the
# logic of signal handling in FramedSerializer.load_stream, for instance,
# SpecialLengths.END_OF_DATA_SECTION in _read_with_length. Since
# FramedSerializer.load_stream produces a generator, the control should
# at least be in that function once. Here we do it by explicitly converting
# the empty iterator to a list, thus make sure worker reuse takes effect.
# See more details in SPARK-26549.
assert len(list(iterator)) == 0
return xrange(getStart(split), getStart(split + 1), step)
return self.parallelize([], numSlices).mapPartitionsWithIndex(f)
# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
def reader_func(temp_filename):
return self._jvm.PythonRDD.readRDDFromFile(self._jsc, temp_filename, numSlices)
def createRDDServer():
return self._jvm.PythonParallelizeServer(self._jsc.sc(), numSlices)
jrdd = self._serialize_to_jvm(c, serializer, reader_func, createRDDServer)
return RDD(jrdd, self, serializer)
def _serialize_to_jvm(self, data, serializer, reader_func, createRDDServer):
"""
Using py4j to send a large dataset to the jvm is really slow, so we use either a file
or a socket if we have encryption enabled.
:param data:
:param serializer:
:param reader_func: A function which takes a filename and reads in the data in the jvm and
returns a JavaRDD. Only used when encryption is disabled.
:param createRDDServer: A function which creates a PythonRDDServer in the jvm to
accept the serialized data, for use when encryption is enabled.
:return:
"""
if self._encryption_enabled:
# with encryption, we open a server in java and send the data directly
server = createRDDServer()
(sock_file, _) = local_connect_and_auth(server.port(), server.secret())
chunked_out = ChunkedStream(sock_file, 8192)
serializer.dump_stream(data, chunked_out)
chunked_out.close()
# this call will block until the server has read all the data and processed it (or
# throws an exception)
r = server.getResult()
return r
else:
# without encryption, we serialize to a file, and we read the file in java and
# parallelize from there.
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
try:
try:
serializer.dump_stream(data, tempFile)
finally:
tempFile.close()
return reader_func(tempFile.name)
finally:
# we eagerily reads the file so we can delete right after.
os.unlink(tempFile.name)
broadcast.py 的部分代码。
class Broadcast(object):
def __init__(self, sc=None, value=None, pickle_registry=None, path=None,
sock_file=None):
"""
Should not be called directly by users -- use :meth:`SparkContext.broadcast`
instead.
"""
if sc is not None:
# we're on the driver. We want the pickled data to end up in a file (maybe encrypted)
f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
self._path = f.name
self._sc = sc
self._python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
if sc._encryption_enabled:
# with encryption, we ask the jvm to do the encryption for us, we send it data
# over a socket
port, auth_secret = self._python_broadcast.setupEncryptionServer()
(encryption_sock_file, _) = local_connect_and_auth(port, auth_secret)
broadcast_out = ChunkedStream(encryption_sock_file, 8192)
else:
# no encryption, we can just write pickled data directly to the file from python
broadcast_out = f
self.dump(value, broadcast_out)
if sc._encryption_enabled:
self._python_broadcast.waitTillDataReceived()
self._jbroadcast = sc._jsc.broadcast(self._python_broadcast)
self._pickle_registry = pickle_registry
else:
# we're on an executor
self._jbroadcast = None
self._sc = None
self._python_broadcast = None
if sock_file is not None:
# the jvm is doing decryption for us. Read the value
# immediately from the sock_file
self._value = self.load(sock_file)
else:
# the jvm just dumps the pickled data in path -- we'll unpickle lazily when
# the value is requested
assert(path is not None)
self._path = path