网上学习资料一大堆,但如果学到的知识不成体系,遇到问题时只是浅尝辄止,不再深入研究,那么很难做到真正的技术提升。
一个人可以走的很快,但一群人才能走的更远!不论你是正从事IT行业的老鸟或是对IT行业感兴趣的新人,都欢迎加入我们的的圈子(技术交流、学习资源、职场吐槽、大厂内推、面试辅导),让我们一起学习成长!
创建 SparkSession 对象
Spark 2.0 以上版本的 spark-shell 在启动时会自动创建一个名为 spark 的 SparkSession 对象。
当需要手工创建时, SparkSession 可以由其伴生对象的 builder 方法创建出来。
spark = SparkSession.builder.master("local[\*]").appName("spark").getOrCreate()
使用 Spark 构建 DataFrame 数据 (Optional)
当数据量较小时, 可以使用该方法手工构建 DataFrame 数据。
构建数据行 Row (以前 3 行为例):
Row(Date="2015-12-31", Code="'000422", Open="7.93", High="7.95", Low="7.76", Close="7.77", Pre_Close="7.93", Change="-0.020177", Turnover_Rate="0.015498", Volume="13915200", MA5="7.86", MA10="7.85")
ROW(Date="2015-12-30", Code="'000422", Open="7.86", High="7.93", Low="7.75", Close="7.93", Pre_Close="7.84", Change="0.011480", Turnover_Rate="0.018662", Volume="16755900", MA5="7.90", MA10="7.85")
Row(Date="2015-12-29", Code="'000422", Open="7.72", High="7.85", Low="7.69", Close="7.84", Pre_Close="7.71", Change="0.016861", Turnover_Rate="0.015886", Volume="14263800", MA5="7.90", MA10="7.81")
将构建好的数据行 Row 加入列表 (以前 3 行为例):
Data_Rows = [
Row(Date="2015-12-31", Code="'000422", Open="7.93", High="7.95", Low="7.76", Close="7.77", Pre_Close="7.93", Change="-0.020177", Turnover_Rate="0.015498", Volume="13915200", MA5="7.86", MA10="7.85"),
ROW(Date="2015-12-30", Code="'000422", Open="7.86", High="7.93", Low="7.75", Close="7.93", Pre_Close="7.84", Change="0.011480", Turnover_Rate="0.018662", Volume="16755900", MA5="7.90", MA10="7.85"),
Row(Date="2015-12-29", Code="'000422", Open="7.72", High="7.85", Low="7.69", Close="7.84", Pre_Close="7.71", Change="0.016861", Turnover_Rate="0.015886", Volume="14263800", MA5="7.90", MA10="7.81")
]
生成 DataFrame 数据框 (以前 3 行为例):
SDF = spark.createDataFrame(Data_Rows)
输出 DataFrame 数据框 (以前 3 行为例):
print("[Message] Builded Spark DataFrame:")
SDF.show()
输出:
+----------+-------+----+----+----+-----+---------+---------+-------------+----------+----+----+
| Date| Code|Open|High| Low|Close|Pre_Close| Change|Turnover_Rate| Volume| MA5|MA10|
+----------+-------+----+----+----+-----+---------+---------+-------------+----------+----+----+
|2015-12-31|'000422|7.93|7.95|7.76| 7.77| 7.93|-0.020177| 0.015498| 1.39152E7|7.86|7.85|
|2015-12-30|'000422|7.86|7.93|7.75| 7.93| 7.84| 0.01148| 0.018662| 1.67559E7|7.90|7.85|
|2015-12-29|'000422|7.72|7.85|7.69| 7.84| 7.71| 0.016861| 0.015886| 1.42638E7|7.90|7.81|
+----------+-------+----+----+----+-----+---------+---------+-------------+----------+----+----+
使用 Spark 读取 CSV 数据
调用 SparkSession 的 .read 方法读取 CSV 数据:
其中 .option 是读取文件时的选项, 左边是 “键(Key)”, 右边是 “值(Value)”, 例如 .option(“header”, “true”) 与 {header = “true”} 类同。
SDF = spark.read.option("header", "true").option("encoding", "utf-8").csv("file:///D:\\HBYH\_000422\_20150806\_20151231.csv")
输出 DataFrame 数据框:
print("[Message] Readed CSV File: D:\\HBYH\_000422\_20150806\_20151231.csv")
SDF.show()
输出:
[Message] Readed CSV File: D:\HBYH_000422_20150806_20151231.csv
+----------+-------+----+----+----+-----+---------+---------+-------------+--------+----+----+
| Date| Code|Open|High| Low|Close|Pre_Close| Change|Turnover_Rate| Volume| MA5|MA10|
+----------+-------+----+----+----+-----+---------+---------+-------------+--------+----+----+
|2015-12-31|'000422|7.93|7.95|7.76| 7.77| 7.93|-0.020177| 0.015498|13915200|7.86|7.85|
|2015-12-30|'000422|7.86|7.93|7.75| 7.93| 7.84| 0.011480| 0.018662|16755900|7.90|7.85|
|2015-12-29|'000422|7.72|7.85|7.69| 7.84| 7.71| 0.016861| 0.015886|14263800|7.90|7.81|
|2015-12-28|'000422|8.03|8.08|7.70| 7.71| 8.03|-0.039851| 0.030821|27672800|7.91|7.78|
|2015-12-25|'000422|8.03|8.05|7.93| 8.03| 7.99| 0.005006| 0.021132|18974000|7.93|7.78|
|2015-12-24|'000422|7.93|8.16|7.87| 7.99| 7.92| 0.008838| 0.026487|23781900|7.85|7.72|
|2015-12-23|'000422|7.97|8.11|7.88| 7.92| 7.89| 0.003802| 0.042360|38033600|7.80|7.69|
|2015-12-22|'000422|7.86|7.93|7.76| 7.89| 7.83| 0.007663| 0.026929|24178700|7.73|7.68|
|2015-12-21|'000422|7.59|7.89|7.56| 7.83| 7.63| 0.026212| 0.030777|27633600|7.66|7.67|
|2015-12-18|'000422|7.71|7.74|7.57| 7.63| 7.74|-0.014212| 0.024764|22234900|7.62|7.71|
|2015-12-17|'000422|7.58|7.75|7.57| 7.74| 7.55| 0.025166| 0.028054|25188400|7.59|7.77|
|2015-12-16|'000422|7.57|7.62|7.53| 7.55| 7.55| 0.000000| 0.020718|18601600|7.58|7.79|
|2015-12-15|'000422|7.63|7.66|7.52| 7.55| 7.62|-0.009186| 0.025902|23256600|7.64|7.78|
|2015-12-14|'000422|7.40|7.64|7.36| 7.62| 7.51| 0.014647| 0.021005|18860100|7.68|7.76|
|2015-12-11|'000422|7.65|7.70|7.41| 7.51| 7.67|-0.020860| 0.020477|18385900|7.80|7.73|
|2015-12-10|'000422|7.78|7.87|7.65| 7.67| 7.83|-0.020434| 0.019972|17931900|7.95|7.69|
|2015-12-09|'000422|7.76|8.00|7.75| 7.83| 7.77| 0.007722| 0.025137|22569700|8.00|7.68|
|2015-12-08|'000422|8.08|8.18|7.76| 7.77| 8.24|-0.057039| 0.036696|32948200|7.92|7.66|
|2015-12-07|'000422|8.12|8.39|7.94| 8.24| 8.23| 0.001215| 0.064590|57993100|7.84|7.64|
|2015-12-04|'000422|7.85|8.48|7.80| 8.23| 7.92| 0.039141| 0.100106|89881900|7.65|7.58|
+----------+-------+----+----+----+-----+---------+---------+-------------+--------+----+----+
only showing top 20 rows
转换 Spark 中 DateFrame 各列数据类型
通常情况下, 为了避免计算出现数据类型的错误, 都需要重新转换一下数据类型。
# 转换 Spark 中 DateFrame 数据类型。
SDF = SDF.withColumn("Date", col("Date").cast(DateType()))
SDF = SDF.withColumn("Open", col("Open").cast(DoubleType()))
SDF = SDF.withColumn("High", col("High").cast(DoubleType()))
SDF = SDF.withColumn("Low", col("Low").cast(DoubleType()))
SDF = SDF.withColumn("Close", col("Close").cast(DoubleType()))
SDF = SDF.withColumn("Pre\_Close", col("Pre\_Close").cast(DoubleType()))
SDF = SDF.withColumn("Change", col("Change").cast(DoubleType()))
SDF = SDF.withColumn("Turnover\_Rate", col("Turnover\_Rate").cast(DoubleType()))
SDF = SDF.withColumn("Volume", col("Volume").cast(IntegerType()))
SDF = SDF.withColumn("MA5", col("MA5").cast(DoubleType()))
SDF = SDF.withColumn("MA10", col("MA10").cast(DoubleType()))
# 输出 Spark 中 DataFrame 字段和数据类型。
print("[Message] Changed Spark DataFrame Data Type:")
SDF.printSchema()
输出:
[Message] Changed Spark DataFrame Data Type:
root
|-- Date: date (nullable = true)
|-- Code: string (nullable = true)
|-- Open: double (nullable = true)
|-- High: double (nullable = true)
|-- Low: double (nullable = true)
|-- Close: double (nullable = true)
|-- Pre_Close: double (nullable = true)
|-- Change: double (nullable = true)
|-- Turnover_Rate: double (nullable = true)
|-- Volume: integer (nullable = true)
|-- MA5: double (nullable = true)
|-- MA10: double (nullable = true)
将 Spark 的 DateFrame 和 Spark RDD 互相转换并计算数据
编写 “向 spark.sql 的 Row 对象添加字段和字段值” 函数:
def MapFunc\_SparkSQL\_Row\_Add\_Field(SrcRow:pyspark.sql.types.Row, FldName:str, FldVal:object) -> pyspark.sql.types.Row:
"""
[Require] import pyspark
[Example] >>> SrcRow = Row(Date=datetime.date(2023, 12, 1), Clerk='Bob', Incom=5432.10)
>>> NewRow = MapFunc\_SparkSQL\_Row\_Add\_Field(SrcRow=SrcRow, FldName='Weekday', FldVal=SrcRow['Date'].weekday())
>>> print(NewRow)
Row(Date=datetime.date(2023, 12, 1), Clerk='Bob', Incom=5432.10, Weekday=4)
"""
# Convert Obj "pyspark.sql.types.Row" to Dict.
# ----------------------------------------------
Row_Dict = SrcRow.asDict()
# Add a New Key in the Dictionary With the New Column Name and Value.
# ----------------------------------------------
Row_Dict[FldName] = FldVal
# Convert Dict to Obj "pyspark.sql.types.Row".
# ----------------------------------------------
NewRow = pyspark.sql.types.Row(\*\*Row_Dict)
# ==============================================
return NewRow
编写 “判断股票涨跌” 函数:
def MapFunc\_Stock\_Judgement\_Rise\_or\_Fall(ChgRate:float) -> int:
if (ChgRate >= 0.0): return 1
if (ChgRate < 0.0): return 0
# ==============================================
# End of Function.
编写 “判断股票短期均线和长期均线关系” 函数:
def MapFunc\_Stock\_Judgement\_Short\_MA\_and\_Long\_MA\_Relationship(Short_MA:float, Long_MA:float) -> int:
if (Short_MA >= Long_MA): return 1
if (Short_MA == Long_MA): return 0
if (Short_MA <= Long_MA): return -1
# ==============================================
# End of Function.
编写 “返回星期几(中文)” 函数:
def DtmFunc\_Weekday\_Return\_String\_CN(SrcDtm:datetime.datetime) -> str:
"""
[Require] import datetime
[Explain] Python3 中 datetime.datetime 对象的 .weekday() 方法返回的是从 0 到 6 的数字 (0 代表周一, 6 代表周日)。
"""
Weekday_Str_Chinese:list = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
# ==============================================
return Weekday_Str_Chinese[SrcDtm.weekday()]
在 Spark 中将 DataFrame 转换为 Spark RDD 并调用自定义函数:
# 在 Spark 中将 DataFrame 转换为 RDD。
CalcRDD = SDF.rdd
# --------------------------------------------------
# 调用自定义函数: 提取星期索引。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Weekday(Idx)", X["Date"].weekday()))
# ..................................................
# 调用自定义函数: 返回星期几(中文)。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Weekday(CN)", DtmFunc_Weekday_Return_String_CN(X["Date"])))
# ..................................................
# 调用自定义函数: 判断股票涨跌。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "Rise\_Fall", MapFunc_Stock_Judgement_Rise_or_Fall(X["Change"])))
# ..................................................
# 判断股票短期均线和长期均线关系。
CalcRDD = CalcRDD.map(lambda X: MapFunc_SparkSQL_Row_Add_Field(X, "MA\_Relationship", MapFunc_Stock_Judgement_Short_MA_and_Long_MA_Relationship(Short_MA=X["MA5"], Long_MA=X["MA10"])))
# 显示计算好的 RDD 前 5 行。
print("[Message] Calculated RDD Top 5 Rows:")
pprint.pprint(CalcRDD.take(5))
输出:
[Message] Calculated RDD Top 5 Rows:
[Row(Date=datetime.date(2015, 12, 31), Code="'000422", Open=7.93, High=7.95, Low=7.76, Close=7.77, Pre_Close=7.93, Change=-0.020177, Turnover_Rate=0.015498, Volume=13915200, MA5=7.86, MA10=7.85, Weekday(Idx)=3, Weekday(CN)='周四', Rise_Fall=0, MA_Relationship=1),
Row(Date=datetime.date(2015, 12, 30), Code="'000422", Open=7.86, High=7.93, Low=7.75, Close=7.93, Pre_Close=7.84, Change=0.01148, Turnover_Rate=0.018662, Volume=16755900, MA5=7.9, MA10=7.85, Weekday(Idx)=2, Weekday(CN)='周三', Rise_Fall=1, MA_Relationship=1),
Row(Date=datetime.date(2015, 12, 29), Code="'000422", Open=7.72, High=7.85, Low=7.69, Close=7.84, Pre_Close=7.71, Change=0.016861, Turnover_Rate=0.015886, Volume=14263800, MA5=7.9, MA10=7.81, Weekday(Idx)=1, Weekday(CN)='周二', Rise_Fall=1, MA_Relationship=1),
Row(Date=datetime.date(2015, 12, 28), Code="'000422", Open=8.03, High=8.08, Low=7.7, Close=7.71, Pre_Close=8.03, Change=-0.039851, Turnover_Rate=0.030821, Volume=27672800, MA5=7.91, MA10=7.78, Weekday(Idx)=0, Weekday(CN)='周一', Rise_Fall=0, MA_Relationship=1),
Row(Date=datetime.date(2015, 12, 25), Code="'000422", Open=8.03, High=8.05, Low=7.93, Close=8.03, Pre_Close=7.99, Change=0.005006, Turnover_Rate=0.021132, Volume=18974000, MA5=7.93, MA10=7.78, Weekday(Idx)=4, Weekday(CN)='周五', Rise_Fall=1, MA_Relationship=1)]
计算完成后将 Spark RDD 转换回 Spark 的 DataFrame:
# 在 Spark 中将 RDD 转换为 DataFrame。
NewSDF = CalcRDD.toDF()
print("[Message] Convert RDD to DataFrame and Filter Out Key Columns for Display:")
NewSDF.select(["Date", "Code", "High", "Low", "Close", "Change", "MA5", "MA10", "Weekday(CN)", "Rise\_Fall", "MA\_Relationship"]).show()
输出:
[Message] Convert RDD to DataFrame and Filter Out Key Columns:
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+
| Date| Code|High| Low|Close| Change| MA5|MA10|Weekday(CN)|Rise_Fall|MA_Relationship|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+
|2015-12-31|'000422|7.95|7.76| 7.77|-0.020177|7.86|7.85| 周四| 0| 1|
|2015-12-30|'000422|7.93|7.75| 7.93| 0.01148| 7.9|7.85| 周三| 1| 1|
|2015-12-29|'000422|7.85|7.69| 7.84| 0.016861| 7.9|7.81| 周二| 1| 1|
|2015-12-28|'000422|8.08| 7.7| 7.71|-0.039851|7.91|7.78| 周一| 0| 1|
|2015-12-25|'000422|8.05|7.93| 8.03| 0.005006|7.93|7.78| 周五| 1| 1|
|2015-12-24|'000422|8.16|7.87| 7.99| 0.008838|7.85|7.72| 周四| 1| 1|
|2015-12-23|'000422|8.11|7.88| 7.92| 0.003802| 7.8|7.69| 周三| 1| 1|
|2015-12-22|'000422|7.93|7.76| 7.89| 0.007663|7.73|7.68| 周二| 1| 1|
|2015-12-21|'000422|7.89|7.56| 7.83| 0.026212|7.66|7.67| 周一| 1| -1|
|2015-12-18|'000422|7.74|7.57| 7.63|-0.014212|7.62|7.71| 周五| 0| -1|
|2015-12-17|'000422|7.75|7.57| 7.74| 0.025166|7.59|7.77| 周四| 1| -1|
|2015-12-16|'000422|7.62|7.53| 7.55| 0.0|7.58|7.79| 周三| 1| -1|
|2015-12-15|'000422|7.66|7.52| 7.55|-0.009186|7.64|7.78| 周二| 0| -1|
|2015-12-14|'000422|7.64|7.36| 7.62| 0.014647|7.68|7.76| 周一| 1| -1|
|2015-12-11|'000422| 7.7|7.41| 7.51| -0.02086| 7.8|7.73| 周五| 0| 1|
|2015-12-10|'000422|7.87|7.65| 7.67|-0.020434|7.95|7.69| 周四| 0| 1|
|2015-12-09|'000422| 8.0|7.75| 7.83| 0.007722| 8.0|7.68| 周三| 1| 1|
|2015-12-08|'000422|8.18|7.76| 7.77|-0.057039|7.92|7.66| 周二| 0| 1|
|2015-12-07|'000422|8.39|7.94| 8.24| 0.001215|7.84|7.64| 周一| 1| 1|
|2015-12-04|'000422|8.48| 7.8| 8.23| 0.039141|7.65|7.58| 周五| 1| 1|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+
字符串索引化 (StringIndexer) 演示 (Only Demo)
StringIndexer (字符串-索引变换) 是一个估计器, 是将字符串列编码为标签索引列。索引位于 [0, numLabels)
, 按标签频率排序, 频率最高的排 0, 依次类推, 因此最常见的标签获取索引是 0。
# 使用 StringIndexer 转换 Weekday(CN) 列。
MyStringIndexer = StringIndexer(inputCol="Weekday(CN)", outputCol="StrIdx")
# 拟合并转换数据。
IndexedSDF = MyStringIndexer.fit(NewSDF).transform(NewSDF)
# 筛选 Date, Weekday(Idx), Weekday(CN), StrIdx 四列, 输出 StringIndexer 效果。
print("[Message] The Effect of StringIndexer:")
IndexedSDF.select(["Date", "Weekday(Idx)", "Weekday(CN)", "StrIdx"]).show()
输出:
[Message] The Effect of StringIndexer:
+----------+------------+-----------+------+
| Date|Weekday(Idx)|Weekday(CN)|StrIdx|
+----------+------------+-----------+------+
|2015-12-31| 3| 周四| 3.0|
|2015-12-30| 2| 周三| 1.0|
|2015-12-29| 1| 周二| 2.0|
|2015-12-28| 0| 周一| 0.0|
|2015-12-25| 4| 周五| 4.0|
|2015-12-24| 3| 周四| 3.0|
|2015-12-23| 2| 周三| 1.0|
|2015-12-22| 1| 周二| 2.0|
|2015-12-21| 0| 周一| 0.0|
|2015-12-18| 4| 周五| 4.0|
|2015-12-17| 3| 周四| 3.0|
|2015-12-16| 2| 周三| 1.0|
|2015-12-15| 1| 周二| 2.0|
|2015-12-14| 0| 周一| 0.0|
|2015-12-11| 4| 周五| 4.0|
|2015-12-10| 3| 周四| 3.0|
|2015-12-09| 2| 周三| 1.0|
|2015-12-08| 1| 周二| 2.0|
|2015-12-07| 0| 周一| 0.0|
|2015-12-04| 4| 周五| 4.0|
+----------+------------+-----------+------+
only showing top 20 rows
提取 标签(Label)列 和 特征向量(Features)列
在创建特征向量(Features)列时, 将会用到 VectorAssembler 模块, VectorAssembler 将多个特征合并为一个特征向量。
提取 标签(Label) 列:
# 将 Rise\_Fall 列复制为 Label 列。
NewSDF = NewSDF.withColumn("Label", col("Rise\_Fall"))
创建 特征向量(Features) 列:
# VectorAssembler 将多个特征合并为一个特征向量。
FeaColsName:list = ["High", "Low", "Turnover\_Rate", "Volume", "Weekday(Idx)", "MA\_Relationship"]
MyAssembler = VectorAssembler(inputCols=FeaColsName, outputCol="Features")
# 拟合数据 (可选, 如果在模型训练时使用 Pipeline, 则无需在此步骤拟合数据, 当然也就无法在此步骤预览数据)。
AssembledSDF = MyAssembler.transform(NewSDF)
输出预览:
print("[Message] Assembled Label and Features for RandomForestClassifier:")
AssembledSDF.select(["Date", "Code", "High", "Low", "Close", "Change", "MA5", "MA10", "Weekday(CN)", "Rise\_Fall", "MA\_Relationship", "Label", "Features"]).show()
预览:
[Message] Assembled for RandomForestClassifier:
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+-----+--------------------+
| Date| Code|High| Low|Close| Change| MA5|MA10|Weekday(CN)|Rise_Fall|MA_Relationship|Label| Features|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+-----+--------------------+
|2015-12-31|'000422|7.95|7.76| 7.77|-0.020177|7.86|7.85| 周四| 0| 1| 0|[7.95,7.76,0.0154...|
|2015-12-30|'000422|7.93|7.75| 7.93| 0.01148| 7.9|7.85| 周三| 1| 1| 1|[7.93,7.75,0.0186...|
|2015-12-29|'000422|7.85|7.69| 7.84| 0.016861| 7.9|7.81| 周二| 1| 1| 1|[7.85,7.69,0.0158...|
|2015-12-28|'000422|8.08| 7.7| 7.71|-0.039851|7.91|7.78| 周一| 0| 1| 0|[8.08,7.7,0.03082...|
|2015-12-25|'000422|8.05|7.93| 8.03| 0.005006|7.93|7.78| 周五| 1| 1| 1|[8.05,7.93,0.0211...|
|2015-12-24|'000422|8.16|7.87| 7.99| 0.008838|7.85|7.72| 周四| 1| 1| 1|[8.16,7.87,0.0264...|
|2015-12-23|'000422|8.11|7.88| 7.92| 0.003802| 7.8|7.69| 周三| 1| 1| 1|[8.11,7.88,0.0423...|
|2015-12-22|'000422|7.93|7.76| 7.89| 0.007663|7.73|7.68| 周二| 1| 1| 1|[7.93,7.76,0.0269...|
|2015-12-21|'000422|7.89|7.56| 7.83| 0.026212|7.66|7.67| 周一| 1| -1| 1|[7.89,7.56,0.0307...|
|2015-12-18|'000422|7.74|7.57| 7.63|-0.014212|7.62|7.71| 周五| 0| -1| 0|[7.74,7.57,0.0247...|
|2015-12-17|'000422|7.75|7.57| 7.74| 0.025166|7.59|7.77| 周四| 1| -1| 1|[7.75,7.57,0.0280...|
|2015-12-16|'000422|7.62|7.53| 7.55| 0.0|7.58|7.79| 周三| 1| -1| 1|[7.62,7.53,0.0207...|
|2015-12-15|'000422|7.66|7.52| 7.55|-0.009186|7.64|7.78| 周二| 0| -1| 0|[7.66,7.52,0.0259...|
|2015-12-14|'000422|7.64|7.36| 7.62| 0.014647|7.68|7.76| 周一| 1| -1| 1|[7.64,7.36,0.0210...|
|2015-12-11|'000422| 7.7|7.41| 7.51| -0.02086| 7.8|7.73| 周五| 0| 1| 0|[7.7,7.41,0.02047...|
|2015-12-10|'000422|7.87|7.65| 7.67|-0.020434|7.95|7.69| 周四| 0| 1| 0|[7.87,7.65,0.0199...|
|2015-12-09|'000422| 8.0|7.75| 7.83| 0.007722| 8.0|7.68| 周三| 1| 1| 1|[8.0,7.75,0.02513...|
|2015-12-08|'000422|8.18|7.76| 7.77|-0.057039|7.92|7.66| 周二| 0| 1| 0|[8.18,7.76,0.0366...|
|2015-12-07|'000422|8.39|7.94| 8.24| 0.001215|7.84|7.64| 周一| 1| 1| 1|[8.39,7.94,0.0645...|
|2015-12-04|'000422|8.48| 7.8| 8.23| 0.039141|7.65|7.58| 周五| 1| 1| 1|[8.48,7.8,0.10010...|
+----------+-------+----+----+-----+---------+----+----+-----------+---------+---------------+-----+--------------------+
only showing top 20 rows
训练 随机森林分类器(RandomForestClassifier) 模型
将数据集划分为 “训练集” 和 “测试集”:
(TrainingData, TestData) = AssembledSDF.randomSplit([0.8, 0.2], seed=42)
创建 随机森林分类器(RandomForestClassifier):
RFC = RandomForestClassifier(labelCol="Label", featuresCol="Features", numTrees=10)
创建 Pipeline (可选):
# 创建 Pipeline, 将特征向量转换和随机森林模型组合在一起
# 注意: 如果要使用 Pipeline, 则在创建 特征向量(Features)列 的时候不需要拟合数据, 否则会报 "Output column Features already exists." 的错误。
MyPipeline = Pipeline(stages=[MyAssembler, RFC])
训练 随机森林分类器(RandomForestClassifier) 模型:
如果在创建 特征向量(Features)列 的时候已经拟合数据:
# 训练模型 (普通模式)。
Model = RFC.fit(TrainingData)
如果在创建 特征向量(Features)列 的时候没有拟合数据:
# 训练模型 (Pipeline 模式)。
Model = MyPipeline.fit(TrainingData)
使用 随机森林分类器(RandomForestClassifier) 模型预测数据
# 在测试集上进行预测。
Predictions = Model.transform(TestData)
# 删除不需要的列 (以免列数太多, 结果显示拥挤, 不好观察)。
Predictions = Predictions.drop("Open")
Predictions = Predictions.drop("High")
Predictions = Predictions.drop("Low")
Predictions = Predictions.drop("Close")
Predictions = Predictions.drop("Pre\_Close")
Predictions = Predictions.drop("Turnover\_Rate")
Predictions = Predictions.drop("Volume")
Predictions = Predictions.drop("Weekday(Idx)")
Predictions = Predictions.drop("Weekday(CN)")
print("[Message] Prediction Results on The Test Data Set for RandomForestClassifier:")
Predictions.show()
输出:
[Message] Prediction Results on The Test Data Set for RandomForestClassifier:
+----------+-------+---------+----+----+---------+---------------+-----+--------------------+--------------------+--------------------+----------+
| Date| Code| Change| MA5|MA10|Rise_Fall|MA_Relationship|Label| Features| rawPrediction| probability|prediction|
+----------+-------+---------+----+----+---------+---------------+-----+--------------------+--------------------+--------------------+----------+
|2015-08-10|'000422| 0.034105| 8.2|7.92| 1| 1| 1|[8.58,8.18,0.0412...|[3.83333333333333...|[0.38333333333333...| 1.0|
|2015-08-14|'000422| 0.009479|8.43|8.24| 1| 1| 1|[8.65,8.43,0.0411...|[6.33333333333333...|[0.63333333333333...| 0.0|
|2015-08-18|'000422|-0.095455|8.39|8.32| 0| 1| 0|[8.86,7.92,0.0561...|[4.83333333333333...|[0.48333333333333...| 1.0|
|2015-08-25|'000422|-0.099424|7.52|7.96| 0| -1| 0|[6.77,6.25,0.0294...|[1.24468211527035...|[0.12446821152703...| 1.0|
|2015-09-02|'000422|-0.053412|6.73|6.91| 0| -1| 0|[6.88,6.3,0.02228...|[2.39316696375519...|[0.23931669637551...| 1.0|
|2015-09-10|'000422|-0.031161|6.76|6.74| 0| 1| 0|[7.01,6.76,0.0174...|[2.40476190476190...|[0.24047619047619...| 1.0|
|2015-09-18|'000422| 0.0|6.39|6.62| 1| -1| 1|[6.58,6.3,0.01662...|[4.22700534759358...|[0.42270053475935...| 1.0|
|2015-09-28|'000422| 0.009464|6.48|6.47| 1| 1| 1|[6.42,6.25,0.0088...|[3.83333333333333...|[0.38333333333333...| 1.0|
|2015-10-19|'000422|-0.007062|6.94|6.72| 0| 1| 0|[7.13,6.92,0.0312...|[1.44220779220779...|[0.14422077922077...| 1.0|
|2015-10-20|'000422| 0.008535|6.98|6.81| 1| 1| 1|[7.09,6.94,0.0244...|[2.59069264069264...|[0.25906926406926...| 1.0|
|2015-10-21|'000422|-0.062059|6.96|6.85| 0| 1| 0|[7.11,6.61,0.0393...|[3.42857142857142...|[0.34285714285714...| 1.0|
|2015-10-23|'000422| 0.054412|6.95|6.93| 1| 1| 1|[7.22,6.81,0.0471...|[2.47857142857142...|[0.24785714285714...| 1.0|
|2015-10-27|'000422| 0.033426|7.04|7.01| 1| 1| 1|[7.48,7.08,0.0576...|[2.81190476190476...|[0.28119047619047...| 1.0|
|2015-11-02|'000422|-0.027548|7.23| 7.1| 0| 1| 0|[7.26,7.05,0.0168...|[1.62402597402597...|[0.16240259740259...| 1.0|
|2015-11-11|'000422| 0.005284|7.54|7.37| 1| 1| 1|[7.64,7.52,0.0261...|[3.29902597402597...|[0.32990259740259...| 1.0|
|2015-11-20|'000422| 0.002635|7.52|7.53| 1| -1| 1|[7.71,7.53,0.0282...|[5.74068627450980...|[0.57406862745098...| 0.0|
|2015-12-02|'000422| 0.009511|7.37|7.49| 1| -1| 1|[7.48,7.2,0.01596...|[7.54901960784313...|[0.75490196078431...| 0.0|
+----------+-------+---------+----+----+---------+---------------+-----+--------------------+--------------------+--------------------+----------+
使用 BinaryClassificationEvaluator 评估模型性能
# 使用 BinaryClassificationEvaluator 评估模型性能。
MyEvaluator = BinaryClassificationEvaluator(labelCol="Label", metricName="areaUnderROC")
auc = MyEvaluator.evaluate(Predictions)
print("Area Under ROC (AUC):", auc)
输出:
Area Under ROC (AUC): 0.15714285714285714
完整代码
#!/usr/bin/python3
# Create By GF 2024-01-07
# 在这个例子中, 我们使用 VectorAssembler 将多个特征列合并为一个特征向量, 并使用 RandomForestClassifier 构建随机森林模型。
# 最后, 我们使用 BinaryClassificationEvaluator 评估模型性能, 通常使用 ROC 曲线下面积 (AUC) 作为评估指标。
# 请根据你的实际数据和问题调整特征列, 标签列以及其他参数。在实际应用中, 你可能需要进行更多的特征工程, 调参和模型评估。
import datetime
import pprint
# --------------------------------------------------
import pyspark
# --------------------------------------------------
from pyspark.sql import Row, SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import DateType, IntegerType, DoubleType
# --------------------------------------------------
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml import Pipeline
# 编写 "向 spark.sql 的 Row 对象添加字段和字段值" 函数。
def MapFunc\_SparkSQL\_Row\_Add\_Field(SrcRow:pyspark.sql.types.Row, FldName:str, FldVal:object) -> pyspark.sql.types.Row:
![img](https://img-blog.csdnimg.cn/img_convert/0dc22e79752475d0697d2e98be9f3b99.png)
![img](https://img-blog.csdnimg.cn/img_convert/6815cc13f6fecf1e76fc826b42e8ba4e.png)
![img](https://img-blog.csdnimg.cn/img_convert/eaf6568e2fcf62b6d852c4a7728699e4.png)
**既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,涵盖了95%以上大数据知识点,真正体系化!**
**由于文件比较多,这里只是将部分目录截图出来,全套包含大厂面经、学习笔记、源码讲义、实战项目、大纲路线、讲解视频,并且后续会持续更新**
**[需要这份系统化资料的朋友,可以戳这里获取](https://bbs.csdn.net/topics/618545628)**
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml import Pipeline
# 编写 "向 spark.sql 的 Row 对象添加字段和字段值" 函数。
def MapFunc\_SparkSQL\_Row\_Add\_Field(SrcRow:pyspark.sql.types.Row, FldName:str, FldVal:object) -> pyspark.sql.types.Row:
[外链图片转存中...(img-QlQLDzz6-1714889575185)]
[外链图片转存中...(img-HfVMaZM9-1714889575186)]
[外链图片转存中...(img-iyX6CPjX-1714889575186)]
**既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,涵盖了95%以上大数据知识点,真正体系化!**
**由于文件比较多,这里只是将部分目录截图出来,全套包含大厂面经、学习笔记、源码讲义、实战项目、大纲路线、讲解视频,并且后续会持续更新**
**[需要这份系统化资料的朋友,可以戳这里获取](https://bbs.csdn.net/topics/618545628)**