from pyspark.sql import SparkSession
from pyspark.sql.functions import col
# 创建 SparkSession
spark = SparkSession.builder.getOrCreate()
# 创建示例 DataFrame
data = [("1", "A", 1),
("2", "A", 1),
("3", "B", 2),
("4", "B", 2),
("5", "B", 2)]
df = spark.createDataFrame(data, ["value", "group", "sample_count"])
df.show()
# 定义窗口函数,按组分组并对行编号
window_spec = Window.partitionBy("group").orderBy(F.rand())
# 对每个组进行采样
sampled_df = (
df.withColumn("row_num", F.row_number().over(window_spec)).filter(col("row_num") <= col("sample_count")).drop("row_num"))
sampled_df.cache()
sampled_df.show()
PysparK dataframe 分组组内采样 某个字段作为采样数量
最新推荐文章于 2024-07-12 09:04:25 发布