1.创建一个测试用例的数据源(每隔12秒生成一个随机数)
import org.apache.flink.streaming.api.functions.source.{RichSourceFunction, SourceFunction}
class udafSource extends RichSourceFunction[Double] {
override def run(ctx: SourceFunction.SourceContext[Double]) = {
while (true) {
val d = scala.math.random
ctx.collect(d)
// 测试产生的每条数据以日志的格式打印出
val logger = Logger(this.getClass)
logger.error(s"当前值:$d")
Thread.sleep(12000)
}
}
override def cancel() = ???
}
自定义聚合函数,实现中位数计算
自定义聚合函数需要继承 AggregateFunction 类,继承AggregateFunction需要指定两个泛型:一个是返回值类型,一个是聚合过程中的中间结构类型;里面需要实现很多方法,有的方法是必须实现的,而有的方法可以根据自己的需要选择性的实现,我们可以从下图看到AggregateFunction的大概情况
import org.apache.flink.table.functions.AggregateFunction
import com.typesafe.scalalogging.Logger
import scala.collection.mutable.ListBuffer
class MedianUdaf extends AggregateFunction[Double, ListBuffer[Double]] {
/*
* 具有初始值的累加器
* 初始化AggregateFunction的accumulator。
* 系统在第一个做aggregate计算之前调用一次这个方法。
*/
override def createAccumulator(): ListBuffer[Double] = new ListBuffer[Double]()
/*
* 系统在每次aggregate计算完成后调用这个方法。
*/
override def getValue(accumulator: ListBuffer[Double]) = {
val length = accumulator.size
val med = (length / 2)
val seq = accumulator.sorted
try {
length % 2 match {
case 0 => (seq(med) + seq(med - 1)) / 2
case 1 => seq(med)
}
} catch {
case e: Exception => seq.head
}
}
/*
* UDAF必须包含1个accumulate方法。
* 您需要实现一个accumulate方法,来描述如何计算用户的输入的数据,并更新到accumulator中。
* accumulate方法的第一个参数必须是使用AggregateFunction的ACC类型的accumulator。
* 在系统运行过程中,底层runtime代码会把历史状态accumulator,
* 和指定的上游数据(支持任意数量,任意类型的数据)做为参数,一起发送给accumulate计算。
*/
def accumulate(accumulator: ListBuffer[Double], i: Double) = {
accumulator.append(i)
}
/*
* 使用merge方法把多个accumulator合为1个accumulator
* merge方法的第1个参数,必须是使用AggregateFunction的ACC类型的accumulator,而且第1个accumulator是merge方法完成之后,状态所存放的地方。
* merge方法的第2个参数是1个ACC type的accumulator遍历迭代器,里面有可能存在1个或者多个accumulator。
*/
def merge(accumulator: ListBuffer[Double], its: Iterable[ListBuffer[Double]]) = {
its.foreach(i => accumulator ++ i)
}
// 返回结果的类型(一般情况下可以不用自己实现,但是在涉及到更复杂的类型是可能会用到)
// override def getResultType = createTypeInformation[Double]
// 返回中间结果的类型
// override def getAccumulatorType = createTypeInformation[ListBuffer[Double]]
}
Flink Job
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.table.api.{Table, TableEnvironment}
import org.apache.flink.streaming.api.scala._
import org.apache.flink.table.api.scala._
object udaftest {
def main(args: Array[String]): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
val tableEnv = TableEnvironment.getTableEnvironment(env)
val ds = env.addSource(new udafSource)
//注册自定义函数
tableEnv.registerFunction("median", new MedianUdaf())
tableEnv.registerDataStream("tb_num", ds, 'num, 'proctime.proctime)
// 每一分钟聚合一次结果
val query: Table = tableEnv.sqlQuery(
"""
|SELECT
|median(num)
|FROM tb_num
|GROUP BY TUMBLE(proctime, INTERVAL '1' MINUTE)
""".stripMargin)
tableEnv.toAppendStream[Double](query).print()
env.execute(s"${this.getClass.getSimpleName}")
}
}
// 结果查看
ERROR [Source: Custom Source (1/1)] - 当前值:0.7732566761814615
ERROR [Source: Custom Source (1/1)] - 当前值:0.8715919354702197
ERROR [Source: Custom Source (1/1)] - 当前值:0.9973721871677008
ERROR [Source: Custom Source (1/1)] - 当前值:0.6653001489953874
3> 0.8224243058258407
ERROR [Source: Custom Source (1/1)] - 当前值:0.7871209641617365
ERROR [Source: Custom Source (1/1)] - 当前值:0.2327299813915178
ERROR [Source: Custom Source (1/1)] - 当前值:0.7257275254509521
ERROR [Source: Custom Source (1/1)] - 当前值:0.34564727587194566
ERROR [Source: Custom Source (1/1)] - 当前值:0.7117726278328883
4> 0.7117726278328883
参看资料:
https://help.aliyun.com/document_detail/69553.html?spm=a2c4g.11186623.6.661.4ff02ec0ilrdv9
https://ci.apache.org/projects/flink/flink-docs-release-1.8/dev/table/udfs.html