from pyspark import SparkContext, SparkConf
# 创建 SparkConf 和 SparkContext 对象
conf = SparkConf().setAppName("MyApp")
sc = SparkContext(conf=conf)
# 加载数据为 RDD
data = sc.textFile("file:///home/ubuntu/Desktop/products.txt")
# 查看 RDD 中的前 10 条记录
for line in data.take(10):
print(line)
header = data.first() # 获取表头
filtered_data = data.filter(lambda line: line != header) # 过滤表头,保留剩下的记录
for line in filtered_data.take(10):
print(line)
# 过滤表头并转换为 (category, price) 键值对
categories = data.filter(lambda line: "id" not in line) \
.map(lambda line: (line.split(",")[3], float(line.split(",")[2])))
# 计算每个水果 category 的平均价格
avg_prices = categories.combineByKey(
lambda x: (x, 1), # 初始值
lambda acc, x: (acc[0] + x, acc[1] + 1), # 对每个分区的值进行聚合
lambda acc1, acc2: (acc1[0] + acc2[0], acc1[1] + acc2[1]) # 合并分区结果
).mapValues(lambda x: x[0] / x[1]) # 计算平均价格
# 输出结果
for category, avg_price in avg_prices.collect():
print("{}: {}".format(category, avg_price))