import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}
class GroupConcatDistinct extends UserDefinedAggregateFunction{
//UDAF:输入数据类型为String
override def inputSchema: StructType = StructType(List(StructField("cityInfo",StringType,true)))
//缓冲区类型
override def bufferSchema: StructType = StructType(List(StructField("buffCityInfo",StringType,true)))
//输出数据类型
override def dataType: DataType = StringType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = ""
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
var buffCityInfo = buffer.getString(0)
val cityInfo = input.getString(0)
if(!buffCityInfo.contains(cityInfo)){
if("".equals(buffCityInfo)){
buffCityInfo += cityInfo
}else{
buffCityInfo += "," + cityInfo
}
buffer.update(0,buffCityInfo)
}
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
var buffCityInfo1 = buffer1.getString(0)
val buffCityInfo2 = buffer2.getString(0)
for(cityInfo <- buffCityInfo2.split(",")){
if(!buffCityInfo1.contains(cityInfo)){
if("".equals(buffCityInfo1)){
buffCityInfo1 += cityInfo
}else{
buffCityInfo1 += "," + cityInfo
}
}
}
buffer1.update(0,buffCityInfo1)
}
override def evaluate(buffer: Row): Any = {
buffer.getString(0)
}
}