Spark-ML-UDF使用的简单介绍

最近的在学习spark,同时作业有用到spark-udf,在这里简单总结一下:

下面我分别分析了udf的两种用法(当然都是针对spark2.x的)以及决策树和随机森林的使用以及交叉验证的使用。

spark sql用法

  • 注册自定义函数:
spark.udf.register("stringcount", protein _)
  • 使用:
 var date: DataFrame =spark.sql("select label,string,stringcount(string) as features from dna")

dataframe的用法

  • 注册自定义函数
val stringcount: UserDefinedFunction = udf(dna _)
  • 使用:
pdf=pdf.withColumn("features", stringcount(col("string")))

再接着是决策树和随机森林的用法了,这里不细说了,只简单地说明数据类型转换

date.select(col("label").cast(DoubleType))
date = date.withColumn("label", col("label").cast(DoubleType))
date.printSchema()

源码如下:

package cn.fly.ml
import java.util.Date

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier, RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, MulticlassClassificationEvaluator}
import org.apache.spark.ml.feature._
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
/**
  * Created by hqf on 2019/4/12.
  */
object DNA_Protein {

  val sBuffer = new Array[String](400) // protein

  val strings = Array("A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "v", "W", "Y")

  def main(args: Array[String]): Unit = {

    val start1=new Date()

    val spark = SparkSession.builder().appName("CsvDataSource")
      .master("local[8]")
      .getOrCreate()

    //读取csv文件C:\Users\hqf\Desktop\学习记录\新建文件夹
    //hdfs:   hdfs://192.168.178.129:9000/input/protein.arff
    val csv: DataFrame = spark.read.csv("C:\\Users\\hqf\\Desktop\\高性能实践\\实验2\\数据\\enhancer.arff")

    //csv.printSchema()

    //格式化数据 转化成dataframe
    var pdf: DataFrame = csv.toDF("string", "label")

    //pdf.show()
    //注册表
    //pdf.registerTempTable("dna");
    pdf.createOrReplaceTempView("dna");

    var dd: DataFrame =spark.sql("select string ,label from dna ")
    //dd.show()//未进行特征提取的数据
    val test1: DataFrame =spark.sql("select string from dna")

    ///--------------下面为自定义函数的使用-----------------------------------
    //判断是蛋白质还是dna
    spark.udf.register("test", test _)
    println(test1.collect()(0).toString())
    var flag=test(test1.collect()(0).toString());
    println(flag)

    import org.apache.spark.sql.functions._
    val stringcount: UserDefinedFunction = udf(dna _)//dataframe用法
    if (flag>=0.9){//dna
      //spark.udf.register("stringcount", dna _)
      // val stringcount: UserDefinedFunction = udf(dna _)//dataframe用法
    }else{//蛋白质
      spark.udf.register("stringcount", protein _)//spark sql用法
      //val stringcount = udf(protein _)//dataframe用法
    }

/*    var date: DataFrame =spark.sql("select label,string,stringcount(string) as features from dna")//spark sql用法

    date.show(false)//特征提取之后的数据*/

    //转化单列数据类型

    /*date.select(col("label").cast(DoubleType))
    date = date.withColumn("label", col("label").cast(DoubleType))
    date.printSchema()//输出数据的类型*/

    //dataframe用法
    pdf=pdf.withColumn("features", stringcount(col("string")))
    val date=pdf;
    println("------------------------------------------------")
    date.show()

    //-----------在进行分类之前要数据的格式转化---------------------------------------=
    //spark中的各个业务量麻烦的并不是计算这个过程,更多的在于数据格式化和预处理


    //---------------------调用randomforest----------------------------------------

    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel")
      .fit(date)
    // Automatically identify categorical features, and index them.
    // Set maxCategories so features with > 4 distinct values are treated as continuous.
    val featureIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("indexedFeatures")
      .setMaxCategories(4)
      .fit(date)

    // Split the data into training and test sets (30% held out for testing).
    val Array(trainingData, testData) = date.randomSplit(Array(0.7, 0.3))

    // Train a RandomForest model.
    val rf = new RandomForestClassifier()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("indexedFeatures")
      .setNumTrees(10)

    // Convert indexed labels back to original labels.
    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("predictedLabel")
      .setLabels(labelIndexer.labels)

    // Chain indexers and forest in a Pipeline.
    val pipeline = new Pipeline()
      .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))


    //====================交叉验证====================
   /* val paramGrid = new ParamGridBuilder()
      .build()
    val cv = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(new BinaryClassificationEvaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(10)
    val cvModel = cv.fit(trainingData)

    val predictions=cvModel.transform(testData)
    predictions.show()
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")

    val accuracy = evaluator.evaluate(predictions)
    println(s"Test Error = ${(1.0 - accuracy)}")*/
    //===============================================
    // Train model. This also runs the indexers.
    val model = pipeline.fit(trainingData)

    // Make predictions.
    val predictions: DataFrame = model.transform(testData)

    // Select example rows to display.
    predictions.select("predictedLabel", "label", "features").show(5)

    // Select (prediction, true label) and compute test error.
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(predictions)
    println(s"Test Error = ${(1.0 - accuracy)}")

    val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
    println(s"Learned classification forest model:\n ${rfModel.toDebugString}")


    //--------------------------------------------------------------------------------
    val end1=new Date()
    println((end1.getTime-start1.getTime)/1000.0,"s")
    spark.stop()


  }

  //传入dna字符串返回各个特征的数据
  def dna(string:String):DenseVector={
    var string1 =string.toUpperCase();
    //dna 特征提取
    var  string22=Array("aa","at","ag","ac","ta","tt","tg","tc","ga","gg","gt","gc","ca","ct","cg","cc")
    var sum =string1.length()-1.0;
    var  index =0;
    var rate=new Array[Double](16)

    var count =0.0;

    for ( i <- 0 to (string22.length-1)){

      var ss=string22(i).toUpperCase()
      string1=string.toUpperCase()
      count =0.0;

      index=string1.indexOf(ss)

      while( index != -1){

        string1=string1.substring(index + ss.length());
        count =count+1;
        index=string1.indexOf(ss)

      }
      rate(i) =count/sum;

    }
    val vector=new DenseVector(rate)
    return vector;

  }

  def sort(): Unit = {
    var ff = 0
    var i = 0
    while ( {
      i < strings.length
    }) {
      var j = 0
      while ( {
        j < strings.length
      }) {
        sBuffer(ff) = strings(i) + strings(j)
        ff += 1

        {
          j += 1; j - 1
        }
      }

      {
        i += 1; i - 1
      }
    }
  }

  //传入protein字符串返回各个特征的数据
  def protein(string:String):DenseVector={
    sort();
    var string1 =string.toUpperCase();
    //dna 特征提取
    var sum =string1.length()-1.0;
    var  index =0;
    var rate=new Array[Double](400)

    var count =0.0;

    for ( i <- 0 to (sBuffer.length-1)){

      var ss=sBuffer(i).toUpperCase()
      string1=string.toUpperCase()
      count =0.0;

      index=string1.indexOf(ss)

      while( index != -1){

        string1=string1.substring(index + ss.length());
        count =count+1;
        index=string1.indexOf(ss)

      }
      rate(i) =count/sum;

    }
    val vector=new DenseVector(rate)
    return vector;

  }


  //-----------------------------------------------
  def test(string:String):Double={

    var string1 =string.toUpperCase();
    //dna 特征提取
    var sum =string1.length();
    var  index =0;
    var rate=new Array[Double](6)
    var rate1=0.0d;

    var count =0.0;
    var  string2=Array("A","T","G","C","U","X")

    for ( i <- 0 to (string2.length-1)){

      var ss=string2(i).toUpperCase()
      string1=string.toUpperCase()
      count =0.0;

      index=string1.indexOf(ss)

      while( index != -1){

        string1=string1.substring(index + ss.length());
        count =count+1;
        index=string1.indexOf(ss)

      }
      rate(i) =count/sum;
      // println(rate(i))
      rate1=rate1+rate(i);

    }

    return rate1;

  }


}

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值