最近的在学习spark,同时作业有用到spark-udf,在这里简单总结一下:
下面我分别分析了udf的两种用法(当然都是针对spark2.x的)以及决策树和随机森林的使用以及交叉验证的使用。
spark sql用法
- 注册自定义函数:
spark.udf.register("stringcount", protein _)
- 使用:
var date: DataFrame =spark.sql("select label,string,stringcount(string) as features from dna")
dataframe的用法
- 注册自定义函数
val stringcount: UserDefinedFunction = udf(dna _)
- 使用:
pdf=pdf.withColumn("features", stringcount(col("string")))
再接着是决策树和随机森林的用法了,这里不细说了,只简单地说明数据类型转换
date.select(col("label").cast(DoubleType))
date = date.withColumn("label", col("label").cast(DoubleType))
date.printSchema()
源码如下:
package cn.fly.ml
import java.util.Date
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier, RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, MulticlassClassificationEvaluator}
import org.apache.spark.ml.feature._
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
/**
* Created by hqf on 2019/4/12.
*/
object DNA_Protein {
val sBuffer = new Array[String](400) // protein
val strings = Array("A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "v", "W", "Y")
def main(args: Array[String]): Unit = {
val start1=new Date()
val spark = SparkSession.builder().appName("CsvDataSource")
.master("local[8]")
.getOrCreate()
//读取csv文件C:\Users\hqf\Desktop\学习记录\新建文件夹
//hdfs: hdfs://192.168.178.129:9000/input/protein.arff
val csv: DataFrame = spark.read.csv("C:\\Users\\hqf\\Desktop\\高性能实践\\实验2\\数据\\enhancer.arff")
//csv.printSchema()
//格式化数据 转化成dataframe
var pdf: DataFrame = csv.toDF("string", "label")
//pdf.show()
//注册表
//pdf.registerTempTable("dna");
pdf.createOrReplaceTempView("dna");
var dd: DataFrame =spark.sql("select string ,label from dna ")
//dd.show()//未进行特征提取的数据
val test1: DataFrame =spark.sql("select string from dna")
///--------------下面为自定义函数的使用-----------------------------------
//判断是蛋白质还是dna
spark.udf.register("test", test _)
println(test1.collect()(0).toString())
var flag=test(test1.collect()(0).toString());
println(flag)
import org.apache.spark.sql.functions._
val stringcount: UserDefinedFunction = udf(dna _)//dataframe用法
if (flag>=0.9){//dna
//spark.udf.register("stringcount", dna _)
// val stringcount: UserDefinedFunction = udf(dna _)//dataframe用法
}else{//蛋白质
spark.udf.register("stringcount", protein _)//spark sql用法
//val stringcount = udf(protein _)//dataframe用法
}
/* var date: DataFrame =spark.sql("select label,string,stringcount(string) as features from dna")//spark sql用法
date.show(false)//特征提取之后的数据*/
//转化单列数据类型
/*date.select(col("label").cast(DoubleType))
date = date.withColumn("label", col("label").cast(DoubleType))
date.printSchema()//输出数据的类型*/
//dataframe用法
pdf=pdf.withColumn("features", stringcount(col("string")))
val date=pdf;
println("------------------------------------------------")
date.show()
//-----------在进行分类之前要数据的格式转化---------------------------------------=
//spark中的各个业务量麻烦的并不是计算这个过程,更多的在于数据格式化和预处理
//---------------------调用randomforest----------------------------------------
val labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(date)
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
val featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4)
.fit(date)
// Split the data into training and test sets (30% held out for testing).
val Array(trainingData, testData) = date.randomSplit(Array(0.7, 0.3))
// Train a RandomForest model.
val rf = new RandomForestClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
.setNumTrees(10)
// Convert indexed labels back to original labels.
val labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels)
// Chain indexers and forest in a Pipeline.
val pipeline = new Pipeline()
.setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))
//====================交叉验证====================
/* val paramGrid = new ParamGridBuilder()
.build()
val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(10)
val cvModel = cv.fit(trainingData)
val predictions=cvModel.transform(testData)
predictions.show()
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println(s"Test Error = ${(1.0 - accuracy)}")*/
//===============================================
// Train model. This also runs the indexers.
val model = pipeline.fit(trainingData)
// Make predictions.
val predictions: DataFrame = model.transform(testData)
// Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(5)
// Select (prediction, true label) and compute test error.
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println(s"Test Error = ${(1.0 - accuracy)}")
val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
println(s"Learned classification forest model:\n ${rfModel.toDebugString}")
//--------------------------------------------------------------------------------
val end1=new Date()
println((end1.getTime-start1.getTime)/1000.0,"s")
spark.stop()
}
//传入dna字符串返回各个特征的数据
def dna(string:String):DenseVector={
var string1 =string.toUpperCase();
//dna 特征提取
var string22=Array("aa","at","ag","ac","ta","tt","tg","tc","ga","gg","gt","gc","ca","ct","cg","cc")
var sum =string1.length()-1.0;
var index =0;
var rate=new Array[Double](16)
var count =0.0;
for ( i <- 0 to (string22.length-1)){
var ss=string22(i).toUpperCase()
string1=string.toUpperCase()
count =0.0;
index=string1.indexOf(ss)
while( index != -1){
string1=string1.substring(index + ss.length());
count =count+1;
index=string1.indexOf(ss)
}
rate(i) =count/sum;
}
val vector=new DenseVector(rate)
return vector;
}
def sort(): Unit = {
var ff = 0
var i = 0
while ( {
i < strings.length
}) {
var j = 0
while ( {
j < strings.length
}) {
sBuffer(ff) = strings(i) + strings(j)
ff += 1
{
j += 1; j - 1
}
}
{
i += 1; i - 1
}
}
}
//传入protein字符串返回各个特征的数据
def protein(string:String):DenseVector={
sort();
var string1 =string.toUpperCase();
//dna 特征提取
var sum =string1.length()-1.0;
var index =0;
var rate=new Array[Double](400)
var count =0.0;
for ( i <- 0 to (sBuffer.length-1)){
var ss=sBuffer(i).toUpperCase()
string1=string.toUpperCase()
count =0.0;
index=string1.indexOf(ss)
while( index != -1){
string1=string1.substring(index + ss.length());
count =count+1;
index=string1.indexOf(ss)
}
rate(i) =count/sum;
}
val vector=new DenseVector(rate)
return vector;
}
//-----------------------------------------------
def test(string:String):Double={
var string1 =string.toUpperCase();
//dna 特征提取
var sum =string1.length();
var index =0;
var rate=new Array[Double](6)
var rate1=0.0d;
var count =0.0;
var string2=Array("A","T","G","C","U","X")
for ( i <- 0 to (string2.length-1)){
var ss=string2(i).toUpperCase()
string1=string.toUpperCase()
count =0.0;
index=string1.indexOf(ss)
while( index != -1){
string1=string1.substring(index + ss.length());
count =count+1;
index=string1.indexOf(ss)
}
rate(i) =count/sum;
// println(rate(i))
rate1=rate1+rate(i);
}
return rate1;
}
}