由于程序中还有其他功能,这里只列出了logistic回归的部分,因此引入了一些不需要的包,这里引用了一些网上的资源
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.sql._
import java.io._
import java.util.Properties
import scala.io.Source
import org.apache.spark.sql.types._
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.DataFrameReader
import org.apache.spark.sql.DataFrameStatFunctions
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.classification.LogisticRegression
object scalaSpark {
def main(args: Array[String]) :Unit = {
val conf = new SparkConf().setAppName("Simple Application") //给Application命名
conf.setMaster("local")
val sc = new SparkContext(conf) //创建SparkContext
val sqlContext = new SQLContext(sc)
val dataList: List[(Double, String, Double, Double, String, Double, Double, Double, Double)] = List(
(0, "male", 37, 10, "no", 3, 18, 7, 4),
(0, "female", 27, 4, "no", 4, 14, 6, 4),
(0, "female", 32, 15, "yes", 1, 12, 1, 4),
(0, "male", 57, 15, "yes", 5, 18, 6, 5),
(0, "male", 22, 0.75, "no", 2, 17, 6, 3),
(0, "female", 32, 1.5, "no", 2, 17, 5, 5),
(0, "female", 22, 0.75, "no", 2, 12, 1, 3),
(0, "male", 57, 15, "yes", 2, 14, 4, 4),
(0, "female", 32, 15, "yes", 4, 16, 1, 2),
(0, "male", 22, 1.5, "no", 4, 14, 4, 5),
(0, "male", 37, 15, "yes", 2, 20, 7, 2),
(0, "male", 27, 4, "yes", 4, 18, 6, 4),
(0, "male", 47, 15, "yes", 5, 17, 6, 4),
(0, "female", 22, 1.5, "no", 2, 17, 5, 4),
(0, "female", 27, 4, "no", 4, 14, 5, 4),
(0, "female", 37, 15, "yes", 1, 17, 5, 5),
(0, "female", 37, 15, "yes", 2, 18, 4, 3),
(0, "female", 22, 0.75, "no", 3, 16, 5, 4),
(0, "female", 22, 1.5, "no", 2, 16, 5, 5),
(0, "female", 27, 10, "yes", 2, 14, 1, 5),
(1, "female", 32, 15, "yes", 3, 14, 3, 2),
(1, "female", 27, 7, "yes", 4, 16, 1, 2),
(1, "male", 42, 15, "yes", 3, 18, 6, 2),
(1, "female", 42, 15, "yes", 2, 14, 3, 2),
(1, "male", 27, 7, "yes", 2, 17, 5, 4),
(1, "male", 32, 10, "yes", 4, 14, 4, 3),
(1, "male", 47, 15, "yes", 3, 16, 4, 2),
(0, "male", 37, 4, "yes", 2, 20, 6, 4))
val colArray1: Array[String] = Array("affairs", "gender", "age", "label", "children", "religiousness", "education", "occupation", "rating")
val data1 = sqlContext.createDataFrame(dataList).toDF(colArray1: _*)
val data = data1.select( "affairs","age", "religiousness", "education", "occupation", "rating").toDF()
data.show(5)
val colArray2 = Array( "age", "religiousness", "education", "occupation", "rating")
val vecDF: DataFrame = new VectorAssembler().setInputCols(colArray2).setOutputCol("features").transform(data)
vecDF.show
val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.9, 0.1), seed = 12345)
trainingDF.show()
// Set parameters for the algorithm.
// Here, we limit the number of iterations to 10.
val lr = new LogisticRegression().setMaxIter(10)
// Fit the model to the data.
val model = lr.setLabelCol("affairs").setFeaturesCol("features").fit(trainingDF)
// Inspect the model: get the feature weights.
//val weights = model.weightCol
// Given a dataset, predict each point's label, and show the results.
model.transform(trainingDF).show(50)
}