【Spark ML系列】Spark GeneralizedLinearRegression广义线性回归原理用法示例源码详解

本文详细介绍了Spark中的GeneralizedLinearRegression,包括原理、示例、参数设置和源码分析。广义线性回归是线性回归的扩展,通过链接函数和错误分布适应非正态响应变量,使用IRLS算法估计模型参数。文章涵盖了不同链接函数、方差函数及Tweedie分布,并提供了参数总结和源码解读。
摘要由CSDN通过智能技术生成

Spark GeneralizedLinearRegression广义线性回归原理用法示例源码权威详解点击这里看全文

原理

Spark中的广义线性回归(Generalized Linear Regression)是一种统计模型,用于建立因变量与自变量之间的关系。它是线性回归的扩展,通过引入链接函数和错误分布来拟合非正态分布的响应变量。

在广义线性回归中,假设因变量Y的条件分布属于指数分布族,具体形式为:

P(Y=y | X; θ) = exp((y * θ - b(θ)) / a(φ) + c(y, φ))

其中,X表示自变量,θ表示模型参数,b(θ)是链接函数的关联函数,a(φ)是方差函数的比例因子,c(y, φ)是常数项,φ是方差函数的离散参数。

链接函数将自变量的线性组合映射到响应变量的均值。常见的链接函数有:恒等函数、对数函数、倒数函数、逻辑函数、Probit函数等。

方差函数描述了响应变量的方差与均值之间的关系。常见的方差函数有:高斯分布的平方函数、泊松分布的恒等函数、伽马分布的恒等函数、Tweedie分布的幂函数等。

在广义线性回归中,模型的目标是最小化负对数似然函数,通过迭代加权最小二乘(IRLS)算法来估计模型参数。IRLS算法通过反复迭代更新权重和模型参数,直到收敛。

广义线性回归在Spark中的实现使用了迭代的加权最小二乘算法,具体细节可以参考Spark官方文档。

总结起来,广义线性回归是通过引入链接函数和错误分布来拟合非正态分布的响应变量,通过最小化负对数似然函数来估计模型参数,并使用迭代加权最小二乘算法进行优化。

示例

package org.example.spark

import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.regression.{
   GeneralizedLinearRegression, GeneralizedLinearRegressionModel}
import org.apache.spark.sql.SparkSession

object GLRExample {
   
  def main(args: Array[String]): Unit = {
   
    // 创建SparkSession
    val spark = SparkSession.builder()
      .appName("GLRExample")
      .master("local[*]")
      .getOrCreate()

    // 加载数据集
    val data = spark.read.format("libsvm").load("D:\\work\\src\\sparkall\\spark-2.4.0\\spark-2.4.0\\data\\mllib\\sample_linear_regression_data.txt")

    // 创建特征向量列
    val assembler = new VectorAssembler()
      .setInputCols(Array("features"))
      .setOutputCol("featuresVector")
    val assembledData = assembler.transform(data)

    // 拆分训练集和测试集
    val Array(trainingData, testData) = assembledData.randomSplit(Array(0.7, 0.3))

    // 创建广义线性回归对象
    val glr = new GeneralizedLinearRegression()
      .setLabelCol("label")
      .setFeaturesCol("featuresVector")
      .setMaxIter(10)
      .setRegParam(0.3)
      .setFamily("gaussian") // 使用高斯家族

    // 训练广义线性回归模型
    val model = glr.fit(trainingData)

    // 在测试集上进行预测
    val predictions = model.transform(testData)

    // 打印预测结果
    predictions.select("features", "label", "prediction").show()

    // 保存模型
    model.save("path/to/save/model")

    // 加载模型
    val loadedModel = GeneralizedLinearRegressionModel.load("path/to/save/model")

    // 关闭SparkSession
    spark.stop()
  }
}

参数

参数总结

广义线性回归(GeneralizedLinearRegression)包含了以下参数:

  1. family:错误分布的描述,用于模型中使用的 family 名称参数。支持的选项有:“gaussian”、“binomial”、“poisson”、“gamma"和"tweedie”。默认值为"gaussian"。

  2. variancePower:Tweedie 分布方差函数中的幂参数,描述方差与分布均值之间的关系。仅适用于 Tweedie 家族。支持的值:0 和 [1, Inf)。

  3. link:链接函数的名称,提供线性预测器与分布函数均值之间的关系。仅当 family 不是"tweedie"时使用。支持的选项有:“identity”、“log”、“inverse”、“logit”、“probit”、“cloglog"和"sqrt”。

  4. linkPower:幂链接函数中的索引参数。仅适用于 Tweedie 家族。注意,链接幂 0、1、-1 或 0.5 分别对应于对数、恒等、倒数或平方根链接。如果未设置此值,默认为 1-[[variancePower]],与 R 中的"statmod"包匹配。

  5. linkPredictionCol:链接预测(线性预测器)列名参数。默认未设置,表示我们不输出链接预测。

  6. offsetCol:偏置列名参数。如果未设置或为空,我们将所有实例偏置视为 0.0。偏置特征具有常量系数 1.0。

  7. solver:优化的求解算法。支持的选项有:“irls”(迭代加权最小二乘)。默认值:“irls”
    此外,该类还包含了以下私有辅助方法:

  8. hasWeightCol:检查权重列是否设置且非空。

  9. hasOffsetCol:检查偏置列是否设置且非空。

  10. hasLinkPredictionCol:检查是否应输出链接预测。

  11. validateAndTransformSchema:验证并转换数据框结构,根据参数设置进行相应的调整。

中文源码

/**
 * 广义线性回归的参数。
 */
private[regression] trait GeneralizedLinearRegressionBase extends PredictorParams
  with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol
  with HasSolver with Logging {
   

  import GeneralizedLinearRegression._

  /**
   * 错误分布的描述,用于模型中使用的family名称参数。
   * 支持的选项有:"gaussian"、"binomial"、"poisson"、"gamma"和"tweedie"。
   * 默认值为"gaussian"。
   *
   * @group param
   */
  @Since("2.0.0")
  final val family: Param[String] = new Param(this, "family",
    "错误分布的描述,用于模型中使用的family名称参数。" +
      s"支持的选项有:${
     supportedFamilyNames.mkString(", ")}。",
    (value: String) => supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT)))

  /** @group getParam */
  @Since("2.0.0")
  def getFamily: String = $(family)

  /**
   * Tweedie分布方差函数中的幂参数,描述方差与分布均值之间的关系。
   * 仅适用于Tweedie家族。
   * (参见 <a href="https://en.wikipedia.org/wiki/Tweedie_distribution">
   * Tweedie Distribution (Wikipedia)</a>)
   * 支持的值:0和[1, Inf)。
   * 注意,方差幂0、1或2对应于高斯、泊松或伽马家族。
   *
   * @group param
   */
  @Since("2.2.0")
  final val variancePower: DoubleParam = new DoubleParam(this, "variancePower",
    "Tweedie分布方差函数中的幂参数,描述方差与分布均值之间的关系。" +
    "仅适用于Tweedie家族。支持的值:0和[1, Inf)。",
    (x: Double) => x >= 1.0 || x == 0.0)

  /** @group getParam */
  @Since("2.2.0")
  def getVariancePower: Double = $(variancePower)

  /**
   * link函数的名称,提供线性预测器与分布函数均值之间的关系。
   * 支持的选项有:"identity"、"log"、"inverse"、"logit"、"probit"、"cloglog"和"sqrt"。
   * 仅当family不是"tweedie"时使用。对于"tweedie"家族,必须通过[[linkPower]]指定link函数。
   *
   * @group param
   */
  @Since("2.0.0")
  final val link: Param[String] = new Param(this, "link", "link函数的名称," +
    "提供线性预测器与分布函数均值之间的关系。" +
    s"支持的选项有:${
     supportedLinkNames.mkString(", ")}",
    (value: String) => supportedLinkNames.contains(value.toLowerCase(Locale.ROOT)))

  /** @group getParam */
  @Since("2.0.0")
  def getLink: String = $(link)

  /**
   * 幂链接函数中的索引参数。仅适用于Tweedie家族。
   * 注意,链接幂0、1、-1或0.5分别对应于对数、恒等、倒数或平方根链接。
   * 如果未设置此值,默认为1-[[variancePower]],与R中的"statmod"包匹配。
   *
   * @group param
   */
  @Since("2.2.0")
  final val linkPower: DoubleParam = new DoubleParam(this, "linkPower",
    "幂链接函数中的索引参数。仅适用于Tweedie家族。")

  /** @group getParam */
  @Since("2.2.0")
  def getLinkPower: Double = $(linkPower)

  /**
   * 链接预测(线性预测器)列名参数。
   * 默认未设置,表示我们不输出链接预测。
   *
   * @group param
   */
  @Since("2.0.0")
  final val linkPredictionCol: Param[String] = new Param[String](this, "linkPredictionCol",
    "链接预测(线性预测器)列名参数")

  /** @group getParam */
  @Since("2.0.0")
  def getLinkPredictionCol: String = $(linkPredictionCol)

  /**
   * 偏置列名参数。如果未设置或为空,我们将所有实例偏置视为0.0。
   * 偏置特征具有常量系数1.0。
   *
   * @group param
   */
  @Since("2.3.0")
  final val offsetCol: Param[String] = new Param[String](this, "offsetCol", "偏置列名参数。" +
    "如果未设置或为空,我们将所有实例偏置视为0.0")

  /** @group getParam */
  @Since("2.3.0")
  def getOffsetCol: String = $(offsetCol)

  /** 检查权重列是否设置且非空。 */
  private[regression] def hasWeightCol: Boolean =
    isSet(weightCol) && $(weightCol).nonEmpty

  /** 检查偏置列是否设置且非空。 */
  private[regression] def hasOffsetCol: Boolean =
    isSet(offsetCol) && $(offsetCol).nonEmpty

  /** 检查是否应输出链接预测。 */
  private[regression] def hasLinkPredictionCol: Boolean = {
   
    isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty
  }

  /**
   * 优化的求解算法。
   * 支持的选项有:"irls"(迭代加权最小二乘)。
   * 默认值:"irls"
   *
   * @group param
   */
  @Since("2.0.0")
  final override val solver: Param[String] = new Param[String](this, "solver",
    "优化的求解算法。支持的选项有:" +
      s"${
     supportedSolvers.mkString(", ")}。 (默认为irls)",
    ParamValidators.inArray[String](supportedSolvers))

  @Since("2.0.0")
  override def validateAndTransformSchema(
      schema: StructType,
      fitting: Boolean,
      featuresDataType: DataType): StructType = {
   
    if ($(family).toLowerCase(Locale.ROOT) == "tweedie") {
   
      if (isSet(link)) {
   
        logWarning("当family为tweedie时,请使用linkPower参数指定链接函数。" +
          "设置link参数将不起作用。")
      }
    } else {
   
      if (isSet(variancePower)) {
   
        logWarning("当family不是tweedie时,设置variancePower参数将不起作用。")
      }
      if (isSet(linkPower)) {
   
        logWarning("当family不是tweedie时,请使用link参数指定链接函数。" +
          "设置linkPower参数将不起作用。")
      }
      if (isSet(link)) {
   
        require(supportedFamilyAndLinkPairs.contains(
          Family.fromParams(this) -> Link.fromParams(this)),
          s"不支持${
     $(link)}链接函数的广义线性回归模型,家族为${
     $(family)}。")
      }
    }

    val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)

    if (hasOffsetCol) {
   
      SchemaUtils.checkNumericType(schema, $(offsetCol))
    }

    if (hasLinkPredictionCol) {
   
      SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType)
    } else {
   
      newSchema
    }
  }
}

源码分析

源码是GeneralizedLinearRegressionBase trait的定义,它包含了广义线性回归的参数。

下面是对代码片段的分析:

  • GeneralizedLinearRegressionBase 是一个trait,继承了PredictorParams和一些其他特质,并实现了一些参数相关的方法。它表示广义线性回归模型的基本参数。

  • family: Param[String]:错误分布的描述,用于模型中使用的family名称参数。

  • variancePower: DoubleParam:Tweedie分布方差函数中的幂参数,描述方差与分布均值之间的关系。

  • link: Param[String]:link函数的名称,提供线性预测器与分布函数均值之间的关系。

  • linkPower: DoubleParam:幂链接函数中的索引参数。

  • linkPredictionCol: Param[String]:链接预测(线性预测器)列名参数。

  • offsetCol: Param[String]:偏置列名参数。

  • solver: Param[String]:优化的求解算法。

  • hasWeightCol: Boolean:检查权重列是否设置且非空。

  • hasOffsetCol: Boolean:检查偏置列是否设置且非空。

  • hasLinkPredictionCol: Boolean:检查是否应输出链接预测。

  • validateAndTransformSchema:验证并转换模式,确保输入数据的正确性。

该trait定义了广义线性回归模型所需的参数,并提供了参数的访问方法和验证逻辑。

拟合类参数

用法总结:

  1. family:设置广义线性模型的家族。支持的家族有"gaussian"、“binomial”、“poisson”、“gamma"和"tweedie”。每个家族支持的链接函数如下所示。每个家族的第一个链接函数是默认的。
    • “gaussian” : “identity”, “log”, “inverse”
    • “binomial” : “logit”, “probit”, “cloglog”
    • “poisson” : “log”, “identity”, “sqrt”
    • “gamma” : “inverse”, “identity”, “log”
    • “tweedie” : 使用"linkPower"参数指定的幂链接函数。Tweedie 家族的默认链接幂为 1 - variancePower。
  2. variancePower:仅当family为"tweedie"时使用。设置 Tweedie 家族的链接幂。默认值为 0.0,对应于"gaussian"家族。
  3. linkPower:仅当family为"tweedie"时使用。设置 Tweedie 家族的链接幂。
  4. link:设置链接函数的名称。仅当family不是"tweedie"时使用。
  5. fitIntercept:设置是否拟合截距。默认值为 true。
  6. maxIter:设置最大迭代次数(适用于 solver “irls”)。默认值为 25。
  7. tol:设置迭代的收敛容忍度。较小的时间会导致更高的准确性,但也会导致更多的迭代次数。默认值为 1E-6。
  8. regParam:设置 L2 正则化的正则化参数。默认值为 0.0。
  9. weightCol:设置权重列名。如果未设置或为空,我们将所有实例权重视为 1.0。
  10. offsetCol:设置偏移列名。如果未设置或为空,我们将所有实例的偏移量视为 0.0。
  11. predictionCol:设置预测结果的列名。
  12. linkPredictionCol:设置链接预测(线性预测器)列名。
  13. solver:设置用于优化的求解器算法。目前只支持"irls",也是默认求解器。
  14. standardizeFeaturesstandardizeLabel:这两个参数目前未使用,但保留用于未来扩展。

中文源码

/**
 * :: Experimental ::
 *
 * 广义线性模型的拟合类。
 * (参见 <a href="https://en.wikipedia.org/wiki/Generalized_linear_model">
 * 广义线性模型 (Wikipedia)</a>)
 * 通过给出线性预测器(链接函数)的符号描述和误差分布(family)的描述来指定广义线性模型。
 * 支持的family有"gaussian"、"binomial"、"poisson"、"gamma"和"tweedie"。
 * 每个family支持的链接函数如下所示。每个family的第一个链接函数是默认的。
 *  - "gaussian" : "identity", "log", "inverse"
 *  - "binomial" : "logit", "probit", "cloglog"
 *  - "poisson"  : "log", "identity", "sqrt"
 *  - "gamma"    : "inverse", "identity", "log"
 *  - "tweedie"  : 使用"linkPower"参数指定的幂链接函数。Tweedie家族的默认链接幂为1 - variancePower。
 */
@Experimental
@Since("2.0.0")
class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String)
  extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel]
  with GeneralizedLinearRegressionBase with DefaultParamsWritable with Logging {
   

  import GeneralizedLinearRegression._

  @Since("2.0.0")
  def this() = this(Identifiable.randomUID("glm"))

  /**
   * 设置参数[[family]]的值。
   * 默认值为"gaussian"。
   *
   * @group setParam
   */
  @Since("2.0.0")
  def setFamily(value: String): this.type = set(family, value)
  setDefault(family -> Gaussian.name)

  /**
   * 设置参数[[variancePower]]的值。
   * 仅当family为"tweedie"时使用。
   * 默认值为0.0,对应于"gaussian"家族。
   *
   * @group setParam
   */
  @Since("2.2.0")
  def setVariancePower(value: Double): this.type = set(variancePower, value)
  setDefault(variancePower -> 0.0)

  /**
   * 设置参数[[linkPower]]的值。
   * 仅当family为"tweedie"时使用。
   *
   * @group setParam
   */
  @Since("2.2.0")
  def setLinkPower(value: Double): this.type = set(linkPower, value)

  /**
   * 设置参数[[link]]的值。
   * 仅当family不是"tweedie"时使用。
   *
   * @group setParam
   */
  @Since("2.0.0")
  def setLink(value: String): this.type = set(link, value)

  /**
   * 设置是否拟合截距。
   * 默认值为true。
   *
   * @group setParam
   */
  @Since("2.0.0")
  def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)

  /**
   * 设置最大迭代次数(适用于solver "irls")。
   * 默认值为25。
   *
   * @group setParam
   */
  @Since("2.0.0")
  def setMaxIter(value: Int): this.type = set(maxIter, value)
  setDefault(maxIter -> 25)

  /**
   * 设置迭代的收敛容忍度。
   * 较小的值会带来更高的准确性,但也会导致更多的迭代次数。
   * 默认值为1E-6。
   *
   * @group setParam
   */
  @Since("2.0.0")
  def setTol(value: Double): this.type = set(tol, value)
  setDefault(tol -> 1E-6)

  /**
   * 设置L2正则化的正则化参数。
   * 正则化项为
   * <blockquote>
   *    $$
   *    0.5 * regParam * L2norm(coefficients)^2
   *    $$
   * </blockquote>
   * 默认值为0.0。
   *
   * @group setParam
   */
  @Since("2.0.0")
  def setRegParam(value: Double): this.type = set(regParam, value)
  setDefault(regParam -> 0.0)

  /**
   * 设置参数[[weightCol]]的值。
   * 如果未设置或为空,我们将所有实例权重视为1.0。
   * 默认未设置,因此所有实例的权重都是1.0。
   * 在二项家族中,权重对应于试验次数,应为整数。
   * 非整数权重在AIC计算中四舍五入为整数。
   *
   * @group setParam
   */
  @Since("2.0.0")
  def setWeightCol(value: String): this.type = set(weightCol, value)

  /**
   * 设置参数[[offsetCol]]的值。
   * 如果未设置或为空,我们将所有实例偏移视为0.0。
   * 默认未设置,因此所有实例的偏移量都是0.0。
   *
   * @group setParam
   */
  @Since("2.3.0")
  def setOffsetCol(value: String): this.type = set(offsetCol, value)

  /**
   * 设置用于优化的求解器算法。
   * 目前只支持"irls",也是默认求解器。
   *
   * @group setParam
   */
  @Since("2.0.0")
  def setSolver(value: String): this.type = set(solver, value)
  setDefault(solver -> IRLS)

  /**
   * 设置链接预测(线性预测器)列名。
   *
   * @group setParam
   */
  @Since("2.0.0")
  def setLinkPredictionCol(value: String): this.type = set(linkPredictionCol, value)

  override protected def train(
      dataset: Dataset[_]): GeneralizedLinearRegressionModel = instrumented {
    instr =>
    val familyAndLink = FamilyAndLink(this)

    val numFeatures = dataset.select(col($(feat
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BigDataMLApplication

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值