Intro
pyspark批量写入数据库时,需要分批写入,批量写入时,只要建立一个连接,这样可以显著的提高写入速度。分批写入,容易想到foreachPartition,但是pyspark不能像scala那样
df.rdd.foreachPartition(x=>{
...
})
只支持
df.rdd.foreachPartition(you_function)
看下源码:
def foreachPartition(self, f):
"""
Applies a function to each partition of this RDD.
>>> def f(iterator):
... for x in iterator:
... print(x)
>>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f)
"""
def func(it):
r = f(it)
try:
return iter(r)
except TypeError:
return iter([])
self.mapPartitions(func).count() # Force evaluation
如果you_function想传入其他参数,需要通过偏函数的方式传入。其原理,简单但不一定正确的理解,就是通过偏函数,绑定参数,生产个新函数,供foreachPartition调用。直接看代码
Code
import pandas as pd
import functools
from pyspark.sql import SparkSession
df = pd.DataFrame({"x":list(range(10))})
spark = SparkSession.builder.appName("pysaprk").getOrCreate()
spark_df = spark.createDataFrame(df)
spark_df.show()
+---+
| x|
+---+
| 0|
| 1|
| 2|
| 3|
| 4|
| 5|
| 6|
| 7|
| 8|
| 9|
+---+
def test_f(part,id):
for row in part:
print(f"id={id},x={row['x']}")
spark_df.repartition(2).rdd.foreachPartition(functools.partial(test_f,id=0))
这样就可以把id参数传进去了
Ref
[1] https://blog.csdn.net/sinat_15793123/article/details/80594748
2022-04-24 于南京市江宁区九龙湖