diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionFolding.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionFolding.java index 2bdd96c5..482a9948 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionFolding.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/ExpressionFolding.java @@ -1,6 +1,7 @@ package liquidjava.rj_language.opt; import liquidjava.rj_language.ast.BinaryExpression; +import liquidjava.rj_language.ast.Enum; import liquidjava.rj_language.ast.Expression; import liquidjava.rj_language.ast.GroupExpression; import liquidjava.rj_language.ast.Ite; @@ -62,6 +63,16 @@ private static ValDerivationNode foldBinary(ValDerivationNode node) { Expression left = leftNode.getValue(); Expression right = rightNode.getValue(); String op = binExp.getOperator(); + + if (left instanceof Enum en && en.getResolvedLiteral() != null) { + left = en.getResolvedLiteral().clone(); + leftNode = new ValDerivationNode(left, leftNode); + } + if (right instanceof Enum en && en.getResolvedLiteral() != null) { + right = en.getResolvedLiteral().clone(); + rightNode = new ValDerivationNode(right, rightNode); + } + binExp.setChild(0, left); binExp.setChild(1, right); @@ -146,6 +157,18 @@ else if (left instanceof LiteralBoolean && right instanceof LiteralBoolean) { return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op)); } + else if (left instanceof Enum leftEnum && right instanceof Enum rightEnum + && leftEnum.getTypeName().equals(rightEnum.getTypeName())) { + boolean equal = leftEnum.getConstName().equals(rightEnum.getConstName()); + Expression res = switch (op) { + case "==" -> new LiteralBoolean(equal); + case "!=" -> new LiteralBoolean(!equal); + default -> null; + }; + if (res != null) + return new ValDerivationNode(res, new BinaryDerivationNode(leftNode, rightNode, op)); + } + ValDerivationNode adjacentConstants = foldAdjacentIntegerConstants(leftNode, rightNode, op); if (adjacentConstants != null) return adjacentConstants; diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java index da1be65d..48b37e03 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java @@ -1,6 +1,7 @@ package liquidjava.rj_language.opt; import liquidjava.rj_language.ast.BinaryExpression; +import liquidjava.rj_language.ast.Enum; import liquidjava.rj_language.ast.Expression; import liquidjava.rj_language.ast.FunctionInvocation; import liquidjava.rj_language.ast.UnaryExpression; @@ -28,7 +29,7 @@ public static ValDerivationNode propagate(Expression exp, ValDerivationNode prev Map expressionSubstitutions = new HashMap<>(); // var == expression for (Map.Entry entry : substitutions.entrySet()) { Expression value = entry.getValue(); - if (value.isLiteral() || value instanceof Var) { + if (value.isLiteral() || value instanceof Var || value instanceof Enum) { directSubstitutions.put(entry.getKey(), value); } else { expressionSubstitutions.put(entry.getKey(), value); diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java index 92ff8d27..f19f2f8a 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariableResolver.java @@ -6,6 +6,7 @@ import java.util.Set; import liquidjava.rj_language.ast.BinaryExpression; +import liquidjava.rj_language.ast.Enum; import liquidjava.rj_language.ast.Expression; import liquidjava.rj_language.ast.FunctionInvocation; import liquidjava.rj_language.ast.Var; @@ -56,9 +57,9 @@ private static void resolveRecursive(Expression exp, Map map String leftKey = substitutionKey(left); String rightKey = substitutionKey(right); - if (leftKey != null && right.isLiteral()) { + if (leftKey != null && isConstant(right)) { map.put(leftKey, right.clone()); - } else if (rightKey != null && left.isLiteral()) { + } else if (rightKey != null && isConstant(left)) { map.put(rightKey, left.clone()); } else if (left instanceof Var leftVar && right instanceof Var rightVar) { // to substitute internal variable with user-facing variable @@ -144,15 +145,15 @@ private static boolean hasUsage(Expression exp, String name, Expression value) { Expression left = binary.getFirstOperand(); Expression right = binary.getSecondOperand(); if (left instanceof Var v && v.getName().equals(name) && right.equals(value) - && (right.isLiteral() || (!(right instanceof Var) && canSubstitute(v, right)))) + && (isConstant(right) || (!(right instanceof Var) && canSubstitute(v, right)))) return false; if (left instanceof FunctionInvocation && left.toString().equals(name) && right.equals(value) - && (right.isLiteral() || (!(right instanceof Var) && !containsExpression(right, left)))) + && (isConstant(right) || (!(right instanceof Var) && !containsExpression(right, left)))) return false; - if (right instanceof Var v && v.getName().equals(name) && left.equals(value) && left.isLiteral()) + if (right instanceof Var v && v.getName().equals(name) && left.equals(value) && isConstant(left)) return false; if (right instanceof FunctionInvocation && right.toString().equals(name) && left.equals(value) - && left.isLiteral()) + && isConstant(left)) return false; } @@ -198,6 +199,10 @@ private static boolean canSubstitute(Var var, Expression value) { return !isReturnVar(var) && !isFreshVar(var) && !containsVariable(value, var.getName()); } + private static boolean isConstant(Expression exp) { + return exp.isLiteral() || exp instanceof Enum; + } + private static boolean containsVariable(Expression exp, String name) { if (exp instanceof Var var) return var.getName().equals(name); diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java index dc71c5df..f243835f 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java @@ -626,4 +626,36 @@ void testFunctionInvocationEqualitiesMixWithVariables() { assertEquals("3", result.getValue().toString()); } + + @Test + void testEnumConstantsPropagateIntoVariableEquality() { + Expression expression = parse("current == mode && mode == Mode.Photo"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); + + assertEquals("current == Mode.Photo", result.getValue().toString()); + } + + @Test + void testEnumConstantsPropagateTransitively() { + Expression expression = parse("target == current && current == mode && mode == Mode.Photo"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); + + assertEquals("target == Mode.Photo", result.getValue().toString()); + } + + @Test + void testEnumConstantsPropagateThroughFunctionInvocations() { + Expression expression = parse("modeOf(x) == Mode.Photo && current == modeOf(x)"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); + + assertEquals("current == Mode.Photo", result.getValue().toString()); + } + + @Test + void testEnumConstantsPropagateIntoTernaryCondition() { + Expression expression = parse("mode == Mode.Photo && (mode == Mode.Video ? explicit(param) : start(param))"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); + + assertEquals("start(param)", result.getValue().toString()); + } }