运行结果:
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel).
18/04/25 18:27:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Original Data
+----+----------+--------+
|col1| date|quantity|
+----+----------+--------+
| b|2016-09-10| 1|
| a|2016-09-11| 2|
| b|2016-09-14| 6|
| a|2016-09-16| 1|
| b|2016-09-17| 4|
| a|2016-09-20| 2|
+----+----------+--------+
After mark interval date
+----+----------+--------+----+------------------------------------------------+
|col1|date |quantity|diff|next_dates |
+----+----------+--------+----+------------------------------------------------+
|a |2016-09-11|2 |5 |[2016-09-12, 2016-09-13, 2016-09-14, 2016-09-15]|
|a |2016-09-16|1 |4 |[2016-09-17, 2016-09-18, 2016-09-19] |
|b |2016-09-10|1 |4 |[2016-09-11, 2016-09-12, 2016-09-13] |
|b |2016-09-14|6 |3 |[2016-09-15, 2016-09-16] |
+----+----------+--------+----+------------------------------------------------+
convert every list to rows
+----+----------+--------+----+------------------------------------------------+
|col1|date |quantity|diff|next_dates |
+----+----------+--------+----+------------------------------------------------+
|a |2016-09-12|0 |5 |[2016-09-12, 2016-09-13, 2016-09-14, 2016-09-15]|
|a |2016-09-13|0 |5 |[2016-09-12, 2016-09-13, 2016-09-14, 2016-09-15]|
|a |2016-09-14|0 |5 |[2016-09-12, 2016-09-13, 2016-09-14, 2016-09-15]|
|a |2016-09-15|0 |5 |[2016-09-12, 2016-09-13, 2016-09-14, 2016-09-15]|
|a |2016-09-17|0 |4 |[2016-09-17, 2016-09-18, 2016-09-19] |
|a |2016-09-18|0 |4 |[2016-09-17, 2016-09-18, 2016-09-19] |
|a |2016-09-19|0 |4 |[2016-09-17, 2016-09-18, 2016-09-19] |
|b |2016-09-11|0 |4 |[2016-09-11, 2016-09-12, 2016-09-13] |
|b |2016-09-12|0 |4 |[2016-09-11, 2016-09-12, 2016-09-13] |
|b |2016-09-13|0 |4 |[2016-09-11, 2016-09-12, 2016-09-13] |
|b |2016-09-15|0 |3 |[2016-09-15, 2016-09-16] |
|b |2016-09-16|0 |3 |[2016-09-15, 2016-09-16] |
+----+----------+--------+----+------------------------------------------------+
union missing date into original data
+----+----------+--------+
|col1| date|quantity|
+----+----------+--------+
| a|2016-09-11| 2|
| a|2016-09-12| 0|
| a|2016-09-13| 0|
| a|2016-09-14| 0|
| a|2016-09-15| 0|
| a|2016-09-16| 1|
| a|2016-09-17| 0|
| a|2016-09-18| 0|
| a|2016-09-19| 0|
| a|2016-09-20| 2|
| b|2016-09-10| 1|
| b|2016-09-11| 0|
| b|2016-09-12| 0|
| b|2016-09-13| 0|
| b|2016-09-14| 6|
| b|2016-09-15| 0|
| b|2016-09-16| 0|
| b|2016-09-17| 4|
+----+----------+--------+
Compare to original data
+----+----------+--------+
|col1| date|quantity|
+----+----------+--------+
| a|2016-09-11| 2|
| a|2016-09-16| 1|
| a|2016-09-20| 2|
| b|2016-09-10| 1|
| b|2016-09-14| 6|
| b|2016-09-17| 4|
+----+----------+--------+
Process finished with exit code 0
代码:
if __name__ == '__main__':
conf = SparkConf()
sparkSession = SparkSession.builder.appName("Test PredictionTool").config(conf=conf).getOrCreate()
sc = sparkSession.sparkContext
sc.setLogLevel("ERROR")
dfp = pd.DataFrame({'date': ['2016-09-10 00:00:00',
'2016-09-11 00:00:00',
'2016-09-14 00:00:00',
'2016-09-16 00:00:00',
'2016-09-17 00:00:00',
'2016-09-20 00:00:00'],
'quantity': [1, 2, 6, 1, 4, 2],
'col1': ['b', 'a', 'b', 'a', 'b', 'a']})
df = sparkSession.createDataFrame(dfp)
df = df.withColumn('date', to_date('date'))
df = df.withColumn('quantity', df['quantity'].cast('Int'))
print "Original Data"
df.show()
def udf_s_e(start, excludedDiff):
# type: (datetime.datetime, int) -> list
dtFormatter = start
date_list = []
for i in range(excludedDiff - 1):
date_list.append(dtFormatter + datetime.timedelta(days=i+1))
return date_list
fill_dates = udf(udf_s_e, ArrayType(DateType()))
w = Window.orderBy("col1", "date")
# df = df.groupBy('col1')
tempDf = df.withColumn("diff", datediff(lead("date", 1).over(w), "date"))\
.filter("diff > 1")
tempDf = tempDf.withColumn("next_dates", fill_dates("date", "diff"))
print "After mark interval date"
tempDf.show(truncate=False)
tempDf = tempDf\
.withColumn("quantity", functions.lit("0"))\
.withColumn("date", functions.explode("next_dates"))
print "convert every list to rows"
tempDf.show(truncate=False)
result = df.union(tempDf.select("col1", "date", "quantity")).orderBy("date")
print "union missing date into original data"
result.sort(['col1', 'date']).show()
print "Compare to original data"
df.sort(['col1', 'date']).show()
exit(0)
其它tip:
函数定义中参数前的*表示的是将调用时的多个参数放入元组中,**则表示将调用函数时的关键字参数放入一个字典中