Spark ml ReadWriter原理用途源码分析含逻辑回归调用示例分析点击这里看全文
原理用途
在Spark ML中,ReadWriter
类是一个用于模型的读写操作的辅助工具。它提供了一种机制来读取和写入训练好的机器学习模型。
ReadWriter
的设计思想主要基于Java的序列化机制,并结合了Spark的分布式计算框架特性。其背后的原理是将模型的参数以二进制的形式进行序列化,并使用分布式文件系统(如HDFS)或本地文件系统进行存储和读取。
ReadWriter
的主要用途包括:
-
保存模型:通过
ReadWriter
可以将训练好的机器学习模型保存到文件系统中。这样,在需要使用该模型进行预测或加载到其他环境中时,可以直接从文件系统中读取模型。 -
加载模型:使用
ReadWriter
可以从文件系统中读取已保存的模型,并将其加载到内存中。这使得可以在不同的Spark应用程序或Spark任务之间共享和重复使用模型。 -
模型版本控制:
ReadWriter
还支持对模型进行版本控制,可以为每个模型保存多个版本。这有助于追踪和管理模型的演化过程,方便回溯和对比不同版本的模型效果。
总之,ReadWriter
提供了一种方便而灵活的方式来读取和写入训练好的模型,使得模型的存储、加载和管理更加便捷和高效。
示例(逻辑回归)
object LogisticRegression
@Since("1.6.0")
object LogisticRegression extends DefaultParamsReadable[LogisticRegression] {
@Since("1.6.0")
override def load(path: String): LogisticRegression = super.load(path)
private[classification] val supportedFamilyNames =
Array("auto", "binomial", "multinomial").map(_.toLowerCase(Locale.ROOT))
}
LogisticRegression
是一个对象,它扩展了**DefaultParamsReadable[LogisticRegression]
**特质。这意味着LogisticRegression
类实现了读取和加载ML模型参数的功能。
-
load(path: String): LogisticRegression
:重写父类的load
方法,调用父类的load
方法来加载保存的LogisticRegression
实例。 -
supportedFamilyNames
:一个私有字段,包含支持的逻辑回归族名称的数组。
class LogisticRegressionModel
@Since("1.4.0")
class LogisticRegressionModel private[spark] (
@Since("1.4.0") override val uid: String,
@Since("2.1.0") val coefficientMatrix: Matrix,
@Since("2.1.0") val interceptVector: Vector,
@Since("1.3.0") override val numClasses: Int,
private val isMultinomial: Boolean)
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with MLWritable
**MLWritable
**特质定义了两个方法:
write: MLWriter
:返回一个MLWriter
实例,用于将ML实例保存到磁盘。save(path: String): Unit
:将ML实例保存到指定的路径。这是write.save(path)
的快捷方式。
object LogisticRegressionModel
/**
* 逻辑回归模型
*/
object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
/**
* 读取模型
*
* @return [[MLReader]]实例
*/
override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader
/**
* 加载模型
*
* @param path 模型路径
* @return 加载的逻辑回归模型
*/
override def load(path: String): LogisticRegressionModel = super.load(path)
/** [[LogisticRegressionModel]]的[[MLWriter]]实例 */
private[LogisticRegressionModel]
class LogisticRegressionModelWriter(instance: LogisticRegressionModel)
extends MLWriter with Logging {
/**
* 模型数据
*
* @param numClasses 类别数量
* @param numFeatures 特征数量
* @param interceptVector 截距向量
* @param coefficientMatrix 系数矩阵
* @param isMultinomial 是否多项式逻辑回归
*/
private case class Data(
numClasses: Int,
numFeatures: Int,
interceptVector: Vector,
coefficientMatrix: Matrix,
isMultinomial: Boolean)
/**
* 保存模型
*
* @param path 模型保存路径
*/
override protected def saveImpl(path: String): Unit = {
// 保存元数据和参数
DefaultParamsWriter.saveMetadata(instance, path, sc)
// 保存模型数据:numClasses, numFeatures, intercept, coefficients
val data = Data(instance.numClasses, instance.numFeatures, instance.interceptVector,
instance.coefficientMatrix, instance.isMultinomial)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
/**
* 逻辑回归模型读取器
*/
private class LogisticRegressionModelReader extends MLReader[LogisticRegressionModel] {
/** 加载模型时与元数据进行验证的类名 */
private val className = classOf[LogisticRegressionModel].getName
/**
* 加载模型
*
* @param path 模型路径
* @return 加载的逻辑回归模型
*/
override def load(path: String): LogisticRegressionModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)
val model = if (major < 2 || (major == 2 && minor == 0)) {
// 2.0及之前版本
val Row(numClasses: Int, numFeatures: Int, intercept: Double, coefficients: Vector) =
MLUtils.convertVectorColumnsToML(data, "coefficients")
.select("numClasses", "numFeatures", "intercept", "coefficients")
.head()
val coefficientMatrix =
new DenseMatrix(1, coefficients.size, coefficients.toArray, isTransposed = true)
val interceptVector = Vectors.dense(intercept)
new LogisticRegressionModel(metadata.uid, coefficientMatrix,
interceptVector, numClasses, isMultinomial = false)
} else {
// 2.1及之后版本
val Row(numClasses: Int, numFeatures: Int, interceptVector: Vector,
coefficientMatrix: Matrix, isMultinomial: Boolean) = data
.select("numClasses", "numFeatures", "interceptVector", "coefficientMatrix",
"isMultinomial").head()
new LogisticRegressionModel(metadata.uid, coefficientMatrix, interceptVector,
numClasses, isMultinomial)
}
metadata.getAndSetParams(model)
model
}
}
}
逻辑回归模型的源码主要包括以下几个部分:
-
LogisticRegressionModel
对象:定义了逻辑回归模型的读取和加载方法。实现了MLReadable
接口,重写了read
和load
方法。内部还定义了一个LogisticRegressionModelWriter
类和一个LogisticRegressionModelReader
类。 -
LogisticRegressionModelWriter
类:继承自MLWriter
,负责保存逻辑回归模型。其中,私有内部类Data
用于保存模型数据,包括类别数量、特征数量、截距向量、系数矩阵以及是否多项式逻辑回归。saveImpl
方法将元数据和参数保存到路径下,并将模型数据保存为Parquet格式文件。 -
LogisticRegressionModelReader
类:继承自MLReader
,负责加载逻辑回归模型。其中,load
方法根据不同的Spark版本从保存路径中加载元数据和模型数据,并构建对应的逻辑回归模型对象。
总结起来,逻辑回归模型的源码提供了模型的保存和加载功能,可以方便地将训练好的模型保存到文件中,并在需要的时候重新加载使用。
Spark ml ReadWrite源码
BaseReadWrite
/**
* 用于`MLWriter`和`MLReader`的特质。
*/
private[util] sealed trait BaseReadWrite {
private var optionSparkSession: Option[SparkSession] = None
/**
* 设置用于保存/加载的Spark Session。
*/
@Since("2.0.0")
def session(sparkSession: SparkSession): this.type = {
optionSparkSession = Option(sparkSession)
this
}
/**
* 返回用户指定的Spark Session或默认值。
*/
protected final def sparkSession: SparkSession = {
if (optionSparkSession.isEmpty) {
optionSparkSession = Some(SparkSession.builder().getOrCreate())
}
optionSparkSession.get
}
/**
* 返回用户指定的SQL Context或默认值。
*/
protected final def sqlContext: SQLContext = sparkSession.sqlContext
/** 返回底层的`SparkContext`。 */
protected final def sc: SparkContext = sparkSession.sparkContext
}
这段源码提供了一个用于保存和加载机器学习模型的基础功能。具体来说,它定义了一个BaseReadWrite
特质,该特质包含以下方法:
session(sparkSession: SparkSession): this.type
:设置用于保存/加载的Spark Session。sparkSession: SparkSession
:返回用户指定的Spark Session或默认值。sqlContext: SQLContext
:返回用户指定的SQL Context或默认值。sc: SparkContext
:返回底层的SparkContext。
通过实现这个特质,可以方便地为自定义的机器学习模型编写保存和加载的方法,并使用Spark的API进行操作。例如,可以创建一个自定义的MLWriter
类,继承自BaseReadWrite
并实现save
方法,用于将模型保存到磁盘上。然后,可以创建一个自定义的MLReader
类,同样继承自BaseReadWrite
并实现load
方法,用于从磁盘上加载模型。这样,就可以使用统一的接口来保存和加载机器学习模型,无论是本地文件系统还是分布式存储系统都可以适用。
此外,该特质还提供了默认的Spark Session和SQL Context,如果用户没有显式指定,将使用默认的Spark配置。这样可以确保在没有额外配置的情况下,仍然能够正常进行保存和加载操作。
MLWriterFormat MLFormatRegister
/**
* 提供ML模型导出功能的抽象类。
*
* 每次调用保存方法时,都会实例化该类的一个新实例。
*
* 必须有一个有效的零参数构造函数,将被调用来实例化。
*
* @since 2.4.0
*/
@Unstable
@Since("2.4.0")
trait MLWriterFormat {
/**
* 将提供的流水线阶段写入。
*
* @param path 要写入结果的路径。
* @param session 与写请求关联的SparkSession。
* @param optionMap 用户提供的选项,以字符串形式存储。
* @param stage 要保存的流水线阶段。
*/
@Since("2.4.0")
def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String],
stage: PipelineStage): Unit
}
/**
* ML导出格式应实现此特质,以便用户可以指定导出器的简短名称而不是完全限定的类名。
*
* 每次调用保存方法时,都会实例化该类的一个新实例。
*
* @since 2.4.0
*/
@Unstable
@Since("2.4.0")
trait MLFormatRegister extends MLWriterFormat {
/**
* 表示该格式提供程序使用的格式的字符串。这个字符串与stageName一起被子类覆盖,为写入器提供了一个漂亮的别名。例如:
*
* {
{
{
* override def format(): String =
* "pmml"
* }}}
* 表示这个格式能够保存pmml模型。
*
* 必须有一个有效的零参数构造函数,将被调用来实例化。
*
* 格式发现是使用ServiceLoader完成的,请确保在META-INF/services中列出您的格式。
* @since 2.4.0
*/
@Since("2.4.0")
def format(): String
/**
* 表示该写入器支持的阶段类型的字符串。这个字符串与format一起被子类覆盖,为写入器提供了一个漂亮的别名。例如:
*
* {
{
{
* override def stageName(): String =
* "org.apache.spark.ml.regression.LinearRegressionModel"
* }}}
* 表示这个格式能够保存Spark自带的PMML模型。
*
* 格式发现是使用ServiceLoader完成的,请确保在META-INF/services中列出您的格式。
* @since 2.4.0
*/
@Since("2.4.0")
def stageName(): String
private[ml] def shortName(): String = s"${
format()}+${
stageName()}"
}
这段源码提供了用于导出ML模型的功能。具体来说,它定义了两个特质:MLWriterFormat
和MLFormatRegister
。
MLWriterFormat
特质是一个抽象类,需要实现一个名为write
的方法,该方法用于将给定的流水线阶段保存到指定路径。它还包含一个与写请求关联的SparkSession,以及用户提供的选项。
MLFormatRegister
特质是一个扩展自MLWriterFormat
的特质,它定义了两个额外的方法:format()
和stageName()
。这些方法分别返回导出格式的字符串表示和支持的阶段类型的字符串表示。子类需要覆盖这些方法来提供导出器的别名。此外,MLFormatRegister
特质还定义了一个私有方法shortName()
,用于生成格式的简短名称,格式为format()+stageName()
。
通过实现这两个特质,可以创建自定义的ML导出器,并使用统一的接口将ML模型保存到指定路径。可以根据需要定义不同的导出格式,并在format()
和stageName()
方法中指定格式和阶段类型的字符串表示。然后,可以使用MLWriter
类的save
方法来保存模型,指定导出格式的别名作为参数。这样,就可以轻松地将ML模型以不同的格式导出,而无需关心底层实现细节。
此外,该源码还提供了对SparkSession、SQLContext和SparkContext的访问方法,以便在保存过程中使用它们。这些方法可以确保在没有显式指定的情况下,仍然能够使用默认的Spark配置进行保存操作。
MLWriter
/**
* 用于以Spark内部格式保存ML实例的实用程序类的抽象类。
*/
@Since("1.6.0")
abstract class MLWriter extends BaseReadWrite with Logging {
protected var shouldOverwrite: Boolean = false
/**
* 将ML实例保存到输入路径。
*/
@Since("1.6.0")
@throws[IOException]("如果输入路径已经存在但未启用覆盖功能。")
def save(path: String): Unit = {
new FileSystemOverwrite().handleOverwrite(path, shouldOverwrite, sparkSession)
saveImpl(path)
}
/**
* `save()` 处理覆盖操作,然后调用这个方法。子类应该重写这个方法来实现实例的实际保存。
*/
@Since("1.6.0")
protected def saveImpl(path: String): Unit
/**
* 如果输出路径已经存在,则覆盖它。
*/
@Since("1.6.0")
def overwrite(): this.type = {
shouldOverwrite = true
this
}
/**
* 用于存储此写入器的额外选项的映射。
*/
protected val optionMap: mutable.Map[String, String] = new mutable.HashMap[String, String]()
/**
* 向底层MLWriter添加一个选项。有关可能的选项,请参阅特定模型的写入器的文档。选项名称(键)不区分大小写。
*/
@Since("2.3.0")
def option(key: String, value: String): this.type = {
require(key != null && !key.isEmpty)
optionMap.put(key.toLowerCase(Locale.ROOT), value)
this
}
// 为了与Java兼容性而覆盖
@Since("1.6.0")
override def session(sparkSession: SparkSession): this.type = super.session(sparkSession)
}
MLWriter
是一个抽象类,用于实现将ML实例以Spark内部格式保存到磁盘的功能。它继承自BaseReadWrite
特质,并添加了一些额外的方法和属性。
save(path: String): Unit
:将ML实例保存到指定路径。如果路径已经存在但未启用覆盖功能,则会抛出IOException
异常。saveImpl(path: String): Unit
:子类需要重写该方法来实现实际的保存逻辑。overwrite(): this.type
:启用覆盖功能,即如果输出路径已经存在,则覆盖它。option(key: String, value: String): this.type
:向MLWriter添加一个选项。可以使用这个方法来设置特定模型写入器的选项。选项名称不区分大小写。optionMap: mutable.Map[String, String]
:用于存储额外选项的映射。
此外,MLWriter
还覆盖了session(sparkSession: SparkSession): this.type
方法,以确保在保存过程中使用正确的SparkSession。
通过继承MLWriter
并实现saveImpl
方法,可以创建自定义的ML保存器,并使用统一的接口将ML实例保存到磁盘。可以使用save
方法来触发保存操作,使用overwrite
方法启用覆盖功能,使用option
方法设置特定模型写入器的选项。
注意:该类是一个抽象类,需要根据具体的ML模型进行实现才能正常工作。
GeneralMLWriter
/**
* 根据请求的格式委托的ML Writer。
*/
@Unstable
@Since("2.4.0")
class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
private var source: String =