/** * Created by zxl on 2016/5/5. * 余弦相识度计算 */ import java.sql.{Connection, DriverManager, ResultSet} import java.text.SimpleDateFormat import java.util.Date ; import kafka.serializer.StringDecoder import kafka.producer._ import org.apache.log4j.{Level, Logger} import org.apache.spark._ import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer} import org.apache.spark.mllib.classification.NaiveBayes import org.apache.spark.mllib.linalg.{Vector, Vectors,SparseVector =>SV} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql.Row import org.apache.spark.streaming._ import org.apache.spark.streaming.kafka._ import org.apache.spark.sql.Row; object Cossimi { case class RawDataRecord(category:String,text:String) def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("cos") val sc = new SparkContext(sparkConf) val sqlContext = new org.apache.spark.sql.SQLContext(sc) import sqlContext.implicits._ Logger.getRootLogger.setLevel(Level.ERROR) var srcRDD = sc.textFile("/home/spark/art.txt").map { x => var data = x.split(",") RawDataRecord(data(0), data(1)) } var trainingDF = srcRDD.toDF() trainingDF.take(2).foreach(println);
var tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words") var wordsData = tokenizer.transform(trainingDF) println("output1:") //===============hashing tf ============= var hashingTF = new HashingTF().setNumFeatures(500000).setInputCol("words").setOutputCol("rawFeatures") var featurizedData = hashingTF.transform(wordsData) var idf = new IDF().setInputCol("rawFeatures").setOutputCol("features") var idfModel = idf.fit(featurizedData) val rescaledData = idfModel.transform(featurizedData) //val content_list = collection.mutable.ListBuffer[String]() val content_list = rescaledData.select($"category", $"features").collect().toList println("out put 2") val docSims = rescaledData.select($"category", $"features").map { case Row(id1, idf1) => import breeze.linalg._ val sv1 = idf1.asInstanceOf[SV] val bsv1 = new SparseVector[Double](sv1.indices, sv1.values, sv1.size) content_list.filter(_(0) != id1).map { case Row(id2, idf2) => val sv2 = idf2.asInstanceOf[SV] val bsv2 = new SparseVector[Double](sv2.indices, sv2.values, sv2.size) val cosSim = bsv1.dot(bsv2).asInstanceOf[Double] / (norm(bsv1) * norm(bsv2)) (id1, id2, cosSim) } } //docSims.take(5).foreach(println) println("insert mysql ......") docSims.foreach{ cos_list => cos_list.sortWith( _._3 >_._3).take(5).foreach{ x => val news_id = x._1 val t_id = x._2 val value = x._3 insert_sim(news_id.toString,t_id.toString,value) } } sc.stop() }