算法小白的第一次尝试---SVM实现

23 篇文章 0 订阅
23 篇文章 0 订阅
import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.optimization.HingeGradient
import org.apache.spark.mllib.optimization.SquaredL2Updater
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.SparkSession
/**
  * @author XiaoTangBao
  * @date 2019/3/6 21:20
  * @version 1.0
  */
object SVM {
  def main(args: Array[String]): Unit = {
    //日志屏蔽
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
    val sparkSession = SparkSession.builder().master("local[4]").appName("SVM").getOrCreate()
    val sc = sparkSession.sparkContext
    //获取数据源
    val data = sc.textFile("G:\\mldata\\iris.txt")
    //spark SVM中要求label为0和1,实际上内部转化为-1和1
    val pddata = data.map(str => str.split('|')).map(arr =>(arr(0).toDouble,arr(1).toDouble,arr(2).toDouble,arr(3).toDouble,arr(4)))
      .filter(tuple =>tuple._5.equals("Iris-setosa") || tuple._5.equals("Iris-versicolor"))
      .map(tuple => if(tuple._5.equals("Iris-setosa")) (tuple._1,tuple._2,tuple._3,tuple._4,0) else (tuple._1,tuple._2,tuple._3,tuple._4,1))
      .map(tuple =>LabeledPoint(tuple._5,Vectors.dense(tuple._1,tuple._2,tuple._3,tuple._4)))
    //简单交叉验证
    val splitdata = pddata.randomSplit(Array(0.8,0.2))
    val traindata = splitdata(0).cache()
    val testdata = splitdata(1)
    //模型参数设置
    val model = new SVMWithSGD()
        model.optimizer
          .setNumIterations(1000)
          .setRegParam(0.1)
          .setStepSize(0.3)
          .setMiniBatchFraction(0.5)
          .setGradient(new HingeGradient())
          .setUpdater(new SquaredL2Updater)
    val svmModel = model.run(traindata)
    svmModel.save(sc,"C:\\users\\Java_Man_China\\desktop\\model1")
    val sameModel = SVMModel.load(sc,"C:\\users\\Java_Man_China\\desktop\\model1")
    val score = sameModel.predict(testdata.map(lab => lab.features))
    val scoreAndLabel = score.zip(testdata.map(lab => lab.label))
    scoreAndLabel.foreach(println(_))
    //二分类评估器
    val metrics = new BinaryClassificationMetrics(scoreAndLabel)
    val auroc = metrics.areaUnderROC()
    println(auroc)
    //多分类评估器
    val metric = new MulticlassMetrics(scoreAndLabel)
    val ac = metric.accuracy
    println(ac)
  }
}
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值