package com.dream.ml.features
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, LogisticRegressionTrainingSummary}
import org.apache.spark.ml.feature.{StandardScaler, StringIndexer, VectorAssembler}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.{DoubleType, StringType, StructType}
/**
* @title: IrisFeaturesDemo
* @projectName SparkStudy
* @description: TODO
* @author MXH
* @date 2023/9/3 10:44
*/
object IrisFeaturesDemo {
def main(args: Array[String]): Unit = {
// 1.创建SparkSQL的运行环境
val spark: SparkSession = SparkSession.builder()
.appName(this.getClass.getSimpleName.stripSuffix("$"))
.master("local[4]")
.config("spark.sql.shuffle.partitions", 4)
.getOrCreate()
// 导入隐式转换
import spark.implicits._
// 2.加载鸢尾花数据集iris.data
val irisSchema: StructType = new StructType()
.add("sepal_length", dataType = DoubleType, nullable = true)
.add("sepal_width", dataType = DoubleType, nullable = true)
.add("petal_length", dataType = DoubleType, nullable = true)
.add("petal_width", dataType = DoubleType, nullable = true)
.add("class", dataType = StringType, nullable = true)
val rawIrisDF: DataFrame = spark.read
// 查看csv源码来设置options
.option("sep", ",") // 分隔符
.option("header", "false") // 默认为false
.option("inferSchema", "false") // 默认为false
// 当csv文件首行不是列名称时,需要自定义Schema
.schema(irisSchema)
.csv("datas/iris/iris.data")
/*
root
|-- sepal_length: double (nullable = true)
|-- sepal_width: double (nullable = true)
|-- petal_length: double (nullable = true)
|-- petal_width: double (nullable = true)
|-- class: string (nullable = true)
*/
rawIrisDF.printSchema()
rawIrisDF.show(10,truncate = false)
// 3.数据转换
// 3.1 转换1:将萼片长度、宽度及花瓣长度、宽度封装到一个特征向量中
// https://spark.apache.org/docs/latest/ml-features.html#vectorassembler
val assembler = new VectorAssembler()
// 把需要组合的列名称枚举出来进行组合
// .setInputCols(Array("hour", "mobile", "userFeatures"))
// 本例中除最后一列不要外,需要其他列
.setInputCols(rawIrisDF.columns.dropRight(1))
.setOutputCol("features") // 添加一列,类型为向量
val df1 = assembler.transform(rawIrisDF)
/*
root
|-- sepal_length: double (nullable = true)
|-- sepal_width: double (nullable = true)
|-- petal_length: double (nullable = true)
|-- petal_width: double (nullable = true)
|-- class: string (nullable = true)
|-- features: vector (nullable = true)
*/
df1.printSchema()
df1.show(10, truncate = false)
// 3.2 转换2: 转换类别字符串数据为数值数据
// https://spark.apache.org/docs/latest/ml-features.html#stringindexer
val indexer = new StringIndexer()
.setInputCol("class") // 需要索引化的列名
.setOutputCol("label") // 数据索引化后列名
.fit(df1)
val df2 = indexer.transform(df1)
/*
root
|-- sepal_length: double (nullable = true)
|-- sepal_width: double (nullable = true)
|-- petal_length: double (nullable = true)
|-- petal_width: double (nullable = true)
|-- class: string (nullable = true)
|-- features: vector (nullable = true) // 特征 x
|-- label: double (nullable = false) // 标签 y
算法: y = kx + b
*/
df2.printSchema()
df2.show(10,truncate = false)
// 3.3 数据标准化
// 在实际开发中,特征数据features经常需要进行各个转换操作,比如归一化、标准化和正则化等
// 为什么要进行归一化、标准化或正则化等数据预处理?原因在于不同维度特征值,值的范围跨度不一样,导致模型异常
// 比如影响房价的因素有地段、面积、楼层、新旧等特征数据
// 数据标准化 https://spark.apache.org/docs/latest/ml-features.html#standardscaler
val scaler = new StandardScaler()
.setInputCol("features")
.setOutputCol("scale_features")
.setWithStd(true) // 使用标准差缩放
.setWithMean(false) //使用平均值缩放
// Compute summary statistics by fitting the StandardScaler.
val scalerModel = scaler.fit(df2)
// Normalize each feature to have unit standard deviation.
val irisDF = scalerModel.transform(df2)
irisDF.show(10, truncate = false)
// 4.分类算法
// https://spark.apache.org/docs/latest/ml-classification-regression.html
/*
分类算法有:
(1)决策树(DecisionTree)分类算法
(2)朴素贝叶斯(Native Bayes)分类算法-适合构建文本数据特征分类,比如垃圾邮件、情感分析
(3)逻辑回归(Logistics Regression)分类算法
(4)线性支持向量机(Linear SVM)分类算法
(5)神经网络相关分类算法,比如多层感知机算法-》深度学习算法
(6)集成融合算法,随机森林(RF)分类算法、梯度提升树(GBT)算法
Classification
Logistic regression
Binomial logistic regression
Multinomial logistic regression
Decision tree classifier
Random forest classifier
Gradient-boosted tree classifier
Multilayer perceptron classifier
Linear Support Vector Machine
One-vs-Rest classifier (a.k.a. One-vs-All)
Naive Bayes
Factorization machines classifier
*/
// 4.1 创建模型
val lr: LogisticRegression = new LogisticRegression()
// 设置特征值列名称和标签列名称
.setFeaturesCol("scale_features") // x -> 特征
.setLabelCol("label") // y-> 标签
// 每个算法都有自己超参数要设置,合理设置,获取较好的模型
.setMaxIter(20) // 模型训练迭代次数,默认100
.setStandardization(true) //是否数据标准化,默认为true
.setFamily("multinomial") //设置分类属于二分类(标签label只有2个值)还是多分类(大于2个值)
.setRegParam(0) // 正则化参数,默认值为0.0 优化
.setElasticNetParam(0) // 弹性化参数,优化
// 4.2训练模型
val lrModel: LogisticRegressionModel = lr.fit(irisDF)
// 4.3评估模型
println(s"多分类混淆矩阵: ${lrModel.coefficientMatrix}")
val summary: LogisticRegressionTrainingSummary = lrModel.summary
// 准确度: 0.9733333333333334
println(s"准确度: ${summary.accuracy}")
// 精确度是针对每一个分类的
println(s"精确度: ${summary.precisionByLabel.mkString(",")}")
// 关闭环境
// spark.close() 源码中调用stop()
spark.stop()
}
}
Spark逻辑回归分类算法-鸢尾花分类
最新推荐文章于 2024-08-10 21:38:29 发布