在官方的API文档可以查到用法。
def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel
input RDD of (label, array of features) pairs. Every vector should be a frequency vector or a count vector.
lambda The smoothing parameter
modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be multinomial or bernoulli,缺省为multinomial
调用方法很简单,用Iris数据集进行测试,代码如下。
package classify
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.classification.NaiveBayes
import org.apache.spark.mllib.evaluation.MulticlassMetrics
object bayes {
def isValid(line: String): Boolean = {
val parts = line.split(",")
return parts.length == 5
}
def parseLine(line: String): LabeledPoint = {
val parts = line.split(",")
val vd: Vector = Vectors.dense(parts(0).toDouble, parts(1).toDouble, parts(2).toDouble, parts(3).toDouble)
var target = 0
parts(4) match {
case "Iris-setosa" => target=1;
case "Iris-versicolor" => target = 2;
case "Iris-virginica" => target = 3;
}
return LabeledPoint(target, vd )
}
def main(args: Array[String]) {
val conf = new SparkConf().setMaster(args(0)).setAppName("")
val sc = new SparkContext(conf)
val data = sc.textFile(args(1)).filter(isValid(_)).map(parseLine(_))
val splits = data.randomSplit(Array(0.7, 0.3), seed=11L)
val trainData = splits(0)
val testData = splits(1)
val model = NaiveBayes.train(trainData, lambda = 1.0)
val predictionAndLabel = testData.map(p => (model.predict(p.features), p.label))
predictionAndLabel.foreach(println)
val metrics = new MulticlassMetrics(predictionAndLabel)
val precision = metrics.precision
println("Precision = " + precision)
}
}