一、简介
SVM(支持向量机)把分类问题转化为寻找分类平面的问题,并通过最大化分类边界点距离分类平面的距离来实现分类。
二、示例
1、数据
PS:以下是一部分,文件名为sample_svm_data.txt,下载地址:机器学习文件数据包。
2、代码
package com.svm
import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.{SparkConf, SparkContext}
object TestSVMDemo {
def main(args: Array[String]): Unit = {
val sc = new SparkContext(new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName))
val file1 = sc.textFile("src/main/resources/svm/sample_svm_data.txt")
// 读取数据
val file = file1.map { line =>
val strs = line.split(" ")
val label = strs(0)
val features = for (i <- 1 until strs.length) yield strs(i)
LabeledPoint(label.toDouble, Vectors.dense(features.map(_.toDouble).toArray))
}
// 划分训练和测试数据
val array = file.randomSplit(Array(0.8, 0.2), 5)
// 建立模型并训练
val numIterations = 120
val model = SVMWithSGD.train(array(0), numIterations)
val predictions = array(1).map { test =>
val score = model.predict(test.features)
(score, test.label)
}
val showData = predictions.take(10)
for (i <- showData.indices) {
println(showData(i)._1 + "\t" + showData(i)._2 + "\t")
}
//获取评价指标
val metrics = new BinaryClassificationMetrics(predictions)
val auROC = metrics.areaUnderROC()
println(s"准确率(ROC): $auROC")
}
}