接上篇
想要在输出男女生人数的基础上,输出姓名,需自定义聚合函数
AggrNameUDF
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}
class AggrNameUDF extends UserDefinedAggregateFunction{
//输入数据的结构类型
override def inputSchema: StructType = {StructType(List(StructField("name",StringType)))}
//缓冲区数据的结构类型
override def bufferSchema: StructType = {StructType(List(StructField("name",StringType)))}
//返回值类型
override def dataType: DataType = StringType
override def deterministic: Boolean = true
//初始化操作,初始值赋值为空
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer.update(0,"")
//在work中每一个分区进行操作
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//获取原先的值
var bfValue = buffer.getString(0)
//新传递的数据
var nowValue = input.getString(0)
if(bfValue==""){
bfValue = nowValue
}else{
bfValue += ","+nowValue
}
//把合并的数据再放到缓冲区
buffer.update(0,bfValue)
}
//合并所有rdd为一个数据
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
var bfValue = buffer1.getString(0)
var nowValue = buffer2.getString(0)
if(bfValue==""){
bfValue = nowValue
}else{
bfValue += ","+nowValue
}
buffer1.update(0,bfValue)
}
//得到缓冲区存放数据
override def evaluate(buffer: Row): Any = buffer.getString(0)
}
HiveDriver
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
object HiveDriver {
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
val session = SparkSession.builder().master("local[2]").enableHiveSupport().getOrCreate()
session.udf.register("aggr",new AggrNameUDF())
session.sql("use driver")
//输出男女生人数
//val df = session.sql("select gender,count(*) count from dr group by gender")
//输出男生女生人数并输出对应的名字
val df = session.sql("select gender,count(*) count,aggr(name) names from dr group by gender")
df.show()
session.stop()
}
}