参考文档:http://spark.apache.org/docs/latest/sql-programming-guide.html
在本地测试。(可自行编译源码spark2.1对应hadoop2.7.2版本的源码,源码包里可以找到对应的代码)
Find full example code at “examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala” in the Spark repo.
package com.huihex.spark
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
/**
* Created by wall-e on 2017/4/1.
*/
object SparkSQLExample {
case class Person(name:String,age:Long)
def main(args: Array[String]): Unit = {
/**
* 不需要再定义SparkConf 和SparkContext
*/
val spark = SparkSession
.builder()
.appName("Spark SQL basic example").master("local")
.config("spark.some.config.option", "some-value")
.getOrCreate()
runBasicDataFrameExample(spark)
runDatasetCreationExample(spark)
runInferSchemaExample(spark)
runProgrammaticSchemaExample(spark)
spark.stop()
}
/**
* 从JSON文件创建简单的DataFrame文件
* 测试常用的DataFrame操作
* @param spark
*/
private def runBasicDataFrameExample(spark: SparkSession): Unit = {
val df = spark.read.json("data/people.json")
df.show()
df.printSchema()
df.select("name").show()
import spark.implicits._
df.select($"name",$"age" +1).show() // Select everybody, but increment the age by 1
df.filter($"age" >21).show() // Select people older than 21
df.groupBy("age").count().show() // Count people by age
// Register the DataFrame as a SQL temporary view
df.createOrReplaceTempView("people1")//注册一张表:people
val sqlDF = spark.sql("SELECT * FROM people1 where age > 18")
sqlDF.show()
// Register the DataFrame as a global temporary view
df.createGlobalTempView("people2")
// Global temporary view is tied to a system preserved database `global_temp`
spark.sql("SELECT * FROM global_temp.people2").show()
// Global temporary view is cross-session
//spark.newSession().sql("SELECT * FROM global_temp.people1").show()//测试可知df.createOrReplaceTempView创建的表在newSession后会报错
spark.newSession().sql("SELECT * FROM global_temp.people2").show()
}
/**
*创建DataSet并进行创建
* @param spark
*/
private def runDatasetCreationExample(spark: SparkSession): Unit ={
import spark.implicits._
val caseClassDS = Seq(Person("Andy",32)).toDS()
caseClassDS.show()
val primitiveDS = Seq(1,2,3).toDS()
primitiveDS.map(x => x + 1).collect().foreach(x => print(x))// Returns: Array(2, 3, 4)
val path = "data/people.json"
val peopleDS = spark.read.json(path).as[Person]
peopleDS.show()
}
/**
* 根据文本文件和自定义的类(Person)创建DataFrame
* @param spark
*/
private def runInferSchemaExample(spark: SparkSession): Unit ={
import spark.implicits._
val peopleDF = spark.sparkContext.textFile("data/people.txt")
.map(_.split(","))
.map(attribute => Person(attribute(0),attribute(1).trim.toInt))
.toDF()
peopleDF.createOrReplaceTempView("people")
val tennagerDF = spark.sql("select name,age from people where age between 13 and 19")
//根据列的序号查
tennagerDF.map(tennager => "Name:"+tennager(0)).show()
tennagerDF.map(tennager => "Name:"+tennager(1)).show()
//根据列名查
tennagerDF.map(tennager => "Name: "+ tennager.getAs[String]("name")).show()
implicit val mapEncoder = org.apache.spark.sql.Encoders.kryo[Map[String,Any]]
tennagerDF.map(tennager => tennager.getValuesMap[Any](List("name","age"))).collect()
.foreach(x=>println(x))
}
/**
* 根据自定义的schema和文本内容创建DataFrame
* @param spark
*/
private def runProgrammaticSchemaExample(spark: SparkSession): Unit ={
import spark.implicits._
val peopleRDD = spark.sparkContext.textFile("data/people.txt")
// The schema is encoded in a string
val schemaString = "name age"
// Generate the schema based on the string of schema
val fields = schemaString.split(" ")
.map(fieldName => StructField(fieldName,StringType,nullable = true))
val schema = StructType(fields)
// Convert records of the RDD (people) to Rows
val rowRDD = peopleRDD.map(_.split(","))
.map(attributes => Row(attributes(0),attributes(1).trim))
//根据schema和rowRDD创建dataFrame
val peopleDF = spark.createDataFrame(rowRDD,schema)
peopleDF.createOrReplaceTempView("people")
val results = spark.sql("select name from people")
results.map(attribute => "Name: "+ attribute(0)).show()
}
}