博主前端时间一个在看Spark-core的源码。最近因为项目上的事,赶鸭子上架被叫过去改bug了。因为项目组对spark源码做了很多修改,这次在迁移到spark2.3时,需要把之前修改的功能也一并迁移过来。遇到的第一个问题就是实现一个IsNumeric 。这是其他RDB中sql支持的函数,原生spark-sql暂不支持。
接下来我们看一下实现流程:
1. 定义函数表达式
先找到arithmetic.scala
在里面依葫芦画瓢添加如下的表达式定义:
/**
* if the input is numeric ,return 1, else return 0.
*/
@ExpressionDescription(
usage = "_FUNC_(input) - if the input is numeric ,return 1, else return 0.",
extended =
"""
Examples:
> SELECT isnumeric('1e2') result;
1
""")
case class IsNumeric(inputExpr: Expression)
extends Expression {
override def children: Seq[Expression] = Seq(inputExpr)
override def dataType: DataType = IntegerType
override def nullable: Boolean = true
def eval(input: InternalRow): Any = {
inputExpr.dataType match {
case DecimalType() | DoubleType | FloatType |
LongType | IntegerType | ShortType => return 1
case StringType => try {
val value = inputExpr.eval(input).toString.toDouble
1
} catch {
case _ => 0
}
case _ => 0
}
}
override protected def doGenCode(ctx: CodegenContext,
ev: ExprCode): ExprCode = {
val eval1 = inputExpr.genCode(ctx)
val other = inputExpr.dataType match {
case DecimalType() | DoubleType | FloatType |
LongType | IntegerType | ShortType =>
s"""${ev.value}= 1;"""
case StringType =>
s"""try {
java.lang.Double.parseDouble(${eval1.value}.toString());
${ev.value}= 1;
} catch (java.lang.Exception pe) {
${ev.value}= 0;
}"""
case _ => s"""${ev.value}= 0;"""
}
ev.copy(code = eval1.code +
s"""boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(IntegerType)} ${ev.value} = ${ctx.defaultValue(IntegerType)};
${other};""")
}
}
2. 函数方法的实现
实现非常非常简单,实现compute方法,使用预编译的正则表达式做一个匹配,返回结果即可
// 这里是项目自己添加的包
package org.apache.hive.tsql.udf;
import org.apache.hive.tsql.arg.Var;
import java.util.List;
import java.util.regex.Pattern;
public class IsNumericCalculator extends BaseCalculator {
public static Pattern pattern = Pattern.compile("^(-?\\d+)(\\.\\d+)?([Ee]?[+-]?\\d+)$");
@Override
public Var compute() throws Exception {
List<Var> argList = getAllArguments();
if (pattern.matcher(argList.get(0).toString()).matches()) {
return new Var("1", Var.DataType.INT);
}
return new Var("0", Var.DataType.INT);
}
}
3. 函数注册
在UDFFactory中注册自定义函数
registFunction("isnumeric".toUpperCase(), "org.apache.hive.tsql.udf.IsNumericCalculator");