Spark逻辑回归分类算法-鸢尾花分类

package com.dream.ml.features

import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, LogisticRegressionTrainingSummary}
import org.apache.spark.ml.feature.{StandardScaler, StringIndexer, VectorAssembler}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.{DoubleType, StringType, StructType}

/**
 * @title: IrisFeaturesDemo
 * @projectName SparkStudy
 * @description: TODO
 * @author MXH
 * @date 2023/9/3 10:44
 */
object IrisFeaturesDemo {
    def main(args: Array[String]): Unit = {
        // 1.创建SparkSQL的运行环境
        val spark: SparkSession = SparkSession.builder()
            .appName(this.getClass.getSimpleName.stripSuffix("$"))
            .master("local[4]")
            .config("spark.sql.shuffle.partitions", 4)
            .getOrCreate()

        // 导入隐式转换
        import spark.implicits._

        // 2.加载鸢尾花数据集iris.data
        val irisSchema: StructType = new StructType()
            .add("sepal_length", dataType = DoubleType, nullable = true)
            .add("sepal_width", dataType = DoubleType, nullable = true)
            .add("petal_length", dataType = DoubleType, nullable = true)
            .add("petal_width", dataType = DoubleType, nullable = true)
            .add("class", dataType = StringType, nullable = true)

        val rawIrisDF: DataFrame = spark.read
            // 查看csv源码来设置options
            .option("sep", ",") // 分隔符
            .option("header", "false") // 默认为false
            .option("inferSchema", "false") // 默认为false
            // 当csv文件首行不是列名称时,需要自定义Schema
            .schema(irisSchema)
            .csv("datas/iris/iris.data")

        /*
        root
         |-- sepal_length: double (nullable = true)
         |-- sepal_width: double (nullable = true)
         |-- petal_length: double (nullable = true)
         |-- petal_width: double (nullable = true)
         |-- class: string (nullable = true)
                 */
        rawIrisDF.printSchema()
        rawIrisDF.show(10,truncate = false)

        // 3.数据转换
        // 3.1 转换1:将萼片长度、宽度及花瓣长度、宽度封装到一个特征向量中
        // https://spark.apache.org/docs/latest/ml-features.html#vectorassembler
        val assembler = new VectorAssembler()
            // 把需要组合的列名称枚举出来进行组合
            // .setInputCols(Array("hour", "mobile", "userFeatures"))
            // 本例中除最后一列不要外,需要其他列
            .setInputCols(rawIrisDF.columns.dropRight(1))
            .setOutputCol("features") // 添加一列,类型为向量

        val df1 = assembler.transform(rawIrisDF)

        /*
        root
         |-- sepal_length: double (nullable = true)
         |-- sepal_width: double (nullable = true)
         |-- petal_length: double (nullable = true)
         |-- petal_width: double (nullable = true)
         |-- class: string (nullable = true)
         |-- features: vector (nullable = true)
         */
        df1.printSchema()
        df1.show(10, truncate = false)

        // 3.2 转换2: 转换类别字符串数据为数值数据
        // https://spark.apache.org/docs/latest/ml-features.html#stringindexer
        val indexer = new StringIndexer()
            .setInputCol("class") // 需要索引化的列名
            .setOutputCol("label") // 数据索引化后列名
            .fit(df1)
        val df2 = indexer.transform(df1)

        /*
        root
        |-- sepal_length: double (nullable = true)
        |-- sepal_width: double (nullable = true)
        |-- petal_length: double (nullable = true)
        |-- petal_width: double (nullable = true)
        |-- class: string (nullable = true)
        |-- features: vector (nullable = true)  // 特征 x
        |-- label: double (nullable = false) // 标签 y
        算法: y = kx + b
        */
        df2.printSchema()
        df2.show(10,truncate = false)

        // 3.3 数据标准化
        // 在实际开发中,特征数据features经常需要进行各个转换操作,比如归一化、标准化和正则化等
        // 为什么要进行归一化、标准化或正则化等数据预处理?原因在于不同维度特征值,值的范围跨度不一样,导致模型异常
        // 比如影响房价的因素有地段、面积、楼层、新旧等特征数据
        // 数据标准化 https://spark.apache.org/docs/latest/ml-features.html#standardscaler
        val scaler = new StandardScaler()
            .setInputCol("features")
            .setOutputCol("scale_features")
            .setWithStd(true)  // 使用标准差缩放
            .setWithMean(false) //使用平均值缩放

        // Compute summary statistics by fitting the StandardScaler.
        val scalerModel = scaler.fit(df2)

        // Normalize each feature to have unit standard deviation.
        val irisDF = scalerModel.transform(df2)

        irisDF.show(10, truncate = false)

        // 4.分类算法
        // https://spark.apache.org/docs/latest/ml-classification-regression.html
        /*
            分类算法有:
            (1)决策树(DecisionTree)分类算法
            (2)朴素贝叶斯(Native Bayes)分类算法-适合构建文本数据特征分类,比如垃圾邮件、情感分析
            (3)逻辑回归(Logistics Regression)分类算法
            (4)线性支持向量机(Linear SVM)分类算法
            (5)神经网络相关分类算法,比如多层感知机算法-》深度学习算法
            (6)集成融合算法,随机森林(RF)分类算法、梯度提升树(GBT)算法
            Classification
                Logistic regression
                Binomial logistic regression
                Multinomial logistic regression
                Decision tree classifier
                Random forest classifier
                Gradient-boosted tree classifier
                Multilayer perceptron classifier
                Linear Support Vector Machine
                One-vs-Rest classifier (a.k.a. One-vs-All)
                Naive Bayes
                Factorization machines classifier
         */

        // 4.1 创建模型
        val lr: LogisticRegression = new LogisticRegression()
                // 设置特征值列名称和标签列名称
                .setFeaturesCol("scale_features")  // x -> 特征
                .setLabelCol("label") // y-> 标签
                // 每个算法都有自己超参数要设置,合理设置,获取较好的模型
                .setMaxIter(20)  // 模型训练迭代次数,默认100
                .setStandardization(true) //是否数据标准化,默认为true
                .setFamily("multinomial") //设置分类属于二分类(标签label只有2个值)还是多分类(大于2个值)
                .setRegParam(0) // 正则化参数,默认值为0.0 优化
                .setElasticNetParam(0) // 弹性化参数,优化

        // 4.2训练模型
        val lrModel: LogisticRegressionModel = lr.fit(irisDF)

        // 4.3评估模型
        println(s"多分类混淆矩阵: ${lrModel.coefficientMatrix}")

        val summary: LogisticRegressionTrainingSummary = lrModel.summary
        // 准确度: 0.9733333333333334
        println(s"准确度: ${summary.accuracy}")

        // 精确度是针对每一个分类的
        println(s"精确度: ${summary.precisionByLabel.mkString(",")}")



        // 关闭环境
        // spark.close() 源码中调用stop()
        spark.stop()
    }
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值