前面通过自己编写朴素贝叶斯算法的源码,对贝叶斯算法也有了一定的了解,今天就直接调用spark的NaiveBayes对数据进行处理。
一,训练样本:
第一位表示买不买电脑,第二,三位表示影响买的属性,生成模型
1,1,1
1,1,1
0,1,0
0,0,1
1,0,1
1,1,0
0,0,0
0,0,1
1,0,0
二,测试样本:
这两个表示影响买电脑的属性,放入模型判断其是否会买电脑
1,1
0,1
1,0
1,1
0,0
源代码:
object MyNaiveBayes {
def main(args: Array[String]) {
myNaiveBayes("file///F:/1/newtest3.txt","file///F:/1/newtest2.txt","multinomial")
}
def myNaiveBayes(testurl:String,forecasturl:String,modelType:String): Unit ={
val conf =new SparkConf().setAppName("WordCount").setMaster("local");
val sc = new SparkContext(conf)
val data = sc.textFile(testurl)
val parsedData = data.map { line =>
val parts = line.split(',')
/* println(parts(0)+"***")
parts(1).split(" ").map{x=>
println("==="+x)
}*/
//LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(" ").map(_.toDouble)))
val row1=parts(0).toDouble
val row2=parts.drop(1).map(_.toDouble)
/* var count=0
for(i<- 0 to row2.length-1){
count+=1
}*/
// print(count+"*********")
//var row2=parts.drop(0))
LabeledPoint(row1, Vectors.dense(row2))
}
// Split data into training (60%) and test (40%).
val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0)
val test = splits(1)
val model = NaiveBayes.train(training, lambda = 1.0, modelType = modelType)
//对模型精确度进行分析
val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
// println(accuracy.toFloat+"*************")
// Save and load model
/* model.save(sc, "target/tmp/myNaiveBayesModel")
val sameModel = NaiveBayesModel.load(sc, "target/tmp/myNaiveBayesModel")*/
//预测部分:
//读取需要预测的文件
val forecastdata = sc.textFile(forecasturl)
val forecastresult=forecastdata.map{line=>
val parts=line.split(",")
Vectors.dense(parts.map(_.toDouble))
}
val forecastresult1= model.predict(forecastresult)//使用模型计算
val resultarray=new Array[String](forecastresult1.collect().length)
for(i<- 0 to forecastresult1.collect().length-1 ){
var t=forecastresult.collect()(i)+"====>"+forecastresult1.collect()(i)
resultarray(i)=t
// println(forecastresult.collect()(i)+"====>"+forecastresult1.collect()(i))
}
// println("Prediction of (1,1):" + model.predict(Vectors.dense(1.0,1.0)))
for(i<- 0 to resultarray.length-1 ){
println(resultarray(i)) //输出预测结果
}
}
}