如何搞一个支持自定义函数和变量的四则运算的抽象语法树出来

楔子

本来最近一直在写设计模式系列,本来要写visitor模式了,但是据说visitor模式本来是应用到编译器级别的组件上的。为了我的visitor模式,所以我决定手写一个抽象语法树出来。

需求

  1. 给定一个输入(2-1)*3+2+(3*(9-(5+2)*1)),程序自动算出来运算的结果。
  2. 程序可以支持自定义函数,函数需要能够进行嵌套,如计算Max(8,Max(5,4))+plus100(max(3,9)),其中maxplus100都是自定义的函数。
  3. 程序可以支持自定义变量,并通过上下文传入自定义变量的值,如计算Max(8,Max(5,cc))+plus100(max(aa*2,bb)),其中aa、bb、cc都是自定义变量。

思考

BNF范式

先从最简单的来,其实我也不知道怎么办,所以我去百度了一下…网上的各种帖子要么说的云里雾里的,要么直接说让你用现成的api。但是通过查询,我了解到了一个关键词:BNF范式但是这个东西专业的解释过于专业 (废话)。相对来说,我更喜欢知乎上的关于BNF范式的一个答案:BNF范式(巴科斯范式)到底是什么。用一句简单说一下就是:BNF范式提供了一种可以向下递归降解的语法解析方式

BNF范式应用

首先明确一下四则运算符的运算规则:

  • 先乘除、后加减
  • 遇到括号,时先算括号中的表达式

然后我们再回来看这个表达式(2-1)*3+2+(3*(9-(5+2)*1)),那么这个表达式应该怎么拆分呢?我画了一个图来表示拆分的过程
运算拆分

为什么从右向左拆,而把不是从左向右拆呢?想想这个表达式要怎么拆2-3-4-5

在上面的图中,一共有三种节点:

  • 浅蓝色的节点表示当前是一个表达式节点Expression,还可以继续向下拆分
  • 浅紫色的节点表示当前是一个运算符节点Operator,运算符用于连接两个不同的节点
  • 绿色的节点表示当前是一个数据节点DataNode,数据节点是最末级节点,不可以继续拆分了。

按照相同的思路,我们拆分一下Max(8,Max(5,4))+plus100(max(3,9))
函数分解
在这次的分解中,多了橘黄色的部分,橘黄色的节点表示当前是一个函数节点FunctionNode,函数节点还可以继续向下分解,但是分解的内容为函数的参数。
最后,我们再分解一下Max(8,Max(5,cc))+plus100(max(aa*2,bb))
带常量的分解
这次又多出来了蓝色的节点,蓝色表示当前节点是一个常量节点StaticNode,不可继续拆分。

回顾

通过上面对计算表达式进行拆分,我们一共拆解出五类节点

  1. 表达式节点Expression,可以继续进行拆分
  2. 运算符节点Operator,用来连接两个表达式
  3. 数据节点DataNode,最末级节点,不可以进行拆分
  4. 函数节点FunctionNode,函数节点还可以继续向下分解,但是分解的内容为函数的参数
  5. 常量节点StaticNode,不可继续拆分

然后我们用代码来实现一下吧。

实现

公共的Node节点

/**
 * 计算节点
 */
public interface Node {

    /**
     * 节点表达式
     * @return
     */
    String getText();

    /**
     * 节点值
     * @return
     */
    Object getValue();

    /**
     * 解析之后的节点
     * @return
     */
    Node parse();
}

表达式节点 Expression

/**
 * 表达式节点
 */
public class Expression implements Node {
    Map<String, Object> env;
    Map<String, IFunction> functionMap;
    private String text;

    public Expression() {
    }

    @Override
    public Node parse() {
        Collection<List<Character>> sortList = Operator.getSortList();
        for (List<Character> characters : sortList) {
            Node operator = getNode(characters);
            if (operator != null) {
                return operator;
            }
        }
        if (text.indexOf('(') == -1 || isStartAndEndWidthBrackets(text)) {
            StaticNode staticNode = new StaticNode();
            staticNode.setText(text);
            staticNode.setEnv(env);
            return staticNode;
        }
        FunctionNode functionNode = new FunctionNode();
        functionNode.setFunctionMap(functionMap);
        functionNode.setEnv(env);
        functionNode.setText(text);
        return functionNode;
    }

    private Node getNode(List<Character> opts) {
        int length = text.length();
        int brackets = 0;
        boolean isDataNode = true;
        for (int i = 0; i < length; i++) {
            char charAt = text.charAt(i);
            if (isDataNode && (charAt > '9' || charAt < '0')) {
                isDataNode = false;
            }
            if (charAt == '(') {
                brackets++;
            }
            if (charAt == ')') {
                brackets--;
            }
            if (opts.contains(charAt)) {
                if (brackets == 0) {
                    Expression left = new Expression();
                    left.setText(text.substring(0, i).trim());
                    left.setFunctionMap(functionMap);
                    left.setEnv(env);
                    Operator operator = Operator.valueOf(charAt);
                    operator.setLeft(left);
                    Expression right = new Expression();
                    right.setText(text.substring(i + 1).trim());
                    right.setFunctionMap(functionMap);
                    right.setEnv(env);
                    operator.setRight(right);
                    operator.parse();
                    return operator;
                }
            }
        }
        if (length > 0 && isDataNode) {
            DataNode expression = new DataNode();
            expression.setText(text);
            return expression;
        }
        return null;
    }

    @Override
    public String getText() {
        return text;
    }

    public void setText(String text) {
        String trim = text.trim();
        if (isStartAndEndWidthBrackets(trim)) {
            trim = trim.substring(1, trim.length() - 1);
        }
        this.text = trim;
    }

    private boolean isStartAndEndWidthBrackets(String text) {
        if (text.charAt(0) == '(') {
            int length = text.length();
            int brackets = 0;
            for (int i = 0; i < length; i++) {
                char charAt = text.charAt(i);
                if (charAt == '(') {
                    brackets++;
                }
                if (charAt == ')') {
                    brackets--;
                    if (brackets == 0) {
                        return i == length - 1;
                    }
                }
            }
        }
        return false;
    }

    @Override
    public Object getValue() {
        throw new RuntimeException("not support!");
    }

    public void setEnv(Map<String, Object> env) {
        this.env = env;
    }

    public void setFunctionMap(Map<String, IFunction> functionMap) {
        this.functionMap = functionMap;
    }
}

静态变量节点 StaticNode

/**
 * 变量节点
 */
public class StaticNode extends Expression{

    @Override
    public String getText() {
        return super.getText();
    }

    @Override
    public Object getValue() {
        if (env == null) {
            return null;
        }
        return env.get(getText());
    }
}

数据节点 DataNode

/**
 * 数值节点
 */
public class DataNode extends Expression{

    @Override
    public String getText() {
        return super.getText();
    }

    @Override
    public Object getValue() {
        String text = getText();
        int i = text.indexOf('.');
        if (i == -1) {
            try {
                return Integer.parseInt(text);
            } catch (NumberFormatException e) {
                return new BigInteger(text);
            }
        }
        int precision = text.length() - i;
        if (precision <= 6) {
            try {
                return Float.parseFloat(text);
            } catch (NumberFormatException e) {
                return new BigDecimal(text);
            }
        }
        if (precision <= 15) {
            try {
                return Double.parseDouble(text);
            } catch (NumberFormatException e) {
                return new BigDecimal(text);
            }
        }
        return new BigDecimal(text);
    }
}

函数节点 FunctionNode

/**
 * 函数节点
 */
public class FunctionNode extends Expression {

    private String funcName;
    private List<Node> params;

    @Override
    public String getText() {
        return super.getText();
    }

    @Override
    public void setText(String text) {
        super.setText(text);
        String txt = getText();
        int index = txt.indexOf('(');
        funcName = txt.substring(0, index).toUpperCase(Locale.ROOT);
        params = new ArrayList<>();
        String substring = txt.substring(index + 1, txt.length() - 1);
        int length = substring.length();
        int brackets = 0;
        int begin = 0;
        for (int i = length-1; i >=0; i--) {
            char charAt = substring.charAt(i);
            if (charAt == ')') {
                brackets++;
            }
            if (charAt == '(') {
                brackets--;
            }
            if (charAt == ',') {
                if (brackets == 0) {
                    Expression expression = new Expression();
                    expression.setFunctionMap(functionMap);
                    expression.setText(substring.substring(begin, i));
                    expression.setEnv(env);
                    Node parse = expression.parse();
                    params.add(parse);
                    begin = i + 1;
                }
            }
        }
        if (begin != length) {
            Expression expression = new Expression();
            expression.setFunctionMap(functionMap);
            expression.setText(substring.substring(begin));
            expression.setEnv(env);
            Node parse = expression.parse();
            params.add(parse);
        }
    }

    @Override
    public Object getValue() {
        IFunction iFunction = functionMap.get(funcName);
        int size = params.size();
        Object[] paramVals = new Object[size];
        for (int i = 0; i < size; i++) {
            paramVals[i] = params.get(i).getValue();
        }
        return iFunction.exec(paramVals, env);
    }
}

操作符节点 Operator

在处理操作符时,使用ServiceLoader将具体的加减乘除操作与抽象的Operator 进行解耦,方便日后扩展其他类型的操作符,比如取摸、求余等操作。

/**
 * 操作符节点
 */
public abstract class Operator implements Node {

    private static final Map<Character, Class<? extends Operator>> operatorMaps = new HashMap<>();
    private static final Collection<List<Character>> sortList;

    static {
        Map<Integer, List<Character>> operatorPriorityMap = new TreeMap<>();
        ServiceLoader<Operator> load = ServiceLoader.load(Operator.class);
        for (Operator operator : load) {
            operatorMaps.put(operator.operator(), operator.getClass());
            List<Character> list = operatorPriorityMap.computeIfAbsent(operator.priority(), k -> new ArrayList<>());
            list.add(operator.operator());
        }
        sortList = Collections.unmodifiableCollection(operatorPriorityMap.values());
    }

    private Node left;
    private Node right;
    private Node leftResult;
    private Node rightResult;

    public static Operator valueOf(char operator) {
        Class<? extends Operator> clazz = operatorMaps.get(operator);
        if (clazz == null) {
            throw new RuntimeException("not support!");
        }
        try {
            return clazz.newInstance();
        } catch (InstantiationException | IllegalAccessException e) {
            e.printStackTrace();
        }
        throw new RuntimeException("not support!");
    }

    public static Collection<List<Character>> getSortList() {
        return sortList;
    }

    @Override
    public final Node parse() {
        leftResult = left.parse();
        rightResult = right.parse();
        return this;
    }

    final void setLeft(Node left) {
        this.left = left;
    }

    final void setRight(Node right) {
        this.right = right;
    }

    final Object getLeftResult() {
        return leftResult.getValue();
    }

    final Object getRightResult() {
        return rightResult.getValue();
    }

    @Override
    public String getText() {
        return String.valueOf(operator());
    }

    /**
     * 优先级
     *
     * @return
     */
    public abstract int priority();

    /**
     * 操作符
     *
     * @return
     */
    public abstract char operator();
}

加法处理节点 Add

/**
 * 加法操作符节点
 */
public class Add extends Operator {

    @Override
    public char operator() {
        return '+';
    }

    @Override
    public Object getValue() {
        Object leftResult = getLeftResult();
        Object rightResult = getRightResult();
        if (leftResult instanceof Integer) {
            int left = (int) leftResult;
            if (rightResult instanceof Integer) {
                return left + (Integer) rightResult;
            }
            if (rightResult instanceof Double) {
                return left + (Double) rightResult;
            }
            if (rightResult instanceof Float) {
                return left + (Float) rightResult;
            }
            if (rightResult instanceof BigInteger) {
                return ((BigInteger) rightResult).add(BigInteger.valueOf(left));
            }
            if (rightResult instanceof BigDecimal) {
                return ((BigDecimal) rightResult).add(BigDecimal.valueOf(left));
            }
        }
        if (leftResult instanceof Double) {
            double left = (Double) leftResult;
            if (rightResult instanceof Integer) {
                return left + (Integer) rightResult;
            }
            if (rightResult instanceof Double) {
                return left + (Double) rightResult;
            }
            if (rightResult instanceof Float) {
                return left + (Float) rightResult;
            }
            if (rightResult instanceof BigInteger) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigDecimal) {
                return ((BigDecimal) rightResult).add(new BigDecimal(String.valueOf(leftResult)));
            }
        }
        if (leftResult instanceof Float) {
            float left = (float) leftResult;
            if (rightResult instanceof Integer) {
                return left + (Integer) rightResult;
            }
            if (rightResult instanceof Double) {
                return left + (Double) rightResult;
            }
            if (rightResult instanceof Float) {
                return left + (Float) rightResult;
            }
            if (rightResult instanceof BigInteger) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigDecimal) {
                return ((BigDecimal) rightResult).add(new BigDecimal(String.valueOf(rightResult)));
            }
        }
        if (leftResult instanceof BigInteger) {
            if (rightResult instanceof Integer) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof Double) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof Float) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigInteger) {
                return ((BigInteger) rightResult).add((BigInteger) leftResult);
            }
            if (rightResult instanceof BigDecimal) {
                throw new RuntimeException("not support!");
            }
        }
        if (leftResult instanceof BigDecimal) {
            BigDecimal left = (BigDecimal) leftResult;
            if (rightResult instanceof Integer || rightResult instanceof Double || rightResult instanceof Float) {
                return left.add(new BigDecimal(String.valueOf(rightResult)));
            }
            if (rightResult instanceof BigInteger) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigDecimal) {
                return ((BigDecimal) rightResult).add(left);
            }
        }
        throw new RuntimeException("not support!");
    }

    @Override
    public int priority() {
        return 0;
    }
}

减法处理节点 Minus

/**
 * 减法操作符
 */
public class Minus extends Operator {

    @Override
    public char operator() {
        return '-';
    }

    @Override
    public Object getValue() {
        Object leftResult = getLeftResult();
        Object rightResult = getRightResult();
        if (leftResult instanceof Integer) {
            int left = (int) leftResult;
            if (rightResult instanceof Integer) {
                return left - (Integer) rightResult;
            }
            if (rightResult instanceof Double) {
                return left - (Double) rightResult;
            }
            if (rightResult instanceof Float) {
                return left - (Float) rightResult;
            }
            if (rightResult instanceof BigInteger) {
                return BigInteger.valueOf(left).subtract(((BigInteger) rightResult));
            }
            if (rightResult instanceof BigDecimal) {
                return BigDecimal.valueOf(left).subtract(((BigDecimal) rightResult));
            }
        }
        if (leftResult instanceof Double) {
            double left = (Double) leftResult;
            if (rightResult instanceof Integer) {
                return left - (Integer) rightResult;
            }
            if (rightResult instanceof Double) {
                return left - (Double) rightResult;
            }
            if (rightResult instanceof Float) {
                return left - (Float) rightResult;
            }
            if (rightResult instanceof BigInteger) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigDecimal) {
                return new BigDecimal(String.valueOf(leftResult)).subtract(((BigDecimal) rightResult));
            }
        }
        if (leftResult instanceof Float) {
            float left = (float) leftResult;
            if (rightResult instanceof Integer) {
                return left - (Integer) rightResult;
            }
            if (rightResult instanceof Double) {
                return left - (Double) rightResult;
            }
            if (rightResult instanceof Float) {
                return left - (Float) rightResult;
            }
            if (rightResult instanceof BigInteger) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigDecimal) {
                return new BigDecimal(String.valueOf(leftResult)).subtract(((BigDecimal) rightResult));
            }
        }
        if (leftResult instanceof BigInteger) {
            if (rightResult instanceof Integer || rightResult instanceof Double || rightResult instanceof Float || rightResult instanceof BigDecimal) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigInteger) {
                return ((BigInteger) leftResult).subtract(((BigInteger) rightResult));
            }
        }
        if (leftResult instanceof BigDecimal) {
            BigDecimal left = (BigDecimal) leftResult;
            if (rightResult instanceof Integer || rightResult instanceof Double || rightResult instanceof Float) {
                return left.subtract(new BigDecimal(String.valueOf(rightResult)));
            }
            if (rightResult instanceof BigInteger) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigDecimal) {
                return left.subtract(((BigDecimal) rightResult));
            }
        }
        throw new RuntimeException("not support!");
    }

    @Override
    public int priority() {
        return 0;
    }
}

乘法处理节点 Multiply

/**
 * 乘法操作符
 */
public class Multiply extends Operator {

    @Override
    public char operator() {
        return '*';
    }

    @Override
    public Object getValue() {
        Object leftResult = getLeftResult();
        Object rightResult = getRightResult();
        if (leftResult instanceof Integer) {
            int left = (int) leftResult;
            if (rightResult instanceof Integer) {
                return left * (Integer) rightResult;
            }
            if (rightResult instanceof Double) {
                return left * (Double) rightResult;
            }
            if (rightResult instanceof Float) {
                return left * (Float) rightResult;
            }
            if (rightResult instanceof BigInteger) {
                return ((BigInteger) rightResult).multiply(BigInteger.valueOf(left));
            }
            if (rightResult instanceof BigDecimal) {
                return ((BigDecimal) rightResult).multiply(BigDecimal.valueOf(left));
            }
        }
        if (leftResult instanceof Double) {
            double left = (Double) leftResult;
            if (rightResult instanceof Integer) {
                return left * (Integer) rightResult;
            }
            if (rightResult instanceof Double) {
                return left * (Double) rightResult;
            }
            if (rightResult instanceof Float) {
                return left * (Float) rightResult;
            }
            if (rightResult instanceof BigInteger) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigDecimal) {
                return ((BigDecimal) rightResult).multiply(new BigDecimal(String.valueOf(leftResult)));
            }
        }
        if (leftResult instanceof Float) {
            float left = (float) leftResult;
            if (rightResult instanceof Integer) {
                return left * (Integer) rightResult;
            }
            if (rightResult instanceof Double) {
                return left * (Double) rightResult;
            }
            if (rightResult instanceof Float) {
                return left * (Float) rightResult;
            }
            if (rightResult instanceof BigInteger) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigDecimal) {
                return ((BigDecimal) rightResult).multiply(new BigDecimal(String.valueOf(rightResult)));
            }
        }
        if (leftResult instanceof BigInteger) {
            if (rightResult instanceof Integer || rightResult instanceof Double || rightResult instanceof Float || rightResult instanceof BigDecimal) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigInteger) {
                return ((BigInteger) rightResult).multiply((BigInteger) leftResult);
            }
        }
        if (leftResult instanceof BigDecimal) {
            BigDecimal left = (BigDecimal) leftResult;
            if (rightResult instanceof Integer || rightResult instanceof Double || rightResult instanceof Float) {
                return left.multiply(new BigDecimal(String.valueOf(rightResult)));
            }
            if (rightResult instanceof BigInteger) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigDecimal) {
                return ((BigDecimal) rightResult).multiply(left);
            }
        }
        throw new RuntimeException("not support!");
    }

    @Override
    public int priority() {
        return 1;
    }
}

除法处理节点 Divide

/**
 * 除法操作符
 */
public class Divide extends Operator {

    @Override
    public char operator() {
        return '/';
    }

    @Override
    public Object getValue() {
        Object leftResult = getLeftResult();
        Object rightResult = getRightResult();
        if (leftResult instanceof Integer) {
            int left = (int) leftResult;
            if (rightResult instanceof Integer) {
                return left / (Integer) rightResult;
            }
            if (rightResult instanceof Double) {
                return left / (Double) rightResult;
            }
            if (rightResult instanceof Float) {
                return left / (Float) rightResult;
            }
            if (rightResult instanceof BigInteger) {
                return (new BigInteger(String.valueOf(leftResult))).divide(((BigInteger) rightResult));
            }
            if (rightResult instanceof BigDecimal) {
                return (new BigDecimal(String.valueOf(leftResult))).divide(((BigDecimal) rightResult));
            }
        }
        if (leftResult instanceof Double) {
            double left = (Double) leftResult;
            if (rightResult instanceof Integer) {
                return left / (Integer) rightResult;
            }
            if (rightResult instanceof Double) {
                return left / (Double) rightResult;
            }
            if (rightResult instanceof Float) {
                return left / (Float) rightResult;
            }
            if (rightResult instanceof BigInteger) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigDecimal) {
                return new BigDecimal(String.valueOf(leftResult)).divide(((BigDecimal) rightResult));
            }
        }
        if (leftResult instanceof Float) {
            float left = (float) leftResult;
            if (rightResult instanceof Integer) {
                return left / (Integer) rightResult;
            }
            if (rightResult instanceof Double) {
                return left / (Double) rightResult;
            }
            if (rightResult instanceof Float) {
                return left / (Float) rightResult;
            }
            if (rightResult instanceof BigInteger) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigDecimal) {
                return (new BigDecimal(String.valueOf(leftResult))).divide((BigDecimal) rightResult);
            }
        }
        if (leftResult instanceof BigInteger) {
            if (rightResult instanceof Integer || rightResult instanceof Double || rightResult instanceof Float || rightResult instanceof BigDecimal) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigInteger) {
                return ((BigInteger) leftResult).divide((BigInteger) rightResult);
            }
        }
        if (leftResult instanceof BigDecimal) {
            BigDecimal left = (BigDecimal) leftResult;
            if (rightResult instanceof Integer || rightResult instanceof Double || rightResult instanceof Float) {
                return left.divide(new BigDecimal(String.valueOf(rightResult)));
            }
            if (rightResult instanceof BigInteger) {
                throw new RuntimeException("not support!");
            }
            if (rightResult instanceof BigDecimal) {
                return left.divide(((BigDecimal) rightResult));
            }
        }
        throw new RuntimeException("not support!");
    }

    @Override
    public int priority() {
        return 1;
    }
}

自定义函数扩展

函数接口 IFunction

/**
 * 扩展的函数接口
 */
public interface IFunction {

    /**
     * 具体的函数执行逻辑
     * @param params
     * @param env
     * @return
     */
    Object exec(Object[] params, Map<String, Object> env);

    /**
     * 函数名
     * @return
     */
    String getName();

    /**
     * 函数参数
     * @return
     */
    List<FunctionParam> getParams();
}

函数参数说明 FunctionParam

public class FunctionParam {
    private String name;
    private Class<?> type;

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public Class<?> getType() {
        return type;
    }

    public void setType(Class<?> type) {
        this.type = type;
    }
}

自定义函数 Max

/**
 * 求两个整数的较大数的函数
 */
public class Max implements IFunction{

    @Override
    public Object exec(Object[] params, Map<String, Object> env) {
        return Math.max((Integer) params[0], (Integer) params[1]);
    }

    @Override
    public String getName() {
        return this.getClass().getSimpleName();
    }

    @Override
    public List<FunctionParam> getParams() {
        List<FunctionParam> params = new ArrayList<>();
        FunctionParam param1 = new FunctionParam();
        param1.setName("left");
        param1.setType(Integer.class);
        FunctionParam param2 = new FunctionParam();
        param2.setName("right");
        param2.setType(Integer.class);
        params.add(param1);
        params.add(param2);
        return params;
    }
}

自定义函数 Plus100

/**
 * 给一个整数加100的函数
 */
public class Plus100 implements IFunction{

    @Override
    public Object exec(Object[] params, Map<String, Object> env) {
        Object param = params[0];
        return (Integer) param + 100;
    }

    @Override
    public String getName() {
        return this.getClass().getSimpleName();
    }

    @Override
    public List<FunctionParam> getParams() {
        List<FunctionParam> params = new ArrayList<>();
        FunctionParam param1 = new FunctionParam();
        param1.setName("p1");
        param1.setType(Integer.class);
        params.add(param1);
        return params;
    }
}

入口

函数的收集机制依赖于ServiceLoader

/**
 * 计算器
 * @author skyline
 */
public class Calculation {

    private static final Map<String, IFunction> FUNCTION_MAP = new HashMap<>();

    static {
        ServiceLoader<IFunction> load = ServiceLoader.load(IFunction.class);
        for (IFunction iFunction : load) {
            FUNCTION_MAP.put(iFunction.getName().toUpperCase(Locale.ROOT), iFunction);
        }
    }

    public static Object exec(String exp) {
        return exec(exp, null);
    }

    public static Object exec(String exp, Map<String, Object> env) {
        Expression root = new Expression();
        root.setText(exp);
        root.setFunctionMap(FUNCTION_MAP);
        root.setEnv(env);
        Node parse = root.parse();
        return parse.getValue();
    }
}

测试

测试样例

跑一下测试样例看看

/**
 * @author skyline
 * @date 2022/1/18
 **/
public class CalculationTest {

    private static final Logger logger = LoggerFactory.getLogger(CalculationMain.class);

    @Test
    public void test1() {
        String exp = "(2-1)*3+2+(3*(9-(5+2)*1))";
        exe(exp, null);
    }

    @Test
    public void test2() {
        String exp = "Max(8,Max(5,4))+plus100(max(3,9))";
        exe(exp, null);
    }


    @Test
    public void test3() {
        String exp = "Max(8,Max(5,cc))*2+plus100(max(aa*2,bb))";
        Map<String, Object> env = new HashMap<>();
        env.put("cc", 4);
        env.put("aa", 3);
        env.put("bb", 9);
        exe(exp, env);
    }

    private void exe(String exp, Map<String, Object> env) {
        logger.info("exp is {}", exp);
        if (env != null) {
            logger.info("env is {}", env);
        }
        Object exec = Calculation.exec(exp, env);
        logger.info("my answer is {}", exec);
    }
}

测试结果

在这里插入图片描述

总结

以前总是觉得这个东西很神奇,想试试但又不知道从何下手,最近终于有机会尝试了一下,发现其实就是树形的递归操作,原来也没那么复杂。很多东西还是要动手的,纸上得来终觉浅,绝知此事要躬行

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值