Naive Bayes - spark.mllib的实现

前面通过自己编写朴素贝叶斯算法的源码,对贝叶斯算法也有了一定的了解,今天就直接调用spark的NaiveBayes对数据进行处理。
一,训练样本:
第一位表示买不买电脑,第二,三位表示影响买的属性,生成模型
1,1,1
1,1,1
0,1,0
0,0,1
1,0,1
1,1,0
0,0,0
0,0,1
1,0,0

二,测试样本:
这两个表示影响买电脑的属性,放入模型判断其是否会买电脑
1,1
0,1
1,0
1,1
0,0

源代码:

object MyNaiveBayes {
  def main(args: Array[String]) {
    myNaiveBayes("file///F:/1/newtest3.txt","file///F:/1/newtest2.txt","multinomial")
  }
  def myNaiveBayes(testurl:String,forecasturl:String,modelType:String): Unit ={
    val conf =new SparkConf().setAppName("WordCount").setMaster("local");
    val sc = new SparkContext(conf)
    val data = sc.textFile(testurl)
    val parsedData = data.map { line =>
      val parts = line.split(',')
      /*  println(parts(0)+"***")
        parts(1).split(" ").map{x=>
          println("==="+x)
        }*/
      //LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(" ").map(_.toDouble)))
      val row1=parts(0).toDouble
      val row2=parts.drop(1).map(_.toDouble)
    /*  var count=0
      for(i<- 0 to row2.length-1){
        count+=1

      }*/
     // print(count+"*********")
      //var row2=parts.drop(0))
      LabeledPoint(row1, Vectors.dense(row2))

    }

    // Split data into training (60%) and test (40%).
    val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L)
    val training = splits(0)
    val test = splits(1)

    val model = NaiveBayes.train(training, lambda = 1.0, modelType = modelType)
    //对模型精确度进行分析
    val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
    val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
   // println(accuracy.toFloat+"*************")
    // Save and load model

    /* model.save(sc, "target/tmp/myNaiveBayesModel")
     val sameModel = NaiveBayesModel.load(sc, "target/tmp/myNaiveBayesModel")*/
    //预测部分:
    //读取需要预测的文件
    val forecastdata = sc.textFile(forecasturl)
    val forecastresult=forecastdata.map{line=>
      val parts=line.split(",")
      Vectors.dense(parts.map(_.toDouble))
    }

    val forecastresult1= model.predict(forecastresult)//使用模型计算

    val resultarray=new Array[String](forecastresult1.collect().length)
    for(i<- 0 to forecastresult1.collect().length-1 ){
      var t=forecastresult.collect()(i)+"====>"+forecastresult1.collect()(i)
      resultarray(i)=t
     // println(forecastresult.collect()(i)+"====>"+forecastresult1.collect()(i))
    }
   // println("Prediction of (1,1):" + model.predict(Vectors.dense(1.0,1.0)))
    for(i<- 0 to resultarray.length-1 ){
      println(resultarray(i)) //输出预测结果
    }
  }
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值