spark svm

dataimput

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
/**
  * Created by MingDong on 2016/11/16.
  */
object data {
  def main(args: Array[String]) {

    val conf = new SparkConf().setMaster("local").setAppName("get_data")
    val sc = new SparkContext(conf)

    // $example on$
    // Load and parse the date
    // 数据格式“标签,特征值1  特征值2  特征值3 。。。。。。。”
    val example_data1= sc.textFile("D:/data/mllib/sample_svm_data.txt").map {
      line=>
        val parts = line.split(",")
        LabeledPoint(parts(0).toDouble,Vectors.dense(parts(1).split(' ').map( _.toDouble)))
    }.cache()
    //数据格式“标签, 特征ID:特征值  特征ID:特征值 。。。。。。。。”
    val example_data2=MLUtils.loadLibSVMFile(sc,"D:/data/mllib/sample_libsvm_data.txt").cache()
    example_data2.foreach(println)
  }
}

svmtrain

import org.apache.spark.ml.feature._
import org.apache.spark.{ml, mllib}
import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._

/**
  * Created by MingDong on 2016/11/16.
  */
object svm_test2 {
  private val spark =SparkSession.builder().config("spark.sql.warehouse.dir", "file:///D:/spark-warehouse")
    .master("local")
    .appName("TFIDFExample ")
    .getOrCreate()
  private val sc = spark.sparkContext

  import spark.implicits._


  case class RawDataRecord(lable: String, message: Array[String])

  def main(args: Array[String]): Unit = {

    //将原始数据映射到DataFrame中,字段category为分类编号,字段text为分好的词,以空格分隔
    val train_data = sc.textFile("file:///D:/data/mllib/s.txt").map {
      x =>
        val data = x.split(",")
        RawDataRecord(data(0), data(1).split(" "))
    }
    val msgDF = spark.createDataFrame(train_data).toDF()
    msgDF.show()
  //  val result=get_word2Vec(msgDF)
    val result=get_tfidf(msgDF)
    result.show()
    val Array(train, test) = result.randomSplit(Array(0.8, 0.2), seed = 12L)


    val traindate=get_dataFrame_to_rdd(train)
    val model3 = SVMWithSGD.train(traindate, 150)



    val testDataRdd=get_dataFrame_to_rdd(test)
    val pres = testDataRdd.map { point =>
      val score = model3.predict(point.features)
      (score,point.label)
    }
    val p = pres.collect()
    println("prediction" + "\t" + "label")
    for (i <- 0 to p.length - 1) {
      println(p(i)._1 + "\t" + p(i)._2)
    }
    val accury = 1.0 * pres.filter(x => x._1 == x._2).count() / test.count()
    println("正确率:" + accury)
  }

  def get_word2Vec(dataDF: DataFrame):DataFrame={
    val word2Vec = new Word2Vec()
      .setInputCol("message")
      .setOutputCol("feature")
    val model = word2Vec.fit(dataDF)
    val result = model.transform(dataDF)
    result
  }
  def get_tfidf(dataDF: DataFrame):DataFrame={

    val labelIndexer = new StringIndexer()
      .setInputCol("lable")
      .setOutputCol("indexedLabel")
      .fit(dataDF)
   val da= labelIndexer.transform(dataDF)
    val hashingTF = new HashingTF()
      .setInputCol("message")
      .setOutputCol("rawFeatures")
      .setNumFeatures(200)
    val featurizedData = hashingTF.transform(da)
    val idf = new IDF()
      .setInputCol("rawFeatures")
      .setOutputCol("feature")
    val idfModel = idf.fit(featurizedData)
    val tfidf_rescaledData = idfModel.transform(featurizedData)
    //特征过滤
//    val selector = new ChiSqSelector()
//      .setNumTopFeatures(20)
//      .setLabelCol("indexedLabel")
//      .setFeaturesCol("feature")
//      .setOutputCol("features")
//    val transformer = selector.fit(tfidf_rescaledData)
//    val selector_feature=transformer.transform(tfidf_rescaledData)
//    selector_feature.show()
    tfidf_rescaledData
  }
  def get_dataFrame_to_rdd(dataDF: DataFrame):RDD[LabeledPoint]={
    val DataRdd = dataDF.select($"lable", $"feature").map {
      case Row(label: String, features: ml.linalg.Vector) =>
        ml.feature.LabeledPoint(label.toDouble, ml.linalg.Vectors.dense(features.toArray))
    }.rdd
    val MllibDataRdd = DataRdd.map { line => val lable = line.label
      val fea = line.features.toArray
      mllib.regression.LabeledPoint(lable, mllib.linalg.Vectors.dense(fea))
    }
    MllibDataRdd
  }
}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值