Spark Machine Learning进行数据挖掘的简单应用(兴趣预测问题)

数据挖掘的过程

数据挖掘任务主要分为以下六个步骤:

  • 1.数据预处理
  • 2.特征转换
  • 3.特征选择
  • 4.训练模型
  • 5.模型预测
  • 6.评估预测结果

数据准备

这里准备了20条关于不同地区、不同性别、不同身高、体重…的人的兴趣数据集(命名为hobby.csv):

id,hobby,sex,address,age,height,weight
1,football,male,dalian,12,168,55
2,pingpang,female,yangzhou,21,163,60
3,football,male,dalian,,172,70
4,football,female,,13,167,58
5,pingpang,female,shanghai,63,170,64
6,football,male,dalian,30,177,76
7,basketball,male,shanghai,25,181,90
8,football,male,dalian,15,172,71
9,basketball,male,shanghai,25,179,80
10,pingpang,male,shanghai,55,175,72
11,football,male,dalian,13,169,55
12,pingpang,female,yangzhou,22,164,61
13,football,male,dalian,23,170,71
14,football,female,,12,164,55
15,pingpang,female,shanghai,64,169,63
16,football,male,dalian,30,177,76
17,basketball,male,shanghai,22,180,80
18,football,male,dalian,16,173,72
19,basketball,male,shanghai,23,176,73
20,pingpang,male,shanghai,56,171,71
  • 任务分析
    通过sex,address,age,height,weight这五个特征预测一个人的兴趣爱好

数据预处理

想要连接数据,必须先创建一个spark对象

定义Spark对象

使用SparkSession中的builder()构建 后续设定appName 和master ,最后使用getOrCreate()完成构建

    // 定义spark对象
    val spark = SparkSession.builder().appName("兴趣预测").master("local[*]").getOrCreate()

连接数据

使用spark.read连接数据,需要指定数据的格式为“CSV”,将首行设置为header,最后指定文件路径:

val df=spark.read.format("CSV").option("header",true).load("C:/Users/35369/Desktop/hobby.csv")

使用df.show() df.printSchema()查看数据:

    df.show()
    df.printSchema()

    spark.stop()  // 关闭spark

输出信息:

+---+----------+------+--------+----+------+------+
| id|     hobby|   sex| address| age|height|weight|
+---+----------+------+--------+----+------+------+
|  1|  football|  male|  dalian|  12|   168|    55|
|  2|  pingpang|female|yangzhou|  21|   163|    60|
|  3|  football|  male|  dalian|null|   172|    70|
|  4|  football|female|    null|  13|   167|    58|
|  5|  pingpang|female|shanghai|  63|   170|    64|
|  6|  football|  male|  dalian|  30|   177|    76|
|  7|basketball|  male|shanghai|  25|   181|    90|
|  8|  football|  male|  dalian|  15|   172|    71|
|  9|basketball|  male|shanghai|  25|   179|    80|
| 10|  pingpang|  male|shanghai|  55|   175|    72|
| 11|  football|  male|  dalian|  13|   169|    55|
| 12|  pingpang|female|yangzhou|  22|   164|    61|
| 13|  football|  male|  dalian|  23|   170|    71|
| 14|  football|female|    null|  12|   164|    55|
| 15|  pingpang|female|shanghai|  64|   169|    63|
| 16|  football|  male|  dalian|  30|   177|    76|
| 17|basketball|  male|shanghai|  22|   180|    80|
| 18|  football|  male|  dalian|  16|   173|    72|
| 19|basketball|  male|shanghai|  23|   176|    73|
| 20|  pingpang|  male|shanghai|  56|   171|    71|
+---+----------+------+--------+----+------+------+

root
 |-- id: string (nullable = true)
 |-- hobby: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- address: string (nullable = true)
 |-- age: string (nullable = true)
 |-- height: string (nullable = true)
 |-- weight: string (nullable = true)

补全年龄空缺的行

补全数值型数据可以分三步:
(1)取出去除空行数据之后的这一列数据
(2)计算(1)中那一列数据的平均值
(3)将平均值填充至原先的表中

  • (1)取出空行之后的数据
    val ageNaDF = df.select("age").na.drop()
    ageNaDF.show()
+---+
|age|
+---+
| 12|
| 21|
| 13|
| 63|
| 30|
| 25|
| 15|
| 25|
| 55|
| 13|
| 22|
| 23|
| 12|
| 64|
| 30|
| 22|
| 16|
| 23|
| 56|
+---+
  • (2)计算(1)中那一列数据的平均值

查看ageNaDF的基本特征

ageNaDF.describe("age").show()

输出:

+-------+-----------------+
|summary|              age|
+-------+-----------------+
|  count|               19|
|   mean|28.42105263157895|
| stddev|17.48432882286206|
|    min|               12|
|    max|               64|
+-------+-----------------+

可以看到其中的均值mean为28.42105263157895,我们需要取出这个mean

    val mean = ageNaDF.describe("age").select("age").collect()(1)(0).toString
    print(mean) //28.42105263157895
  • (3)将平均值填充至原先的表中
    使用df.na.fill()方法可以填充空值,需要指定列为“age”,所以第二个参数为List(“age”)
    val ageFilledDF = df.na.fill(mean,List("age"))
    ageFilledDF.show()

输出:

+---+----------+------+--------+-----------------+------+------+
| id|     hobby|   sex| address|              age|height|weight|
+---+----------+------+--------+-----------------+------+------+
|  1|  football|  male|  dalian|               12|   168|    55|
|  2|  pingpang|female|yangzhou|               21|   163|    60|
|  3|  football|  male|  dalian|28.42105263157895|   172|    70|
|  4|  football|female|    null|               13|   167|    58|
|  5|  pingpang|female|shanghai|               63|   170|    64|
|  6|  football|  male|  dalian|               30|   177|    76|
|  7|basketball|  male|shanghai|               25|   181|    90|
|  8|  football|  male|  dalian|               15|   172|    71|
|  9|basketball|  male|shanghai|               25|   179|    80|
| 10|  pingpang|  male|shanghai|               55|   175|    72|
| 11|  football|  male|  dalian|               13|   169|    55|
| 12|  pingpang|female|yangzhou|               22|   164|    61|
| 13|  football|  male|  dalian|               23|   170|    71|
| 14|  football|female|    null|               12|   164|    55|
| 15|  pingpang|female|shanghai|               64|   169|    63|
| 16|  football|  male|  dalian|               30|   177|    76|
| 17|basketball|  male|shanghai|               22|   180|    80|
| 18|  football|  male|  dalian|               16|   173|    72|
| 19|basketball|  male|shanghai|               23|   176|    73|
| 20|  pingpang|  male|shanghai|               56|   171|    71|
+---+----------+------+--------+-----------------+------+------+

可以发现年龄中的空值被填充了平均值

删除城市有空值所在的行

由于城市的列没有合理的数据可以填充,所以如果城市出现空数据则选择把改行删除
使用.na.drop()方法

    val addressDf = ageFilledDF.na.drop()
    addressDf.show()

输出:

+---+----------+------+--------+-----------------+------+------+
| id|     hobby|   sex| address|              age|height|weight|
+---+----------+------+--------+-----------------+------+------+
|  1|  football|  male|  dalian|               12|   168|    55|
|  2|  pingpang|female|yangzhou|               21|   163|    60|
|  3|  football|  male|  dalian|28.42105263157895|   172|    70|
|  5|  pingpang|female|shanghai|               63|   170|    64|
|  6|  football|  male|  dalian|               30|   177|    76|
|  7|basketball|  male|shanghai|               25|   181|    90|
|  8|  football|  male|  dalian|               15|   172|    71|
|  9|basketball|  male|shanghai|               25|   179|    80|
| 10|  pingpang|  male|shanghai|               55|   175|    72|
| 11|  football|  male|  dalian|               13|   169|    55|
| 12|  pingpang|female|yangzhou|               22|   164|    61|
| 13|  football|  male|  dalian|               23|   170|    71|
| 15|  pingpang|female|shanghai|               64|   169|    63|
| 16|  football|  male|  dalian|               30|   177|    76|
| 17|basketball|  male|shanghai|               22|   180|    80|
| 18|  football|  male|  dalian|               16|   173|    72|
| 19|basketball|  male|shanghai|               23|   176|    73|
| 20|  pingpang|  male|shanghai|               56|   171|    71|
+---+----------+------+--------+-----------------+------+------+

4和14行被删除

将每列字段的格式转换成合理的格式

    //对df的schema进行调整
    val formatDF = addressDf.select(
      col("id").cast("int"),
      col("hobby").cast("String"),
      col("sex").cast("String"),
      col("address").cast("String"),
      col("age").cast("Double"),
      col("height").cast("Double"),
      col("weight").cast("Double")
    )
    formatDF.printSchema()

输出:

root
 |-- id: integer (nullable = true)
 |-- hobby: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- address: string (nullable = true)
 |-- age: double (nullable = true)
 |-- height: double (nullable = true)
 |-- weight: double (nullable = true)

到此,数据预处理部分完成。

特征转换

为了便于模型训练,在数据的特征转换中,我们需要对age、weight、height、address、sex这些特征做分桶处理。

对年龄做分桶处理

  • 18以下
  • 18-35
  • 35-60
  • 60以上

使用Bucketizer类用来分桶处理,需要设置输入的列名和输出的列名,把定义的分桶区间作为这个类分桶的依据,最后给定需要做分桶处理的DataFrame

    //2.1 对年龄进行分桶处理
    //定义一个数组作为分桶的区间
    val ageSplits = Array(Double.NegativeInfinity,18,35,60,Double.PositiveInfinity)
    val bucketizerDF = new Bucketizer()
      .setInputCol("age")
      .setOutputCol("ageFeature")
      .setSplits(ageSplits)
      .transform(formatDF)
    bucketizerDF.show()

查看分桶结果:

+---+----------+------+--------+-----------------+------+------+----------+
| id|     hobby|   sex| address|              age|height|weight|ageFeature|
+---+----------+------+--------+-----------------+------+------+----------+
|  1|  football|  male|  dalian|             12.0| 168.0|  55.0|       0.0|
|  2|  pingpang|female|yangzhou|             21.0| 163.0|  60.0|       1.0|
|  3|  football|  male|  dalian|28.42105263157895| 172.0|  70.0|       1.0|
|  5|  pingpang|female|shanghai|             63.0| 170.0|  64.0|       3.0|
|  6|  football|  male|  dalian|             30.0| 177.0|  76.0|       1.0|
|  7|basketball|  male|shanghai|             25.0| 181.0|  90.0|       1.0|
|  8|  football|  male|  dalian|             15.0| 172.0|  71.0|       0.0|
|  9|basketball|  male|shanghai|             25.0| 179.0|  80.0|       1.0|
| 10|  pingpang|  male|shanghai|             55.0| 175.0|  72.0|       2.0|
| 11|  football|  male|  dalian|             13.0| 169.0|  55.0|       0.0|
| 12|  pingpang|female|yangzhou|             22.0| 164.0|  61.0|       1.0|
| 13|  football|  male|  dalian|             23.0| 170.0|  71.0|       1.0|
| 15|  pingpang|female|shanghai|             64.0| 169.0|  63.0|       3.0|
| 16|  football|  male|  dalian|             30.0| 177.0|  76.0|       1.0|
| 17|basketball|  male|shanghai|             22.0| 180.0|  80.0|       1.0|
| 18|  football|  male|  dalian|             16.0| 173.0|  72.0|       0.0|
| 19|basketball|  male|shanghai|             23.0| 176.0|  73.0|       1.0|
| 20|  pingpang|  male|shanghai|             56.0| 171.0|  71.0|       2.0|
+---+----------+------+--------+-----------------+------+------+----------+

对身高做二值化处理

基准为170 使用Binarizer类

    //2.2 对身高做二值化处理
    val heightDF = new Binarizer()
      .setInputCol("height")
      .setOutputCol("heightFeature")
      .setThreshold(170) // 阈值
      .transform(bucketizerDF)
    heightDF.show()

查看处理后结果:

+---+----------+------+--------+-----------------+------+------+----------+-------------+
| id|     hobby|   sex| address|              age|height|weight|ageFeature|heightFeature|
+---+----------+------+--------+-----------------+------+------+----------+-------------+
|  1|  football|  male|  dalian|             12.0| 168.0|  55.0|       0.0|          0.0|
|  2|  pingpang|female|yangzhou|             21.0| 163.0|  60.0|       1.0|          0.0|
|  3|  football|  male|  dalian|28.42105263157895| 172.0|  70.0|       1.0|          1.0|
|  5|  pingpang|female|shanghai|             63.0| 170.0|  64.0|       3.0|          0.0|
|  6|  football|  male|  dalian|             30.0| 177.0|  76.0|       1.0|          1.0|
|  7|basketball|  male|shanghai|             25.0| 181.0|  90.0|       1.0|          1.0|
|  8|  football|  male|  dalian|             15.0| 172.0|  71.0|       0.0|          1.0|
|  9|basketball|  male|shanghai|             25.0| 179.0|  80.0|       1.0|          1.0|
| 10|  pingpang|  male|shanghai|             55.0| 175.0|  72.0|       2.0|          1.0|
| 11|  football|  male|  dalian|             13.0| 169.0|  55.0|       0.0|          0.0|
| 12|  pingpang|female|yangzhou|             22.0| 164.0|  61.0|       1.0|          0.0|
| 13|  football|  male|  dalian|             23.0| 170.0|  71.0|       1.0|          0.0|
| 15|  pingpang|female|shanghai|             64.0| 169.0|  63.0|       3.0|          0.0|
| 16|  football|  male|  dalian|             30.0| 177.0|  76.0|       1.0|          1.0|
| 17|basketball|  male|shanghai|             22.0| 180.0|  80.0|       1.0|          1.0|
| 18|  football|  male|  dalian|             16.0| 173.0|  72.0|       0.0|          1.0|
| 19|basketball|  male|shanghai|             23.0| 176.0|  73.0|       1.0|          1.0|
| 20|  pingpang|  male|shanghai|             56.0| 171.0|  71.0|       2.0|          1.0|
+---+----------+------+--------+-----------------+------+------+----------+-------------+

对体重做二值化处理

阈值设为 65

    //2.3 对体重做二值化处理
    val weightDF = new Binarizer()
      .setInputCol("weight")
      .setOutputCol("weightFeature")
      .setThreshold(65)
      .transform(heightDF)

    weightDF.show()

性别、城市、爱好字段的处理

这三个字段都是字符串,而字符串的形式在机器学习中是不适合做分析处理的,所以也需要对他们做特征转换(编码处理)。

    //2.4 对性别进行labelEncode转换
    val sexIndex = new StringIndexer()
      .setInputCol("sex")
      .setOutputCol("sexIndex")
      .fit(weightDF)
      .transform(weightDF)

    //2.5对家庭地址进行labelEncode转换
    val addIndex = new StringIndexer()
      .setInputCol("address")
      .setOutputCol("addIndex")
      .fit(sexIndex)
      .transform(sexIndex)

    //2.6对地址进行one-hot编码
    val addOneHot = new OneHotEncoder()
      .setInputCol("addIndex")
      .setOutputCol("addOneHot")
      .fit(addIndex)
      .transform(addIndex)

    //2.7对兴趣字段进行LabelEncode处理
    val hobbyIndexDF = new StringIndexer()
      .setInputCol("hobby")
      .setOutputCol("hobbyIndex")
      .fit(addOneHot)
      .transform(addOneHot)

    hobbyIndexDF.show()

这里额外对地址做了一个one-hot处理。

将hobbyIndex列名称改成label,因为hobby在模型训练阶段用作标签。

    //2.8修改列名
    val resultDF = hobbyIndexDF.withColumnRenamed("hobbyIndex","label")
    resultDF.show()

最终特征转换后的结果:

+---+----------+------+--------+-----------------+------+------+----------+-------------+-------------+--------+--------+-------------+-----+
| id|     hobby|   sex| address|              age|height|weight|ageFeature|heightFeature|weightFeature|sexIndex|addIndex|    addOneHot|label|
+---+----------+------+--------+-----------------+------+------+----------+-------------+-------------+--------+--------+-------------+-----+
|  1|  football|  male|  dalian|             12.0| 168.0|  55.0|       0.0|          0.0|          0.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
|  2|  pingpang|female|yangzhou|             21.0| 163.0|  60.0|       1.0|          0.0|          0.0|     1.0|     2.0|    (2,[],[])|  1.0|
|  3|  football|  male|  dalian|28.42105263157895| 172.0|  70.0|       1.0|          1.0|          1.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
|  5|  pingpang|female|shanghai|             63.0| 170.0|  64.0|       3.0|          0.0|          0.0|     1.0|     1.0|(2,[1],[1.0])|  1.0|
|  6|  football|  male|  dalian|             30.0| 177.0|  76.0|       1.0|          1.0|          1.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
|  7|basketball|  male|shanghai|             25.0| 181.0|  90.0|       1.0|          1.0|          1.0|     0.0|     1.0|(2,[1],[1.0])|  2.0|
|  8|  football|  male|  dalian|             15.0| 172.0|  71.0|       0.0|          1.0|          1.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
|  9|basketball|  male|shanghai|             25.0| 179.0|  80.0|       1.0|          1.0|          1.0|     0.0|     1.0|(2,[1],[1.0])|  2.0|
| 10|  pingpang|  male|shanghai|             55.0| 175.0|  72.0|       2.0|          1.0|          1.0|     0.0|     1.0|(2,[1],[1.0])|  1.0|
| 11|  football|  male|  dalian|             13.0| 169.0|  55.0|       0.0|          0.0|          0.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
| 12|  pingpang|female|yangzhou|             22.0| 164.0|  61.0|       1.0|          0.0|          0.0|     1.0|     2.0|    (2,[],[])|  1.0|
| 13|  football|  male|  dalian|             23.0| 170.0|  71.0|       1.0|          0.0|          1.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
| 15|  pingpang|female|shanghai|             64.0| 169.0|  63.0|       3.0|          0.0|          0.0|     1.0|     1.0|(2,[1],[1.0])|  1.0|
| 16|  football|  male|  dalian|             30.0| 177.0|  76.0|       1.0|          1.0|          1.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
| 17|basketball|  male|shanghai|             22.0| 180.0|  80.0|       1.0|          1.0|          1.0|     0.0|     1.0|(2,[1],[1.0])|  2.0|
| 18|  football|  male|  dalian|             16.0| 173.0|  72.0|       0.0|          1.0|          1.0|     0.0|     0.0|(2,[0],[1.0])|  0.0|
| 19|basketball|  male|shanghai|             23.0| 176.0|  73.0|       1.0|          1.0|          1.0|     0.0|     1.0|(2,[1],[1.0])|  2.0|
| 20|  pingpang|  male|shanghai|             56.0| 171.0|  71.0|       2.0|          1.0|          1.0|     0.0|     1.0|(2,[1],[1.0])|  1.0|
+---+----------+------+--------+-----------------+------+------+----------+-------------+-------------+--------+--------+-------------+-----+

特征选择

特征转换后的结果是一个多列数据,但不是所有的列都可以拿来用作机器学习的模型训练,特征选择就是要选择可以用来机器学习的数据。

选择特征

使用VectorAssembler()可以将需要的列取出

    //3.1选择特征
    val vectorAssembler = new VectorAssembler()
      .setInputCols(Array("ageFeature","heightFeature","weightFeature","sexIndex","addIndex","label"))
      .setOutputCol("features")

特征进行规范化处理

    val scaler = new StandardScaler()
      .setInputCol("features")
      .setOutputCol("featureScaler")
      .setWithStd(true) // 是否使用标准差
      .setWithMean(false)  // 是否使用中位数

特征筛选

    // 特征筛选,使用卡方检验方法来做筛选
    val selector = new ChiSqSelector()
      .setLabelCol("label")
      .setOutputCol("featuresSelector")

构建逻辑回归模型和pipline

    // 逻辑回归模型
    val lr = new LogisticRegression().setLabelCol("label").setFeaturesCol("featuresSelector")

    // 构造pipeline
    val pipeline = new Pipeline().setStages(Array(vectorAssembler,scaler,selector,lr))

设置网络搜索最佳参数

    // 设置网络搜索最佳参数
    val params = new ParamGridBuilder()
      .addGrid(lr.regParam,Array(0.1,0.01))  //正则化参数
      .addGrid(selector.numTopFeatures,Array(5,10,5))  //设置卡方检验最佳特征数
      .build()

设置交叉检验

    // 设置交叉检验
    val cv = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(new BinaryClassificationEvaluator())
      .setEstimatorParamMaps(params)
      .setNumFolds(5)

模型训练与预测

模型训练前需要拆分一下训练集和测试集

val Array(trainDF,testDF) = resultDF.randomSplit(Array(0.8,0.2))

使用randomSplit方法可以完成拆分

  • 开始训练和预测
    val model = cv.fit(trainDF)

    // 模型预测
    val preddiction = model.bestModel.transform(testDF)
    preddiction.show()

报错求解决

运行cv.fit(trainDF)的地方报错了 这个信息网上也没找到

Exception in thread "main" java.lang.NoClassDefFoundError: org/apache/spark/sql/catalyst/trees/BinaryLike
	at java.lang.ClassLoader.defineClass1(Native Method)
	at java.lang.ClassLoader.defineClass(ClassLoader.java:756)
	at java.security.SecureClassLoader.defineClass(SecureClassLoader.java:142)
	at java.net.URLClassLoader.defineClass(URLClassLoader.java:473)
	at java.net.URLClassLoader.access$100(URLClassLoader.java:74)
	at java.net.URLClassLoader$1.run(URLClassLoader.java:369)
	at java.net.URLClassLoader$1.run(URLClassLoader.java:363)
	at java.security.AccessController.doPrivileged(Native Method)
	at java.net.URLClassLoader.findClass(URLClassLoader.java:362)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:418)
	at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:355)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:351)
	at org.apache.spark.ml.stat.SummaryBuilderImpl.summary(Summarizer.scala:251)
	at org.apache.spark.ml.stat.SummaryBuilder.summary(Summarizer.scala:54)
	at org.apache.spark.ml.feature.StandardScaler.fit(StandardScaler.scala:112)
	at org.apache.spark.ml.feature.StandardScaler.fit(StandardScaler.scala:84)
	at org.apache.spark.ml.Pipeline.$anonfun$fit$5(Pipeline.scala:151)
	at org.apache.spark.ml.MLEvents.withFitEvent(events.scala:130)
	at org.apache.spark.ml.MLEvents.withFitEvent$(events.scala:123)
	at org.apache.spark.ml.util.Instrumentation.withFitEvent(Instrumentation.scala:42)
	at org.apache.spark.ml.Pipeline.$anonfun$fit$4(Pipeline.scala:151)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at org.apache.spark.ml.Pipeline.$anonfun$fit$2(Pipeline.scala:147)
	at org.apache.spark.ml.MLEvents.withFitEvent(events.scala:130)
	at org.apache.spark.ml.MLEvents.withFitEvent$(events.scala:123)
	at org.apache.spark.ml.util.Instrumentation.withFitEvent(Instrumentation.scala:42)
	at org.apache.spark.ml.Pipeline.$anonfun$fit$1(Pipeline.scala:133)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.Pipeline.fit(Pipeline.scala:133)
	at org.apache.spark.ml.Pipeline.fit(Pipeline.scala:93)
	at org.apache.spark.ml.Estimator.fit(Estimator.scala:59)
	at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$7(CrossValidator.scala:174)
	at scala.runtime.java8.JFunction0$mcD$sp.apply(JFunction0$mcD$sp.java:23)
	at scala.concurrent.Future$.$anonfun$apply$1(Future.scala:659)
	at scala.util.Success.$anonfun$map$1(Try.scala:255)
	at scala.util.Success.map(Try.scala:213)
	at scala.concurrent.Future.$anonfun$map$1(Future.scala:292)
	at scala.concurrent.impl.Promise.liftedTree1$1(Promise.scala:33)
	at scala.concurrent.impl.Promise.$anonfun$transform$1(Promise.scala:33)
	at scala.concurrent.impl.CallbackRunnable.run(Promise.scala:64)
	at org.sparkproject.guava.util.concurrent.MoreExecutors$SameThreadExecutorService.execute(MoreExecutors.java:293)
	at scala.concurrent.impl.ExecutionContextImpl$$anon$4.execute(ExecutionContextImpl.scala:138)
	at scala.concurrent.impl.CallbackRunnable.executeWithValue(Promise.scala:72)
	at scala.concurrent.impl.Promise$KeptPromise$Kept.onComplete(Promise.scala:372)
	at scala.concurrent.impl.Promise$KeptPromise$Kept.onComplete$(Promise.scala:371)
	at scala.concurrent.impl.Promise$KeptPromise$Successful.onComplete(Promise.scala:379)
	at scala.concurrent.impl.Promise.transform(Promise.scala:33)
	at scala.concurrent.impl.Promise.transform$(Promise.scala:31)
	at scala.concurrent.impl.Promise$KeptPromise$Successful.transform(Promise.scala:379)
	at scala.concurrent.Future.map(Future.scala:292)
	at scala.concurrent.Future.map$(Future.scala:292)
	at scala.concurrent.impl.Promise$KeptPromise$Successful.map(Promise.scala:379)
	at scala.concurrent.Future$.apply(Future.scala:659)
	at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$6(CrossValidator.scala:182)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
	at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
	at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
	at scala.collection.TraversableLike.map(TraversableLike.scala:286)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
	at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:198)
	at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$4(CrossValidator.scala:172)
	at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
	at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
	at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
	at scala.collection.TraversableLike.map(TraversableLike.scala:286)
	at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
	at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:198)
	at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$1(CrossValidator.scala:166)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.tuning.CrossValidator.fit(CrossValidator.scala:137)
	at org.example.SparkML.SparkMl01$.main(SparkMl01.scala:147)
	at org.example.SparkML.SparkMl01.main(SparkMl01.scala)
Caused by: java.lang.ClassNotFoundException: org.apache.spark.sql.catalyst.trees.BinaryLike
	at java.net.URLClassLoader.findClass(URLClassLoader.java:387)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:418)
	at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:355)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:351)

全部源码以及pom文件

package org.example.SparkML


import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{Binarizer, Bucketizer, ChiSqSelector, OneHotEncoder, StandardScaler, StringIndexer, VectorAssembler}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col

/**
 * 数据挖掘的过程
 * 1.数据预处理
 * 2.特征转换(编码。。。)
 * 3.特征选择
 * 4.训练模型
 * 5.模型预测
 * 6.评估预测结果
 */
object SparkMl01 {
  def main(args: Array[String]): Unit = {
    // 定义spark对象
    val spark = SparkSession.builder().appName("兴趣预测").master("local").getOrCreate()
    import spark.implicits._
    val df=spark.read.format("CSV").option("header",true).load("C:/Users/35369/Desktop/hobby.csv")
    //1.数据预处理,补全空缺的年龄
    val ageNaDF = df.select("age").na.drop()
    val mean = ageNaDF.describe("age").select("age").collect()(1)(0).toString
    val ageFilledDF = df.na.fill(mean,List("age"))
    //address为空的行直接删除
    val addressDf = ageFilledDF.na.drop()

    //对df的schema进行调整
    val formatDF = addressDf.select(
      col("id").cast("int"),
      col("hobby").cast("String"),
      col("sex").cast("String"),
      col("address").cast("String"),
      col("age").cast("Double"),
      col("height").cast("Double"),
      col("weight").cast("Double")
    )

    //2.特征转换
    //2.1 对年龄进行分桶处理
    //定义一个数组作为分桶的区间
    val ageSplits = Array(Double.NegativeInfinity,18,35,60,Double.PositiveInfinity)
    val bucketizerDF = new Bucketizer()
      .setInputCol("age")
      .setOutputCol("ageFeature")
      .setSplits(ageSplits)
      .transform(formatDF)


    //2.2 对身高做二值化处理
    val heightDF = new Binarizer()
      .setInputCol("height")
      .setOutputCol("heightFeature")
      .setThreshold(170) // 阈值
      .transform(bucketizerDF)

    //2.3 对体重做二值化处理
    val weightDF = new Binarizer()
      .setInputCol("weight")
      .setOutputCol("weightFeature")
      .setThreshold(65)
      .transform(heightDF)

    //2.4 对性别进行labelEncode转换
    val sexIndex = new StringIndexer()
      .setInputCol("sex")
      .setOutputCol("sexIndex")
      .fit(weightDF)
      .transform(weightDF)

    //2.5对家庭地址进行labelEncode转换
    val addIndex = new StringIndexer()
      .setInputCol("address")
      .setOutputCol("addIndex")
      .fit(sexIndex)
      .transform(sexIndex)

    //2.6对地址进行one-hot编码
    val addOneHot = new OneHotEncoder()
      .setInputCol("addIndex")
      .setOutputCol("addOneHot")
      .fit(addIndex)
      .transform(addIndex)

    //2.7对兴趣字段进行LabelEncode处理
    val hobbyIndexDF = new StringIndexer()
      .setInputCol("hobby")
      .setOutputCol("hobbyIndex")
      .fit(addOneHot)
      .transform(addOneHot)

    //2.8修改列名
    val resultDF = hobbyIndexDF.withColumnRenamed("hobbyIndex","label")


    //3 特征选择
    //3.1选择特征
    val vectorAssembler = new VectorAssembler()
      .setInputCols(Array("ageFeature","heightFeature","weightFeature","sexIndex","addOneHot"))
      .setOutputCol("features")

    //3.2特征进行规范化处理
    val scaler = new StandardScaler()
      .setInputCol("features")
      .setOutputCol("featureScaler")
      .setWithStd(true) // 是否使用标准差
      .setWithMean(false)  // 是否使用中位数


    // 特征筛选,使用卡方检验方法来做筛选
    val selector = new ChiSqSelector()
      .setFeaturesCol("featureScaler")
      .setLabelCol("label")
      .setOutputCol("featuresSelector")


    // 逻辑回归模型
    val lr = new LogisticRegression().setLabelCol("label").setFeaturesCol("featuresSelector")

    // 构造pipeline
    val pipeline = new Pipeline()
      .setStages(Array(vectorAssembler,scaler,selector,lr))

    // 设置网络搜索最佳参数
    val params = new ParamGridBuilder()
      .addGrid(lr.regParam,Array(0.1,0.01))  //正则化参数
      .addGrid(selector.numTopFeatures,Array(5,10,5))  //设置卡方检验最佳特征数
      .build()

    // 设置交叉检验
    val cv = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(new BinaryClassificationEvaluator())
      .setEstimatorParamMaps(params)
      .setNumFolds(5)

    // 模型训练
    val Array(trainDF,testDF) = resultDF.randomSplit(Array(0.8,0.2))
    trainDF.show()
    testDF.show()
    val model = cv.fit(trainDF)

    //生成模型
//    val model = pipeline.fit(trainDF)
//    val prediction = model.transform(testDF)
//    prediction.show()



    // 模型预测
//    val preddiction = model.bestModel.transform(testDF)
//    preddiction.show()



    spark.stop()
  }
}

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org.example</groupId>
    <artifactId>untitled</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.scala-lang</groupId>
            <artifactId>scala-library</artifactId>
            <version>2.12.18</version>
        </dependency>


        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_2.12</artifactId>
            <version>3.0.0-preview2</version>
        </dependency>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-hive_2.12</artifactId>
            <version>3.1.2</version>
<!--            <scope>provided</scope>-->
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.12</artifactId>
            <version>3.0.0-preview2</version>
<!--            <scope>compile</scope>-->
        </dependency>


<!--        <dependency>-->
<!--            <groupId>mysql</groupId>-->
<!--            <artifactId>mysql-connector-java</artifactId>-->
<!--            <version>8.0.16</version>-->
<!--        </dependency>-->

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.12</artifactId>
            <version>3.5.0</version>
<!--            <scope>compile</scope>-->
        </dependency>




    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-shade-plugin</artifactId>
                <version>2.4.1</version>
                <executions>
                    <execution>
                        <phase>package</phase>
                        <goals>
                            <goal>shade</goal>
                        </goals>
                        <configuration>
                            <transformers>
                                <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
                                    <mainClass>com.xxg.Main</mainClass>
                                </transformer>
                            </transformers>
                        </configuration>
                    </execution>
                </executions>
            </plugin>

        </plugins>
    </build>


</project>
  • 17
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
Machine Learning with Spark - Second Edition by Rajdeep Dua English | 4 May 2017 | ASIN: B01DPR2ELW | 532 Pages | AZW3 | 9.6 MB Key Features Get to the grips with the latest version of Apache Spark Utilize Spark's machine learning library to implement predictive analytics Leverage Spark’s powerful tools to load, analyze, clean, and transform your data Book Description This book will teach you about popular machine learning algorithms and their implementation. You will learn how various machine learning concepts are implemented in the context of Spark ML. You will start by installing Spark in a single and multinode cluster. Next you'll see how to execute Scala and Python based programs for Spark ML. Then we will take a few datasets and go deeper into clustering, classification, and regression. Toward the end, we will also cover text processing using Spark ML. Once you have learned the concepts, they can be applied to implement algorithms in either green-field implementations or to migrate existing systems to this new platform. You can migrate from Mahout or Scikit to use Spark ML. By the end of this book, you will acquire the skills to leverage Spark's features to create your own scalable machine learning applications and power a modern data-driven business. What you will learn Get hands-on with the latest version of Spark ML Create your first Spark program with Scala and Python Set up and configure a development environment for Spark on your own computer, as well as on Amazon EC2 Access public machine learning datasets and use Spark to load, process, clean, and transform data Use Spark's machine learning library to implement programs by utilizing well-known machine learning models Deal with large-scale text data, including feature extraction and using text data as input to your machine learning models Write Spark functions to evaluate the performance of your machine learning models
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

冲鸭嘟嘟可

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值