spark 朴素贝叶斯(naive bayes)模型save与load优化

Spark MLLIB中Naive Bayes(朴素贝叶斯)分类模型的保存与加载速度在实际应用场景中,比较慢,先对朴素贝叶斯模型save与load进行优化。优化后,save与load速度提高很多倍(优化前需要4-5分钟,而且比较容易出现问题而失败,优化后只需要几秒钟),模型文件占用空间也减小了50%。

先简单介绍下Naive Bayes模型机制

数据结构:

    case class Data(
                     labels: Array[Double],
                     pi: Array[Double],
                     theta: Array[Array[Double]],
                     modelType: String)

参数说明:

labels: 类目标签-数组
pi:各类目出现概率-数组
theta:每个特征值在每个类目下出现的概率-矩阵
modelType:模型类型-字符串

数据存储:

通过上述结构可以看出,模型的数据都保存成了一行数据,一共4个字段,每个字段是所有类目的相关数据。这样就会遇到些问题,如果类目数量特别多并且特征数量也特别多的话,这一样数据就特别的大了,读写性能会比较低。

优化:

思路:因为加载速度慢的原因是一行数据量大,导致读写慢,所以考虑增加并行度,将一行数据拆分成多行数据,然后读写的时候就能并发的读写,进而提高速度。

代码实现:

代码位置:org.apache.spark.mllib.classification.NaiveBayes.scala   --spark1.6.0

save代码

原代码

 @Since("1.3.0")
  override def save(sc: SparkContext, path: String): Unit = {
    val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType)
    NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
  }

    def save(sc: SparkContext, path: String, data: Data): Unit = {
      val sqlContext = SQLContext.getOrCreate(sc)
      import sqlContext.implicits._

      // Create JSON metadata.
      val metadata = compact(render(
        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
          ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
      sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))

      // Create Parquet data.
      -- val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
      dataRDD.write.parquet(dataPath(path))
    }
修改上述“- -”部分,将数据拆分成多行
修改代码
def save(sc: SparkContext, path: String, data: Data): Unit = {
      val sqlContext = SQLContext.getOrCreate(sc)
      import sqlContext.implicits._

      // Create JSON metadata.
      val metadata = compact(render(
        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
          ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
      sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))

     ++ val labels = data.labels   
     ++ val pi = data.pi
     ++ val theta = data.theta
     ++ val modelType = data.modelType

     ++ var i = 0
     ++ var dateArray: ArrayBuffer[WKData]= new ArrayBuffer()
     ++ labels.foreach(label => {
     ++  dateArray += WKData(labels(i), pi(i), theta(i), modelType)
     ++   i += 1
     ++ })
      // Create Parquet data.
     ++ val dataRDD: DataFrame = sc.parallelize(dateArray, 200).toDF()
      dataRDD.write.parquet(dataPath(path))
    }
load代码

原代码

    @Since("1.3.0")
    def load(sc: SparkContext, path: String): NaiveBayesModel = {
      val sqlContext = SQLContext.getOrCreate(sc)
      // Load Parquet data.
      val dataRDD = sqlContext.read.parquet(dataPath(path))
      // Check schema explicitly since erasure makes it hard to use match-case for checking.
      checkSchema[Data](dataRDD.schema)
      val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1)
      assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
      -- val data = dataArray(0)
      -- val labels = data.getAs[Seq[Double]](0).toArray
      -- val pi = data.getAs[Seq[Double]](1).toArray
      -- val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
      -- val modelType = data.getString(3)
      new NaiveBayesModel(labels, pi, theta, modelType)
    }
修改- -“”部分,读多行数据

修改代码

    def load(sc: SparkContext, path: String): WkNaiveBayesModel = {
      val sqlContext = SQLContext.getOrCreate(sc)
      // Load Parquet data.
      val dataRDD = sqlContext.read.parquet(dataPath(path))
      // Check schema explicitly since erasure makes it hard to use match-case for checking.
      checkSchema[Data](dataRDD.schema)
      ++ val dataDF = dataRDD.select("labels", "pi", "theta", "modelType")
      ++ dataDF.persist()
      ++ val labels = dataDF.map(_.getAs[Double](0)).collect()
      ++ val pi = dataDF.map(_.getAs[Double](1)).collect()
      ++ val theta = dataDF.map(_.getAs[Seq[Double]](2).toArray).collect()
      ++ val modelType = dataDF.first().getString(3)
      new WkNaiveBayesModel(labels, pi, theta, modelType)
    }
只需修改save(sc: SparkContext, path: String)与load(sc: SparkContext, path: String)即可。逻辑很简单。



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

sunyang098

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值