package com.hx.data.collection.wx import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.RandomForest import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.model.RandomForestModel import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD /** * Created by Administrator on 2017/3/1. */ object RandomForestTest { def main(args: Array[String]): Unit = { val sparkConf = new SparkConf().setAppName("RandomFores") val sc = new SparkContext(sparkConf) //val data:RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc,"/data/sample.txt") val data:RDD[LabeledPoint] = sc.textFile("/data/sample.txt").map(line=>create_label_point(line)) val numclasses = 2 val featureSubsetStrategy = "auto" val numTrees = 3 val model:RandomForestModel = RandomForest.trainClassifier(data, Strategy.defaultStrategy("classification"),numTrees, featureSubsetStrategy,new java.util.Random().nextInt()) //val input:RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc,"/data/input.txt") val input:RDD[LabeledPoint] = sc.textFile("/data/input.txt").map(line=>create_label_point(line)) val predictResult = input.map { point => val prediction = model.predict(point.features) (point.label,prediction) } predictResult.collect().foreach( x=>println("res"+x) ) sc.stop() } def create_label_point(line:String):LabeledPoint = { //字符串去空格,以逗号分隔转为数组 val linearr = line.trim().split(" ") val linedoublearr = linearr.map(x=>x.toDouble) //定长数组转可变数组 val linearrbuff = linedoublearr.toBuffer //移除label元素(将linedoublearr的第一个元素作为标签) linearrbuff.remove(0) //将剩下的元素转为向量 val vectorarr = linearrbuff.toArray val vector = Vectors.dense(vectorarr) //返回标签向量 LabeledPoint(linedoublearr(0),vector) } }
scala字符串转为标签向量(LabeledPoint)
最新推荐文章于 2024-08-22 11:41:26 发布