spark好久不用先熟悉一下sql的api
package com.wtx.job014
import org.apache.spark.sql.SparkSession
object demo2 {
def main(args: Array[String]): Unit = {
val train = "file:\\C:\\Users\\86183\\Desktop\\scala_machine_leraning_projects\\ScalaMachineLearningData\\train.csv"
val test = "file:\\C:\\Users\\86183\\Desktop\\scala_machine_leraning_projects\\ScalaMachineLearningData\\test.csv"
val spark: SparkSession = SparkSessionCreate.createSession()
val trainInput = spark.read.option("header", "true").option("inferSchema", "true").format("com.databricks.spark.csv")
.load(train).cache()
print(trainInput.printSchema())
print(trainInput.count)
// trainInput.show()
// trainInput.select("id", "cat1","cat2","cat3","cont1","cont2","cont3","loss").show()
val newDF = trainInput.withColumnRenamed("loss", "label")
//lable不替换的话,学习框架会报错,所以目标值一定要换名
newDF.createOrReplaceTempView("insurance") //insurance 保险
spark.sql("SELECT avg(insurance.label) as AVG_LOSS FROM insurance").show()
spark.sql("SELECT min(insurance.label) as MIN_LOSS FROM insurance").show()
spark.sql("SELECT max(insurance.label) as MAX_LOSS FROM insurance").show()
}
}
数据预处理
package com.wtx.job014
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.{ StringIndexer, StringIndexerModel }
import org.apache.spark.sql.SparkSession
object demo3 {
val train = "file:\\C:\\Users\\86183\\Desktop\\scala_machine_leraning_projects\\ScalaMachineLearningData\\train.csv"
val test = "file:\\C:\\Users\\86183\\Desktop\\scala_machine_leraning_projects\\ScalaMachineLearningData\\test.csv"
var trainSample = 1.0
var testSample = 1.0
val spark: SparkSession = SparkSessionCreate.createSession()
val trainInput = spark.read.option("header", "true").option("inferSchema", "true").format("com.databricks.spark.csv")
.load(train).cache()
val testInput = spark.read.option("header", "true").option("inferSchema", "true").format("com.databricks.spark.csv")
.load(train).cache()
var data = trainInput.withColumnRenamed("loss", "label").sample(false, trainSample)
var DF = data.na.drop()
if (data == DF) {
print("不为空")
} else {
print("dataframe为空值 data的值 " + data + " ____ df的值为: " + DF)
data = DF
}
val seed = 12345L
val splits = data.randomSplit(Array(0.75, 0.25), seed)
val (trainingData, validationData) = (splits(0), splits(1))
//trainingData 训练集 validationData 验证集
trainingData.cache()
validationData.cache()
val testData = testInput.sample(false, testSample).cache()
//测试集采样
//数据集准备好了训练,验证,测试三组数据,现在开始数据预处理
def isCateg(c: String): Boolean = c.startsWith("cat")
def categNewCol(c: String): String = if (isCateg(c)) s"idx_${c}" else c
def removeTooManyCategs(c: String): Boolean = !(c matches "cat(109$|110$|112$|113$|116$)")
def onlyFeatureCols(c: String): Boolean = !(c matches "id|label")
//处理完不要的数据后构建所需要的数据的格式化数据
val featureCols=trainingData.columns.filter(removeTooManyCategs).filter(onlyFeatureCols).map(categNewCol)
val stringUbdexerStages=trainingData.columns.filter(isCateg).map(c=> new StringIndexer()
.setInputCol(c)
.setOutputCol(categNewCol(c))
.fit(trainInput.select(c).union(testInput.select(c)))
)
val assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features")
}
数据预处理,将两份csv数据清洗并将数据结构化
需要注意sbt配置依赖时,如果添加了新的依赖注意sbt严格的格式,还有powermap窗口中reload和compile之后一定要eclipse才能将mlib的库导入ide
sbt文件
ThisBuild / scalaVersion := "2.11.11"
ThisBuild / organization := "com.wtx.job014"
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-sql" % "2.3.2" % "provided",
"org.apache.spark" %% "spark-core" % "2.3.2" % "provided",
"org.apache.spark" %% "spark-streaming" % "2.3.2" % "provided",
"org.apache.spark" %% "spark-mllib" % "2.3.2" % "provided"
)