From c6755e57c3382dfe2a3d4d7c07101b983ebee5be Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Tue, 15 Jul 2025 16:25:41 -0500 Subject: [PATCH 01/23] Add BDD-based rules engine trait This commit updates the smithy-rules-engine package to support binary decision diagrams (BDD) to more efficiently resolve endpoints. We create the BDD by converting the decision tree into a control flow graph (CFG), then compile the CFG to a BDD. The CFG canonicalizes conditions for better sharing (e.g., sorts commutative functions, expands simple string templates, etc), and strips all conditions from results and hash-conses them as well. Later, we'll migrate to emitting the BDD directly in order to shave off many conditions and results that can be simplified. Our decision-tree based rules engine requires deep branching logic to find results. When evaluating the path to a result based on given input, decision trees require descending into a branch, and if at any point a condition in the branch fails, you bail out and go back up to the next branch. This can cause pathological searches of a tree (e.g., 60+ repeated checks on things like isset and booleanEquals to resolve S3 endpoints). In fact, there are currently ~73,000 unique paths through the current decision tree for S3 rules. Using a BDD (a fully reduced one at least) guarantees that we only evaluate any given condition at most once, and only when that condition actually discriminates the result. This is achieved by recursively converting the CFG into BDD nodes using ITE (if-then-else) operations, choosing a variable ordering that honors dependencies between conditions and variable bindings. The BDD builder applies Shannon expansion during ITE operations and uses hash-consing to share common subgraphs. The "bdd" trait has most of the same information as the endpointRuleset trait, but doesn't include "rules". Instead it contains a base64 encoded "nodes" value that contains the zig-zag variable-length encoded node triples, one after the other (this is much more compact and efficient to decode than 1000+ JSON array nodes). The BDD implementation uses CUDD-style complement edges where negative node references represent logical NOT, further reducing BDD size. --- config/spotbugs/filter.xml | 6 + .../aws/language/functions/AwsArn.java | 63 +- .../aws/language/functions/AwsPartition.java | 5 + .../functions/IsVirtualHostableS3Bucket.java | 5 + .../aws/language/functions/ParseArn.java | 5 + .../smithy/rulesengine/language/Endpoint.java | 14 +- .../language/evaluation/RuleEvaluator.java | 16 + .../language/evaluation/value/ArrayValue.java | 9 + .../evaluation/value/BooleanValue.java | 5 + .../language/evaluation/value/EmptyValue.java | 5 + .../evaluation/value/EndpointValue.java | 5 + .../evaluation/value/IntegerValue.java | 5 + .../evaluation/value/RecordValue.java | 9 + .../evaluation/value/StringValue.java | 5 + .../language/evaluation/value/Value.java | 2 + .../language/syntax/Identifier.java | 2 +- .../language/syntax/ToCondition.java | 5 +- .../syntax/expressions/Expression.java | 27 +- .../syntax/expressions/Reference.java | 7 + .../language/syntax/expressions/Template.java | 3 +- .../expressions/functions/BooleanEquals.java | 15 + .../functions/FunctionDefinition.java | 12 + .../expressions/functions/FunctionNode.java | 2 +- .../syntax/expressions/functions/GetAttr.java | 5 + .../syntax/expressions/functions/IsSet.java | 5 + .../functions/IsValidHostLabel.java | 5 + .../functions/LibraryFunction.java | 92 +++ .../syntax/expressions/functions/Not.java | 5 + .../expressions/functions/ParseUrl.java | 5 + .../expressions/functions/StringEquals.java | 24 + .../expressions/functions/Substring.java | 26 +- .../expressions/functions/UriEncode.java | 5 + .../expressions/literal/RecordLiteral.java | 11 + .../expressions/literal/StringLiteral.java | 20 + .../expressions/literal/TupleLiteral.java | 11 + .../language/syntax/parameters/Parameter.java | 2 +- .../syntax/parameters/Parameters.java | 13 +- .../language/syntax/rule/Condition.java | 40 +- .../language/syntax/rule/EndpointRule.java | 10 +- .../language/syntax/rule/ErrorRule.java | 28 +- .../language/syntax/rule/NoMatchRule.java | 43 ++ .../language/syntax/rule/Rule.java | 57 +- .../language/syntax/rule/TreeRule.java | 13 +- .../rulesengine/logic/ConditionEvaluator.java | 22 + .../rulesengine/logic/ConditionInfo.java | 59 ++ .../rulesengine/logic/ConditionInfoImpl.java | 103 +++ .../rulesengine/logic/ConditionReference.java | 81 +++ .../logic/RuleBasedConditionEvaluator.java | 30 + .../smithy/rulesengine/logic/bdd/Bdd.java | 362 +++++++++++ .../rulesengine/logic/bdd/BddBuilder.java | 608 ++++++++++++++++++ .../rulesengine/logic/bdd/BddCompiler.java | 141 ++++ .../logic/bdd/BddEquivalenceChecker.java | 363 +++++++++++ .../rulesengine/logic/bdd/BddEvaluator.java | 86 +++ .../rulesengine/logic/bdd/BddNodeHelpers.java | 150 +++++ .../logic/bdd/ConditionDependencyGraph.java | 110 ++++ .../logic/bdd/ConditionOrderingStrategy.java | 43 ++ .../logic/bdd/DefaultOrderingStrategy.java | 104 +++ .../rulesengine/logic/bdd/NodeReversal.java | 96 +++ .../logic/bdd/OrderConstraints.java | 92 +++ .../logic/bdd/SiftingOptimization.java | 566 ++++++++++++++++ .../smithy/rulesengine/logic/cfg/Cfg.java | 245 +++++++ .../rulesengine/logic/cfg/CfgBuilder.java | 230 +++++++ .../smithy/rulesengine/logic/cfg/CfgNode.java | 13 + .../rulesengine/logic/cfg/ConditionData.java | 64 ++ .../rulesengine/logic/cfg/ConditionNode.java | 83 +++ .../rulesengine/logic/cfg/ResultNode.java | 65 ++ .../rulesengine/logic/cfg/SsaTransform.java | 488 ++++++++++++++ .../smithy/rulesengine/traits/BddTrait.java | 77 +++ .../validators/BddTraitValidator.java | 112 ++++ ...re.amazon.smithy.model.traits.TraitService | 1 + ...e.amazon.smithy.model.validation.Validator | 1 + .../META-INF/smithy/smithy.rules.smithy | 196 +++++- .../language/value/ToObjectTest.java | 87 +++ .../logic/ConditionInfoImplTest.java | 126 ++++ .../logic/ConditionReferenceTest.java | 138 ++++ .../RuleBasedConditionEvaluatorTest.java | 72 +++ .../smithy/rulesengine/logic/TestHelpers.java | 46 ++ .../rulesengine/logic/bdd/BddBuilderTest.java | 533 +++++++++++++++ .../logic/bdd/BddCompilerTest.java | 190 ++++++ .../logic/bdd/BddEvaluatorTest.java | 215 +++++++ .../smithy/rulesengine/logic/bdd/BddTest.java | 311 +++++++++ .../bdd/ConditionDependencyGraphTest.java | 120 ++++ .../bdd/DefaultOrderingStrategyTest.java | 221 +++++++ .../logic/bdd/NodeReversalTest.java | 196 ++++++ .../logic/bdd/OrderConstraintsTest.java | 152 +++++ .../logic/bdd/SiftingOptimizationTest.java | 134 ++++ .../rulesengine/logic/cfg/CfgBuilderTest.java | 280 ++++++++ .../smithy/rulesengine/logic/cfg/CfgTest.java | 267 ++++++++ .../logic/cfg/ConditionDataTest.java | 170 +++++ .../logic/cfg/VariableDisambiguatorTest.java | 236 +++++++ .../errorfiles/valid/substring.smithy | 4 +- .../language/invalid-rules/empty-rule.json5 | 2 +- .../rulesengine/language/minimal-ruleset.json | 1 - .../errorfiles/bdd/bdd-invalid-base64.errors | 1 + .../errorfiles/bdd/bdd-invalid-base64.smithy | 35 + .../bdd/bdd-invalid-node-data.errors | 1 + .../bdd/bdd-invalid-node-data.smithy | 51 ++ .../bdd/bdd-invalid-root-reference.errors | 1 + .../bdd/bdd-invalid-root-reference.smithy | 21 + .../bdd/bdd-node-count-mismatch.errors | 1 + .../bdd/bdd-node-count-mismatch.smithy | 51 ++ .../traits/errorfiles/bdd/bdd-valid.errors | 1 + .../traits/errorfiles/bdd/bdd-valid.smithy | 56 ++ .../endpoint-tests-without-ruleset.errors | 2 +- 104 files changed, 8888 insertions(+), 77 deletions(-) create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionEvaluator.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfo.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImpl.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionReference.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/RuleBasedConditionEvaluator.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluator.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddNodeHelpers.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraph.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionOrderingStrategy.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategy.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraints.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgNode.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionData.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionNode.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ResultNode.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/BddTrait.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/BddTraitValidator.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/value/ToObjectTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImplTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionReferenceTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/RuleBasedConditionEvaluatorTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/TestHelpers.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluatorTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraphTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategyTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraintsTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDataTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableDisambiguatorTest.java create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.errors create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.smithy create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.errors create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.smithy create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.errors create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.errors create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.smithy create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.errors create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy diff --git a/config/spotbugs/filter.xml b/config/spotbugs/filter.xml index 6188a7e3059..a92ee96eade 100644 --- a/config/spotbugs/filter.xml +++ b/config/spotbugs/filter.xml @@ -218,4 +218,10 @@ + + + + + + diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsArn.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsArn.java index c52fe6c727b..f4d05b6503f 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsArn.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsArn.java @@ -4,7 +4,7 @@ */ package software.amazon.smithy.rulesengine.aws.language.functions; -import java.util.Arrays; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -39,28 +39,67 @@ private AwsArn(Builder builder) { * @return the optional ARN. */ public static Optional parse(String arn) { - String[] base = arn.split(":", 6); - if (base.length != 6) { + if (arn == null || arn.length() < 8 || !arn.startsWith("arn:")) { return Optional.empty(); } - // First section must be "arn". - if (!base[0].equals("arn")) { + + // find each of the first five ':' positions + int p0 = 3; // after "arn" + int p1 = arn.indexOf(':', p0 + 1); + if (p1 < 0) { + return Optional.empty(); + } + + int p2 = arn.indexOf(':', p1 + 1); + if (p2 < 0) { + return Optional.empty(); + } + + int p3 = arn.indexOf(':', p2 + 1); + if (p3 < 0) { return Optional.empty(); } - // Sections for partition, service, and resource type must not be empty. - if (base[1].isEmpty() || base[2].isEmpty() || base[5].isEmpty()) { + + int p4 = arn.indexOf(':', p3 + 1); + if (p4 < 0) { + return Optional.empty(); + } + + // extract and validate mandatory parts + String partition = arn.substring(p0 + 1, p1); + String service = arn.substring(p1 + 1, p2); + String region = arn.substring(p2 + 1, p3); + String accountId = arn.substring(p3 + 1, p4); + String resource = arn.substring(p4 + 1); + + if (partition.isEmpty() || service.isEmpty() || resource.isEmpty()) { return Optional.empty(); } return Optional.of(builder() - .partition(base[1]) - .service(base[2]) - .region(base[3]) - .accountId(base[4]) - .resource(Arrays.asList(base[5].split("[:/]", -1))) + .partition(partition) + .service(service) + .region(region) + .accountId(accountId) + .resource(splitResource(resource)) .build()); } + private static List splitResource(String resource) { + List result = new ArrayList<>(); + int start = 0; + int length = resource.length(); + for (int i = 0; i < length; i++) { + char c = resource.charAt(i); + if (c == ':' || c == '/') { + result.add(resource.substring(start, i)); + start = i + 1; + } + } + result.add(resource.substring(start)); + return result; + } + /** * Builder to create an {@link AwsArn} instance. * diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsPartition.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsPartition.java index 8a91ee62b98..7f1178a8d70 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsPartition.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsPartition.java @@ -184,6 +184,11 @@ public Value evaluate(List arguments) { public AwsPartition createFunction(FunctionNode functionNode) { return new AwsPartition(functionNode); } + + @Override + public int getCostHeuristic() { + return 6; + } } /** diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/IsVirtualHostableS3Bucket.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/IsVirtualHostableS3Bucket.java index d71b285d9d5..26e442bfbcd 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/IsVirtualHostableS3Bucket.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/IsVirtualHostableS3Bucket.java @@ -97,6 +97,11 @@ public Value evaluate(List arguments) { public IsVirtualHostableS3Bucket createFunction(FunctionNode functionNode) { return new IsVirtualHostableS3Bucket(functionNode); } + + @Override + public int getCostHeuristic() { + return 8; + } } /** diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/ParseArn.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/ParseArn.java index 1a4fa8c5283..e87ea48869e 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/ParseArn.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/ParseArn.java @@ -125,5 +125,10 @@ public Value evaluate(List arguments) { public ParseArn createFunction(FunctionNode functionNode) { return new ParseArn(functionNode); } + + @Override + public int getCostHeuristic() { + return 9; + } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/Endpoint.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/Endpoint.java index b1b23f7a997..f9b407fb196 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/Endpoint.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/Endpoint.java @@ -206,21 +206,21 @@ public int hashCode() { @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append("url: ").append(url).append("\n"); + sb.append("url: ").append(url); if (!headers.isEmpty()) { - sb.append("headers:\n"); + sb.append("\nheaders:"); for (Map.Entry> entry : headers.entrySet()) { - sb.append(StringUtils.indent(String.format("%s: %s", entry.getKey(), entry.getValue()), 2)) - .append("\n"); + sb.append("\n"); + sb.append(StringUtils.indent(String.format("%s: %s", entry.getKey(), entry.getValue()), 2)); } } if (!properties.isEmpty()) { - sb.append("properties:\n"); + sb.append("\nproperties:"); for (Map.Entry entry : properties.entrySet()) { - sb.append(StringUtils.indent(String.format("%s: %s", entry.getKey(), entry.getValue()), 2)) - .append("\n"); + sb.append("\n"); + sb.append(StringUtils.indent(String.format("%s: %s", entry.getKey(), entry.getValue()), 2)); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java index da81e346228..a913bb9a3d3 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java @@ -19,6 +19,7 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.language.syntax.rule.RuleValueVisitor; @@ -70,6 +71,21 @@ public Value evaluateRuleSet(EndpointRuleSet ruleset, Map par }); } + /** + * Configure the rule evaluator with the given parameters and parameter values for manual evaluation. + * + * @param parameters Parameters of the ruleset to evaluate. + * @param parameterArguments Parameter values to evaluate the ruleset against. + * @return the updated evaluator. + */ + public RuleEvaluator withParameters(Parameters parameters, Map parameterArguments) { + for (Parameter parameter : parameters) { + parameter.getDefault().ifPresent(value -> scope.insert(parameter.getName(), value)); + } + parameterArguments.forEach(scope::insert); + return this; + } + /** * Evaluates the given condition in the current scope. * diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/ArrayValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/ArrayValue.java index 40eeb6cc48a..64b52bad7ad 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/ArrayValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/ArrayValue.java @@ -101,4 +101,13 @@ public String toString() { } return "[" + String.join(", ", valueStrings) + "]"; } + + @Override + public Object toObject() { + List result = new ArrayList<>(values.size()); + for (Value value : values) { + result.add(value.toObject()); + } + return result; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/BooleanValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/BooleanValue.java index 43b7b145223..63a632c0d6c 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/BooleanValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/BooleanValue.java @@ -68,4 +68,9 @@ public int hashCode() { public String toString() { return String.valueOf(value); } + + @Override + public Object toObject() { + return value; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EmptyValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EmptyValue.java index 392214b450c..2409d9211d5 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EmptyValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EmptyValue.java @@ -35,4 +35,9 @@ public Node toNode() { public String toString() { return ""; } + + @Override + public Object toObject() { + return null; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EndpointValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EndpointValue.java index 324c71ab92f..b6867589589 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EndpointValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EndpointValue.java @@ -163,6 +163,11 @@ public String toString() { return sb.toString(); } + @Override + public Object toObject() { + return this; + } + /** * A builder used to create an {@link EndpointValue} class. */ diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/IntegerValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/IntegerValue.java index 99fe951ba5b..f8311556a91 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/IntegerValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/IntegerValue.java @@ -42,4 +42,9 @@ public IntegerValue expectIntegerValue() { public Node toNode() { return Node.from(value); } + + @Override + public Object toObject() { + return value; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/RecordValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/RecordValue.java index 20e7bd45a72..7a3b23abafe 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/RecordValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/RecordValue.java @@ -97,4 +97,13 @@ public int hashCode() { public String toString() { return value.toString(); } + + @Override + public Object toObject() { + Map result = new HashMap<>(value.size()); + for (Map.Entry entry : value.entrySet()) { + result.put(entry.getKey().toString(), entry.getValue().toObject()); + } + return result; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/StringValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/StringValue.java index 44f48d3cbbd..f1c659b08e3 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/StringValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/StringValue.java @@ -66,4 +66,9 @@ public int hashCode() { public String toString() { return value; } + + @Override + public Object toObject() { + return value; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/Value.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/Value.java index 7d719e1cd5b..b9caeb0da1e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/Value.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/Value.java @@ -230,4 +230,6 @@ private RuntimeException throwTypeMismatch(String expectedType) { getType(), this)); } + + public abstract Object toObject(); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/Identifier.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/Identifier.java index 78dbaf4d2bb..cf0540cd775 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/Identifier.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/Identifier.java @@ -80,7 +80,7 @@ public boolean equals(Object obj) { @Override public int hashCode() { - return Objects.hash(name); + return name.hashCode(); } @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/ToCondition.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/ToCondition.java index e70e946a35e..c69bfa44c3e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/ToCondition.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/ToCondition.java @@ -17,8 +17,11 @@ public interface ToCondition { * Convert this into a condition builder for compositional use. * * @return the condition builder. + * @throws UnsupportedOperationException if this cannot be converted to a condition. */ - Condition.Builder toConditionBuilder(); + default Condition.Builder toConditionBuilder() { + throw new UnsupportedOperationException("Cannot convert " + getClass().getName() + " to a condition"); + } /** * Convert this into a condition. diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java index 46fb4c5d5b6..b4157244604 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java @@ -6,8 +6,10 @@ import static software.amazon.smithy.rulesengine.language.error.RuleError.context; +import java.util.Collections; import java.util.Objects; import java.util.Optional; +import java.util.Set; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.SourceException; import software.amazon.smithy.model.SourceLocation; @@ -24,7 +26,6 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionNode; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -121,6 +122,16 @@ public static Reference getReference(Identifier name, FromSourceLocation context return new Reference(name, context); } + /** + * Constructs a {@link Reference} for the given {@link Identifier}. + * + * @param name the referenced identifier. + * @return the reference. + */ + public static Reference getReference(Identifier name) { + return getReference(name, SourceLocation.NONE); + } + /** * Constructs a {@link Literal} from the given {@link StringNode}. * @@ -131,6 +142,15 @@ public static Literal getLiteral(StringNode node) { return Literal.stringLiteral(new Template(node)); } + /** + * Get the set of variables that this condition references. + * + * @return variable references by name. + */ + public Set getReferences() { + return Collections.emptySet(); + } + /** * Invoke the {@link ExpressionVisitor} functions for this expression. * @@ -154,11 +174,6 @@ public Type type() { return cachedType; } - @Override - public Condition.Builder toConditionBuilder() { - return Condition.builder().fn(this); - } - @Override public Expression toExpression() { return this; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Reference.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Reference.java index 1f8a757d38a..6db377dc8dc 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Reference.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Reference.java @@ -6,7 +6,9 @@ import static software.amazon.smithy.rulesengine.language.error.RuleError.context; +import java.util.Collections; import java.util.Objects; +import java.util.Set; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.node.Node; import software.amazon.smithy.model.node.ObjectNode; @@ -47,6 +49,11 @@ public String template() { return String.format("{%s}", name); } + @Override + public Set getReferences() { + return Collections.singleton(getName().toString()); + } + @Override public R accept(ExpressionVisitor visitor) { return visitor.visitRef(this); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Template.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Template.java index a634158cd2c..c8cb5931533 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Template.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Template.java @@ -22,7 +22,6 @@ import software.amazon.smithy.rulesengine.language.evaluation.TypeCheck; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.ToExpression; -import software.amazon.smithy.utils.SmithyBuilder; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -41,7 +40,7 @@ public final class Template implements FromSourceLocation, ToNode { private final String value; public Template(StringNode template) { - sourceLocation = SmithyBuilder.requiredState("source", template.getSourceLocation()); + sourceLocation = template.getSourceLocation(); value = template.getValue(); parts = context("when parsing template", template, () -> parseTemplate(template.getValue(), template)); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/BooleanEquals.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/BooleanEquals.java index bfc615a111e..0eeaa8dbec8 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/BooleanEquals.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/BooleanEquals.java @@ -56,6 +56,16 @@ public static BooleanEquals ofExpressions(ToExpression arg1, boolean arg2) { return ofExpressions(arg1, Expression.of(arg2)); } + @Override + public BooleanEquals canonicalize() { + List args = getArguments(); + if (shouldSwapArgs(args.get(0), args.get(1))) { + return BooleanEquals.ofExpressions(args.get(1), args.get(0)); + } + + return this; + } + @Override public R accept(ExpressionVisitor visitor) { return visitor.visitBoolEquals(functionNode.getArguments().get(0), functionNode.getArguments().get(1)); @@ -92,5 +102,10 @@ public Value evaluate(List arguments) { public BooleanEquals createFunction(FunctionNode functionNode) { return new BooleanEquals(functionNode); } + + @Override + public int getCostHeuristic() { + return 2; + } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java index e8878c85ed6..1573db78fa6 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java @@ -45,4 +45,16 @@ public interface FunctionDefinition { * @return the created LibraryFunction implementation. */ LibraryFunction createFunction(FunctionNode functionNode); + + /** + * Get the relative "cost" of the function as compared to the baseline of "isset" which equals 1. + * + *

If this function is considered more computationally expensive, then it has a value higher than 1. Otherwise, + * it has a value equal to 1. Defaults to "4" for unknown functions. + * + * @return the relative cost. + */ + default int getCostHeuristic() { + return 4; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionNode.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionNode.java index c3b98048ccf..51a40498bae 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionNode.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionNode.java @@ -109,7 +109,7 @@ public static FunctionNode fromNode(ObjectNode function) { * * @return this function as an expression. */ - public Expression createFunction() { + public LibraryFunction createFunction() { return EndpointRuleSet.createFunctionFactory() .apply(this) .orElseThrow(() -> new RuleError(new SourceException( diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/GetAttr.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/GetAttr.java index 8e719afb25c..82049b7c9cd 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/GetAttr.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/GetAttr.java @@ -238,6 +238,11 @@ public Value evaluate(List arguments) { public GetAttr createFunction(FunctionNode functionNode) { return new GetAttr(functionNode); } + + @Override + public int getCostHeuristic() { + return 7; + } } public interface Part { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsSet.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsSet.java index ee867a5cec6..9d55683bbcb 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsSet.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsSet.java @@ -99,5 +99,10 @@ public Value evaluate(List arguments) { public IsSet createFunction(FunctionNode functionNode) { return new IsSet(functionNode); } + + @Override + public int getCostHeuristic() { + return 1; + } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsValidHostLabel.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsValidHostLabel.java index 4043f2da9aa..df45c277b97 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsValidHostLabel.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsValidHostLabel.java @@ -93,6 +93,11 @@ public Value evaluate(List arguments) { public IsValidHostLabel createFunction(FunctionNode functionNode) { return new IsValidHostLabel(functionNode); } + + @Override + public int getCostHeuristic() { + return 8; + } } /** diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java index afddaec181f..5d8f5166d7a 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java @@ -5,8 +5,10 @@ package software.amazon.smithy.rulesengine.language.syntax.expressions.functions; import java.util.ArrayList; +import java.util.LinkedHashSet; import java.util.List; import java.util.Objects; +import java.util.Set; import software.amazon.smithy.model.SourceException; import software.amazon.smithy.model.SourceLocation; import software.amazon.smithy.model.node.Node; @@ -15,6 +17,10 @@ import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.utils.SmithyUnstableApi; import software.amazon.smithy.utils.StringUtils; @@ -41,6 +47,24 @@ public String getName() { return functionNode.getName(); } + @Override + public Set getReferences() { + Set references = new LinkedHashSet<>(); + for (Expression arg : getArguments()) { + references.addAll(arg.getReferences()); + } + return references; + } + + /** + * Get the function definition. + * + * @return function definition. + */ + public FunctionDefinition getFunctionDefinition() { + return definition; + } + /** * @return The arguments to this function */ @@ -48,6 +72,17 @@ public List getArguments() { return functionNode.getArguments(); } + /** + * Returns a canonical form of this function. + * + *

Default implementation returns this. Override for functions that need canonicalization. + * + * @return the canonical form of this function + */ + public LibraryFunction canonicalize() { + return this; + } + protected Expression expectOneArgument() { List argv = functionNode.getArguments(); if (argv.size() == 1) { @@ -56,6 +91,11 @@ protected Expression expectOneArgument() { throw new RuleError(new SourceException("expected 1 argument but found " + argv.size(), functionNode)); } + @Override + public Condition.Builder toConditionBuilder() { + return Condition.builder().fn(this); + } + @Override public SourceLocation getSourceLocation() { return functionNode.getSourceLocation(); @@ -153,4 +193,56 @@ public String toString() { } return getName() + "(" + String.join(", ", arguments) + ")"; } + + /** + * Determines if two arguments should be swapped for canonical ordering. + * Used by commutative functions to ensure consistent argument order. + * + * @param arg0 the first argument + * @param arg1 the second argument + * @return true if arguments should be swapped + */ + protected static boolean shouldSwapArgs(Expression arg0, Expression arg1) { + boolean arg0IsRef = isReference(arg0); + boolean arg1IsRef = isReference(arg1); + + // Always put References before literals to make things consistent + if (arg0IsRef != arg1IsRef) { + return !arg0IsRef; // Swap if arg0 is literal and arg1 is reference + } + + // Both same type, use string comparison for deterministic order + return arg0.toString().compareTo(arg1.toString()) > 0; + } + + /** + * Strips single-variable template wrappers if present. + * Converts "{varName}" to just varName reference. + * + * @param expr the expression to strip + * @return the stripped expression or original if not applicable + */ + static Expression stripSingleVariableTemplate(Expression expr) { + if (!(expr instanceof StringLiteral)) { + return expr; + } + + StringLiteral stringLit = (StringLiteral) expr; + List parts = stringLit.value().getParts(); + if (parts.size() == 1 && parts.get(0) instanceof Template.Dynamic) { + return ((Template.Dynamic) parts.get(0)).toExpression(); + } + + return expr; + } + + private static boolean isReference(Expression arg) { + if (arg instanceof Reference) { + return true; + } else if (arg instanceof StringLiteral) { + StringLiteral s = (StringLiteral) arg; + return !s.value().isStatic(); + } + return false; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Not.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Not.java index 9b6330a7224..473e7e9d118 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Not.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Not.java @@ -87,5 +87,10 @@ public Value evaluate(List arguments) { public Not createFunction(FunctionNode functionNode) { return new Not(functionNode); } + + @Override + public int getCostHeuristic() { + return 2; + } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/ParseUrl.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/ParseUrl.java index da9f3ca37f3..e5a0bdd2a12 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/ParseUrl.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/ParseUrl.java @@ -119,6 +119,11 @@ public Value evaluate(List arguments) { public ParseUrl createFunction(FunctionNode functionNode) { return new ParseUrl(functionNode); } + + @Override + public int getCostHeuristic() { + return 10; + } } /** diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/StringEquals.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/StringEquals.java index 1c0f228d75d..63645232cf6 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/StringEquals.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/StringEquals.java @@ -56,6 +56,25 @@ public static StringEquals ofExpressions(ToExpression arg1, String arg2) { return ofExpressions(arg1, Expression.of(arg2)); } + @Override + public StringEquals canonicalize() { + List args = getArguments(); + + // Strip single-variable templates + Expression arg0 = stripSingleVariableTemplate(args.get(0)); + Expression arg1 = stripSingleVariableTemplate(args.get(1)); + + // Check if we need to reorder for commutative canonicalization + if (shouldSwapArgs(arg0, arg1)) { + return StringEquals.ofExpressions(arg1, arg0); + } else if (arg0 != args.get(0) || arg1 != args.get(1)) { + // Templates were stripped but no reordering needed + return StringEquals.ofExpressions(arg0, arg1); + } + + return this; + } + @Override public R accept(ExpressionVisitor visitor) { return visitor.visitStringEquals(functionNode.getArguments().get(0), functionNode.getArguments().get(1)); @@ -92,5 +111,10 @@ public Value evaluate(List arguments) { public StringEquals createFunction(FunctionNode functionNode) { return new StringEquals(functionNode); } + + @Override + public int getCostHeuristic() { + return 3; + } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Substring.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Substring.java index 8285b833799..391daf0a7b5 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Substring.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Substring.java @@ -105,6 +105,11 @@ public Value evaluate(List arguments) { public Substring createFunction(FunctionNode functionNode) { return new Substring(functionNode); } + + @Override + public int getCostHeuristic() { + return 5; + } } /** @@ -117,22 +122,23 @@ public Substring createFunction(FunctionNode functionNode) { * @return the substring value or null. */ public static String getSubstring(String value, int startIndex, int stopIndex, boolean reverse) { - if (startIndex >= stopIndex || value.length() < stopIndex) { + if (value == null) { return null; } - for (int i = 0; i < value.length(); i++) { - if (!(value.charAt(i) <= 127)) { + int len = value.length(); + if (startIndex < 0 || stopIndex > len || startIndex >= stopIndex) { + return null; + } + + int from = reverse ? len - stopIndex : startIndex; + int to = reverse ? len - startIndex : stopIndex; + for (int i = from; i < to; i++) { + if (value.charAt(i) > 127) { return null; } } - if (!reverse) { - return value.substring(startIndex, stopIndex); - } else { - int revStart = value.length() - stopIndex; - int revStop = value.length() - startIndex; - return value.substring(revStart, revStop); - } + return value.substring(from, to); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/UriEncode.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/UriEncode.java index b4815364538..eca33d9de5c 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/UriEncode.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/UriEncode.java @@ -99,5 +99,10 @@ public Value evaluate(List arguments) { public UriEncode createFunction(FunctionNode functionNode) { return new UriEncode(functionNode); } + + @Override + public int getCostHeuristic() { + return 8; + } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java index a3357a612d4..5d70dff6e0c 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java @@ -6,9 +6,11 @@ import java.util.Collections; import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.node.Node; import software.amazon.smithy.model.node.ObjectNode; @@ -72,4 +74,13 @@ public Node toNode() { members.forEach((k, v) -> builder.withMember(k.toString(), v.toNode())); return builder.build(); } + + @Override + public Set getReferences() { + Set references = new LinkedHashSet<>(); + for (Literal value : members().values()) { + references.addAll(value.getReferences()); + } + return references; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java index 29dce7b61fe..8adb24f487b 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java @@ -4,8 +4,11 @@ */ package software.amazon.smithy.rulesengine.language.syntax.expressions.literal; +import java.util.Collections; +import java.util.LinkedHashSet; import java.util.Objects; import java.util.Optional; +import java.util.Set; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.node.Node; import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; @@ -66,4 +69,21 @@ public String toString() { public Node toNode() { return value.toNode(); } + + @Override + public Set getReferences() { + Template template = value(); + if (template.isStatic()) { + return Collections.emptySet(); + } + + Set references = new LinkedHashSet<>(); + for (Template.Part part : template.getParts()) { + if (part instanceof Template.Dynamic) { + references.addAll(((Template.Dynamic) part).toExpression().getReferences()); + } + } + + return references; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java index 095b34bb182..867f5ec6e4c 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java @@ -5,9 +5,11 @@ package software.amazon.smithy.rulesengine.language.syntax.expressions.literal; import java.util.ArrayList; +import java.util.LinkedHashSet; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.node.ArrayNode; import software.amazon.smithy.model.node.Node; @@ -78,4 +80,13 @@ public Node toNode() { } return builder.build(); } + + @Override + public Set getReferences() { + Set references = new LinkedHashSet<>(); + for (Literal member : members()) { + references.addAll(member.getReferences()); + } + return references; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameter.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameter.java index eaff721790b..3dbe14dbdde 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameter.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameter.java @@ -234,7 +234,7 @@ public Optional getDefault() { @Override public Condition.Builder toConditionBuilder() { - return Condition.builder().fn(toExpression()); + throw new UnsupportedOperationException("Cannot convert a Parameter to a Condition"); } @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameters.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameters.java index 95feeedbf54..352c1d8d447 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameters.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameters.java @@ -4,6 +4,7 @@ */ package software.amazon.smithy.rulesengine.language.syntax.parameters; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -57,8 +58,7 @@ public static Builder builder() { public static Parameters fromNode(ObjectNode node) throws RuleError { Builder builder = new Builder(node); for (Map.Entry entry : node.getMembers().entrySet()) { - builder.addParameter(Parameter.fromNode(entry.getKey(), - RuleError.context("when parsing parameter", () -> entry.getValue().expectObjectNode()))); + builder.addParameter(Parameter.fromNode(entry.getKey(), entry.getValue().expectObjectNode())); } return builder.build(); } @@ -76,6 +76,15 @@ public void writeToScope(Scope scope) { } } + /** + * Convert the Parameters container to a list. + * + * @return the parameters list. + */ + public List toList() { + return Collections.unmodifiableList(parameters); + } + @Override public SourceLocation getSourceLocation() { return sourceLocation; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Condition.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Condition.java index 4f331600381..d98b1ee21d5 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Condition.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Condition.java @@ -21,6 +21,7 @@ import software.amazon.smithy.rulesengine.language.syntax.SyntaxElement; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionNode; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; import software.amazon.smithy.utils.SmithyBuilder; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -31,8 +32,10 @@ @SmithyUnstableApi public final class Condition extends SyntaxElement implements TypeCheck, FromSourceLocation, ToNode { public static final String ASSIGN = "assign"; - private final Expression function; + private final LibraryFunction function; private final Identifier result; + private int hash; + private String toString; private Condition(Builder builder) { this.result = builder.result; @@ -86,10 +89,23 @@ public Optional getResult() { * * @return the function for this condition. */ - public Expression getFunction() { + public LibraryFunction getFunction() { return function; } + /** + * Returns a canonical form of this condition. + * + * @return the canonical condition + */ + public Condition canonicalize() { + LibraryFunction canonicalFn = function.canonicalize(); + if (canonicalFn == function) { + return this; + } + return toBuilder().fn(canonicalFn).build(); + } + @Override public Builder toConditionBuilder() { return toBuilder(); @@ -152,26 +168,32 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(function, result); + int h = hash; + if (h == 0) { + h = Objects.hash(function, result); + hash = h; + } + return h; } @Override public String toString() { - StringBuilder sb = new StringBuilder(); - if (result != null) { - sb.append(result).append(" = "); + String s = this.toString; + if (s == null) { + s = result != null ? (result + " = " + function) : function.toString(); + toString = s; } - return sb.append(function).toString(); + return s; } /** * A builder used to create a {@link Condition} class. */ public static class Builder implements SmithyBuilder { - private Expression fn; + private LibraryFunction fn; private Identifier result; - public Builder fn(Expression fn) { + public Builder fn(LibraryFunction fn) { this.fn = fn; return this; } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/EndpointRule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/EndpointRule.java index f20c483fdf0..fddbe5f38b7 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/EndpointRule.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/EndpointRule.java @@ -20,6 +20,7 @@ @SmithyUnstableApi public final class EndpointRule extends Rule { private final Endpoint endpoint; + private int hash; EndpointRule(Rule.Builder builder, Endpoint endpoint) { super(builder); @@ -46,7 +47,7 @@ protected Type typecheckValue(Scope scope) { } @Override - void withValueNode(ObjectNode.Builder builder) { + protected void withValueNode(ObjectNode.Builder builder) { builder.withMember("endpoint", endpoint).withMember(TYPE, ENDPOINT); } @@ -67,7 +68,12 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(super.hashCode(), endpoint); + int result = hash; + if (result == 0) { + result = Objects.hash(super.hashCode(), endpoint); + hash = result; + } + return hash; } @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/ErrorRule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/ErrorRule.java index 03a60388f49..9dfce50faf6 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/ErrorRule.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/ErrorRule.java @@ -6,6 +6,7 @@ import static software.amazon.smithy.rulesengine.language.error.RuleError.context; +import java.util.Objects; import software.amazon.smithy.model.node.ObjectNode; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; @@ -19,6 +20,7 @@ @SmithyUnstableApi public final class ErrorRule extends Rule { private final Expression error; + private int hash; public ErrorRule(Rule.Builder builder, Expression error) { super(builder); @@ -45,7 +47,7 @@ protected Type typecheckValue(Scope scope) { } @Override - void withValueNode(ObjectNode.Builder builder) { + protected void withValueNode(ObjectNode.Builder builder) { builder.withMember("error", error.toNode()).withMember(TYPE, ERROR); } @@ -54,4 +56,28 @@ public String toString() { return super.toString() + StringUtils.indent(String.format("error(%s)", error), 2); } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } else if (object == null || getClass() != object.getClass()) { + return false; + } else if (!super.equals(object)) { + return false; + } else { + ErrorRule errorRule = (ErrorRule) object; + return Objects.equals(error, errorRule.error); + } + } + + @Override + public int hashCode() { + int result = hash; + if (result == 0) { + result = Objects.hash(super.hashCode(), error); + hash = result; + } + return hash; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java new file mode 100644 index 00000000000..d7c76f7feec --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java @@ -0,0 +1,43 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.syntax.rule; + +import software.amazon.smithy.model.node.ObjectNode; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.utils.SmithyUnstableApi; + +/** + * Sentinel rule for "no match" results. + */ +@SmithyUnstableApi +public final class NoMatchRule extends Rule { + + public static final NoMatchRule INSTANCE = new NoMatchRule(); + + private NoMatchRule() { + super(Rule.builder()); + } + + @Override + public T accept(RuleValueVisitor visitor) { + throw new UnsupportedOperationException("NO_MATCH is a sentinel"); + } + + @Override + protected Type typecheckValue(Scope scope) { + throw new UnsupportedOperationException("NO_MATCH is a sentinel"); + } + + @Override + protected void withValueNode(ObjectNode.Builder builder) { + // nothing + } + + @Override + public String toString() { + return "NO_MATCH"; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java index 5a66d029e09..ddcf4a12c6c 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java @@ -84,7 +84,10 @@ public static Rule fromNode(Node node) { Builder builder = new Builder(node); objectNode.getStringMember(DOCUMENTATION, builder::description); - builder.conditions(objectNode.expectArrayMember(CONDITIONS).getElementsAs(Condition::fromNode)); + + objectNode.getArrayMember(CONDITIONS).ifPresent(conds -> { + builder.conditions(conds.getElementsAs(Condition::fromNode)); + }); String type = objectNode.expectStringMember(TYPE).getValue(); switch (type) { @@ -131,9 +134,27 @@ public Optional getDocumentation() { */ public abstract T accept(RuleValueVisitor visitor); + /** + * Get a new Rule of the same type that has the same values, but no conditions. + * + * @return the rule without conditions. + * @throws UnsupportedOperationException if it is a TreeRule or Condition rule. + */ + public Rule withoutConditions() { + if (getConditions().isEmpty()) { + return this; + } else if (this instanceof ErrorRule) { + return new ErrorRule(ErrorRule.builder(this), ((ErrorRule) this).getError()); + } else if (this instanceof EndpointRule) { + return new EndpointRule(EndpointRule.builder(this), ((EndpointRule) this).getEndpoint()); + } else { + throw new UnsupportedOperationException("Cannot remove conditions from " + this); + } + } + protected abstract Type typecheckValue(Scope scope); - abstract void withValueNode(ObjectNode.Builder builder); + protected abstract void withValueNode(ObjectNode.Builder builder); @Override public Type typeCheck(Scope scope) { @@ -150,11 +171,13 @@ public Type typeCheck(Scope scope) { public Node toNode() { ObjectNode.Builder builder = ObjectNode.builder(); - ArrayNode.Builder conditionsBuilder = ArrayNode.builder(); - for (Condition condition : conditions) { - conditionsBuilder.withValue(condition.toNode()); + if (!conditions.isEmpty()) { + ArrayNode.Builder conditionsBuilder = ArrayNode.builder(); + for (Condition condition : conditions) { + conditionsBuilder.withValue(condition.toNode()); + } + builder.withMember(CONDITIONS, conditionsBuilder.build()); } - builder.withMember(CONDITIONS, conditionsBuilder.build()); if (documentation != null) { builder.withMember(DOCUMENTATION, documentation); @@ -219,7 +242,7 @@ public Builder conditions(ToCondition... conditions) { return this; } - public Builder conditions(List conditions) { + public Builder conditions(List conditions) { this.conditions.addAll(conditions); return this; } @@ -229,20 +252,24 @@ public Builder condition(ToCondition condition) { return this; } - public Rule endpoint(Endpoint endpoint) { - return this.onBuild.apply(new EndpointRule(this, endpoint)); + public EndpointRule endpoint(Endpoint endpoint) { + return (EndpointRule) this.onBuild.apply(new EndpointRule(this, endpoint)); + } + + public ErrorRule error(Node error) { + return error(Expression.fromNode(error)); } - public Rule error(Node error) { - return this.onBuild.apply(new ErrorRule(this, Expression.fromNode(error))); + public ErrorRule error(String error) { + return error(Literal.of(error)); } - public Rule error(String error) { - return this.onBuild.apply(new ErrorRule(this, Literal.of(error))); + public ErrorRule error(Expression error) { + return (ErrorRule) this.onBuild.apply(new ErrorRule(this, error)); } - public Rule treeRule(Rule... rules) { - return this.treeRule(Arrays.asList(rules)); + public TreeRule treeRule(Rule... rules) { + return (TreeRule) this.treeRule(Arrays.asList(rules)); } @SafeVarargs diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/TreeRule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/TreeRule.java index d6a817cd3b4..2afe672cde6 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/TreeRule.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/TreeRule.java @@ -21,6 +21,7 @@ @SmithyUnstableApi public final class TreeRule extends Rule { private final List rules; + private int hash; TreeRule(Builder builder, List rules) { super(builder); @@ -54,7 +55,7 @@ protected Type typecheckValue(Scope scope) { } @Override - void withValueNode(ObjectNode.Builder builder) { + protected void withValueNode(ObjectNode.Builder builder) { ArrayNode.Builder rulesBuilder = ArrayNode.builder().sourceLocation(getSourceLocation()); for (Rule rule : rules) { rulesBuilder.withValue(rule.toNode()); @@ -70,4 +71,14 @@ public String toString() { } return super.toString() + StringUtils.indent(String.join("\n", ruleStrings), 2); } + + @Override + public int hashCode() { + int result = hash; + if (result == 0) { + result = super.hashCode(); + hash = result; + } + return result; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionEvaluator.java new file mode 100644 index 00000000000..dd4c1d06832 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionEvaluator.java @@ -0,0 +1,22 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic; + +/** + * Evaluates a single condition using a condition index. + * + *

This functional interface provides maximum flexibility for condition evaluation implementations. Implementations + * are responsible maintaining their own internal state as methods are called (e.g., tracking variables). + */ +@FunctionalInterface +public interface ConditionEvaluator { + /** + * Evaluates the condition at the given index. + * + * @param conditionIndex the index of the condition to evaluate + * @return true if the condition is satisfied, false otherwise + */ + boolean test(int conditionIndex); +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfo.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfo.java new file mode 100644 index 00000000000..9b8d653356e --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfo.java @@ -0,0 +1,59 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic; + +import java.util.Collections; +import java.util.Set; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; + +/** + * Information about a condition. + */ +public interface ConditionInfo { + /** + * Create a new ConditionInfo from the given condition. + * + * @param condition Condition to compute. + * @return the created ConditionInfo. + */ + static ConditionInfo from(Condition condition) { + return new ConditionInfoImpl(condition); + } + + /** + * Get the underlying condition. + * + * @return condition. + */ + Condition getCondition(); + + /** + * Get the complexity of the condition. + * + * @return the complexity. + */ + default int getComplexity() { + return 1; + } + + /** + * Get the references used by the condition. + * + * @return the references. + */ + default Set getReferences() { + return Collections.emptySet(); + } + + /** + * Get the name of the variable this condition defines, if any, or null. + * + * @return the defined variable name or null. + */ + default String getReturnVariable() { + return getCondition().getResult().map(Identifier::toString).orElse(null); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImpl.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImpl.java new file mode 100644 index 00000000000..d493ae1c72a --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImpl.java @@ -0,0 +1,103 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic; + +import java.util.Set; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; + +/** + * Default implementation of {@link ConditionInfo} that computes condition metadata. + */ +final class ConditionInfoImpl implements ConditionInfo { + + private final Condition condition; + private final int complexity; + private final Set references; + + ConditionInfoImpl(Condition condition) { + this.condition = condition; + this.complexity = calculateComplexity(condition.getFunction()); + this.references = condition.getFunction().getReferences(); + } + + @Override + public Condition getCondition() { + return condition; + } + + @Override + public int getComplexity() { + return complexity; + } + + @Override + public Set getReferences() { + return references; + } + + @Override + public String getReturnVariable() { + return condition.getResult().map(Identifier::toString).orElse(null); + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } else if (object == null || getClass() != object.getClass()) { + return false; + } else { + return condition.equals(((ConditionInfoImpl) object).condition); + } + } + + @Override + public int hashCode() { + return condition.hashCode(); + } + + @Override + public String toString() { + return condition.toString(); + } + + private static int calculateComplexity(Expression e) { + // Base complexity for this node + int complexity = 1; + + if (e instanceof StringLiteral) { + Template template = ((StringLiteral) e).value(); + if (!template.isStatic()) { + if (template.getParts().size() > 1) { + // Single dynamic part is cheap, but multiple parts are expensive + complexity += 8; + } + for (Template.Part part : template.getParts()) { + // Add complexity from dynamic parts + if (part instanceof Template.Dynamic) { + Template.Dynamic dynamic = (Template.Dynamic) part; + complexity += calculateComplexity(dynamic.toExpression()); + } + } + } + } else if (e instanceof GetAttr) { + complexity += calculateComplexity(((GetAttr) e).getTarget()) + 2; + } else if (e instanceof LibraryFunction) { + LibraryFunction l = (LibraryFunction) e; + complexity += l.getFunctionDefinition().getCostHeuristic(); + for (Expression arg : l.getArguments()) { + complexity += calculateComplexity(arg); + } + } + + return complexity; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionReference.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionReference.java new file mode 100644 index 00000000000..94f1a34da75 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionReference.java @@ -0,0 +1,81 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic; + +import java.util.Set; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; + +/** + * A reference to a condition and whether it is negated. + */ +public final class ConditionReference implements ConditionInfo { + + private final ConditionInfo delegate; + private final boolean negated; + + public ConditionReference(ConditionInfo delegate, boolean negated) { + this.delegate = delegate; + this.negated = negated; + } + + /** + * Returns true if this condition is negated (e.g., wrapped in not). + * + * @return true if negated. + */ + public boolean isNegated() { + return negated; + } + + /** + * Create a negated version of this reference. + * + * @return returns the negated reference. + */ + public ConditionReference negate() { + return new ConditionReference(delegate, !negated); + } + + @Override + public Condition getCondition() { + return delegate.getCondition(); + } + + @Override + public int getComplexity() { + return delegate.getComplexity(); + } + + @Override + public Set getReferences() { + return delegate.getReferences(); + } + + @Override + public String getReturnVariable() { + return delegate.getReturnVariable(); + } + + @Override + public String toString() { + return (negated ? "!" : "") + delegate.toString(); + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } else if (object == null || getClass() != object.getClass()) { + return false; + } + ConditionReference that = (ConditionReference) object; + return negated == that.negated && delegate.equals(that.delegate); + } + + @Override + public int hashCode() { + return delegate.hashCode() ^ (negated ? 0x80000000 : 0); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/RuleBasedConditionEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/RuleBasedConditionEvaluator.java new file mode 100644 index 00000000000..72473046761 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/RuleBasedConditionEvaluator.java @@ -0,0 +1,30 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic; + +import software.amazon.smithy.rulesengine.language.evaluation.RuleEvaluator; +import software.amazon.smithy.rulesengine.language.evaluation.value.BooleanValue; +import software.amazon.smithy.rulesengine.language.evaluation.value.Value; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; + +/** + * Evaluates rules using a rules engine evaluator. + */ +public final class RuleBasedConditionEvaluator implements ConditionEvaluator { + private final RuleEvaluator evaluator; + private final Condition[] conditions; + + public RuleBasedConditionEvaluator(RuleEvaluator evaluator, Condition[] conditions) { + this.evaluator = evaluator; + this.conditions = conditions; + } + + @Override + public boolean test(int conditionIndex) { + Condition condition = conditions[conditionIndex]; + Value value = evaluator.evaluateCondition(condition); + return !value.isEmpty() && (!(value instanceof BooleanValue) || ((BooleanValue) value).getValue()); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java new file mode 100644 index 00000000000..ef7fd0005d5 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java @@ -0,0 +1,362 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.function.Function; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.model.node.ToNode; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; + +/** + * Binary Decision Diagram (BDD) with complement edges for efficient endpoint rule evaluation. + * + *

A BDD provides a compact representation of decision logic where each condition is evaluated at most once along + * any path. Complement edges (negative references) enable further size reduction through node sharing. + * + *

Reference Encoding: + *

    + *
  • {@code 0}: Invalid/unused reference (never appears in valid BDDs)
  • + *
  • {@code 1}: TRUE terminal; represents boolean true, treated as "no match" in endpoint resolution
  • + *
  • {@code -1}: FALSE terminal; represents boolean false, treated as "no match" in endpoint resolution
  • + *
  • {@code 2, 3, ...}: Node references (points to nodes array at index ref-1)
  • + *
  • {@code -2, -3, ...}: Complement node references (logical NOT of the referenced node)
  • + *
  • {@code 100_000_000+}: Result terminals (100_000_000 + resultIndex)
  • + *
+ * + *

Result terminals are encoded as special references starting at 100_000_000 (RESULT_OFFSET). + * When evaluating the BDD, any reference >= 100_000_000 represents a result terminal that + * indexes into the results array (resultIndex = ref - 100_000_000). These are not stored + * as nodes in the nodes array. + * + *

Node Format: {@code [variable, high, low]} + *

    + *
  • {@code variable}: Condition index (0 to conditionCount-1)
  • + *
  • {@code high}: Reference to follow when condition evaluates to true
  • + *
  • {@code low}: Reference to follow when condition evaluates to false
  • + *
+ */ +public final class Bdd implements ToNode { + /** + * Result reference encoding. + * + *

Results start at 100M to avoid collision with node references. + */ + public static final int RESULT_OFFSET = 100_000_000; + + private final Parameters parameters; + private final List conditions; + private final List results; + private final int[][] nodes; + private final int rootRef; + + /** + * Builds a BDD from an endpoint ruleset. + * + * @param ruleSet the ruleset to convert + * @return the constructed BDD + */ + public static Bdd from(EndpointRuleSet ruleSet) { + return from(Cfg.from(ruleSet)); + } + + /** + * Builds a BDD from a control flow graph. + * + * @param cfg the control flow graph + * @return the constructed BDD + */ + public static Bdd from(Cfg cfg) { + return from(cfg, new BddBuilder(), ConditionOrderingStrategy.defaultOrdering()); + } + + static Bdd from(Cfg cfg, BddBuilder bddBuilder, ConditionOrderingStrategy orderingStrategy) { + return new BddCompiler(cfg, orderingStrategy, bddBuilder).compile(); + } + + public Bdd(Parameters params, List conditions, List results, int[][] nodes, int rootRef) { + this.parameters = Objects.requireNonNull(params, "params is null"); + this.conditions = conditions; + this.results = results; + this.nodes = nodes; + this.rootRef = rootRef; + + if (rootRef < 0 && rootRef != -1) { + throw new IllegalArgumentException("Root reference cannot be complemented: " + rootRef); + } + } + + /** + * Gets the ordered list of conditions. + * + * @return list of conditions in evaluation order + */ + public List getConditions() { + return conditions; + } + + /** + * Gets the number of conditions. + * + * @return condition count + */ + public int getConditionCount() { + return conditions.size(); + } + + /** + * Gets the ordered list of results. + * + * @return list of results (null represents no match) + */ + public List getResults() { + return results; + } + + /** + * Gets the BDD nodes. + * + * @return array of node triples + */ + public int[][] getNodes() { + return nodes; + } + + /** + * Gets the root node reference. + * + * @return root reference + */ + public int getRootRef() { + return rootRef; + } + + /** + * Get the input parameters of the ruleset. + * + * @return input parameters. + */ + public Parameters getParameters() { + return parameters; + } + + /** + * Applies a transformation to the BDD and return a new BDD. + * + * @param transformer Optimization to apply. + * @return the optimized BDD. + */ + public Bdd transform(Function transformer) { + return transformer.apply(this); + } + + /** + * Checks if a reference points to a node (not a terminal or result). + * + * @param ref the reference to check + * @return true if this is a node reference + */ + public static boolean isNodeReference(int ref) { + if (ref == 0 || isTerminal(ref)) { + return false; + } + return Math.abs(ref) < RESULT_OFFSET; + } + + /** + * Checks if a reference points to a result. + * + * @param ref the reference to check + * @return true if this is a result reference + */ + public static boolean isResultReference(int ref) { + return ref >= RESULT_OFFSET; + } + + /** + * Checks if a reference is a terminal (TRUE or FALSE). + * + * @param ref the reference to check + * @return true if this is a terminal reference + */ + public static boolean isTerminal(int ref) { + return ref == 1 || ref == -1; + } + + /** + * Checks if a reference is complemented (negative). + * + * @param ref the reference to check + * @return true if the reference is complemented + */ + public static boolean isComplemented(int ref) { + // -1 is FALSE terminal, not a complement + return ref < 0 && ref != -1; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } else if (!(obj instanceof Bdd)) { + return false; + } + Bdd other = (Bdd) obj; + return rootRef == other.rootRef + && conditions.equals(other.conditions) + && results.equals(other.results) + && nodesEqual(nodes, other.nodes) + && Objects.equals(parameters, other.parameters); + } + + private static boolean nodesEqual(int[][] a, int[][] b) { + if (a.length != b.length) { + return false; + } + for (int i = 0; i < a.length; i++) { + if (!Arrays.equals(a[i], b[i])) { + return false; + } + } + return true; + } + + @Override + public int hashCode() { + int hash = 31 * rootRef + nodes.length; + // Sample up to 16 nodes distributed across the BDD + int step = Math.max(1, nodes.length / 16); + for (int i = 0; i < nodes.length; i += step) { + int[] node = nodes[i]; + hash = 31 * hash + node[0]; + hash = 31 * hash + node[1]; + hash = 31 * hash + node[2]; + } + return hash; + } + + @Override + public String toString() { + return toString(new StringBuilder()).toString(); + } + + /** + * Appends a string representation to the given StringBuilder. + * + * @param sb the StringBuilder to append to + * @return the given string builder. + */ + public StringBuilder toString(StringBuilder sb) { + // Calculate max width needed for first column identifiers + int maxConditionIdx = conditions.size() - 1; + int maxResultIdx = results.size() - 1; + + // Width needed for "C" + maxConditionIdx or "R" + maxResultIdx + int conditionWidth = maxConditionIdx >= 0 ? String.valueOf(maxConditionIdx).length() + 1 : 2; + int resultWidth = maxResultIdx >= 0 ? String.valueOf(maxResultIdx).length() + 1 : 2; + int varWidth = Math.max(conditionWidth, resultWidth); + + sb.append("Bdd{\n"); + + // Conditions + sb.append(" conditions (").append(getConditionCount()).append("):\n"); + for (int i = 0; i < conditions.size(); i++) { + sb.append(String.format(" %" + varWidth + "s: %s%n", "C" + i, conditions.get(i))); + } + + // Results + sb.append(" results (").append(results.size()).append("):\n"); + for (int i = 0; i < results.size(); i++) { + sb.append(String.format(" %" + varWidth + "s: ", "R" + i)); + appendResult(sb, results.get(i)); + sb.append("\n"); + } + + // Root + sb.append(" root: ").append(formatReference(rootRef)).append("\n"); + + // Nodes + sb.append(" nodes (").append(nodes.length).append("):\n"); + + // Calculate width needed for node indices + int indexWidth = String.valueOf(nodes.length - 1).length(); + + for (int i = 0; i < nodes.length; i++) { + sb.append(String.format(" %" + indexWidth + "d: ", i)); + if (i == 0) { + sb.append("terminal"); + } else { + int[] node = nodes[i]; + int varIdx = node[0]; + sb.append("["); + + // Use the calculated width for variable/result references + if (varIdx < conditions.size()) { + sb.append(String.format("%" + varWidth + "s", "C" + varIdx)); + } else { + sb.append(String.format("%" + varWidth + "s", "R" + (varIdx - conditions.size()))); + } + + // Format the references with consistent spacing + sb.append(", ") + .append(String.format("%6s", formatReference(node[1]))) + .append(", ") + .append(String.format("%6s", formatReference(node[2]))) + .append("]"); + } + sb.append("\n"); + } + + sb.append("}"); + return sb; + } + + private void appendResult(StringBuilder sb, Rule result) { + if (result == null) { + sb.append("(no match)"); + } else if (result instanceof EndpointRule) { + sb.append("Endpoint: ").append(((EndpointRule) result).getEndpoint().getUrl()); + } else if (result instanceof ErrorRule) { + sb.append("Error: ").append(((ErrorRule) result).getError()); + } else { + sb.append(result.getClass().getSimpleName()); + } + } + + private String formatReference(int ref) { + if (ref == 0) { + return "INVALID"; + } else if (ref == 1) { + return "TRUE"; + } else if (ref == -1) { + return "FALSE"; + } else if (ref >= Bdd.RESULT_OFFSET) { + // This is a result reference + int resultIdx = ref - Bdd.RESULT_OFFSET; + return "R" + resultIdx; + } else if (ref < 0) { + return "!" + (Math.abs(ref) - 1); + } else { + return String.valueOf(ref - 1); + } + } + + public static Bdd fromNode(Node node) { + return BddNodeHelpers.fromNode(node); + } + + @Override + public Node toNode() { + return BddNodeHelpers.toNode(this); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java new file mode 100644 index 00000000000..adfd2bd2e27 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java @@ -0,0 +1,608 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Binary Decision Diagram (BDD) builder with complement edges and multi-terminal support. + * + *

This implementation uses CUDD-style complement edges where negative references + * represent logical negation. The engine supports both boolean operations and + * multi-terminal decision diagrams (MTBDDs) for endpoint resolution. + * + *

Reference encoding: + *

    + *
  • 0: Invalid/unused reference (never appears in valid BDDs)
  • + *
  • 1: TRUE terminal
  • + *
  • -1: FALSE terminal
  • + *
  • 2, 3, 4, ...: BDD nodes (use index + 1)
  • + *
  • -2, -3, -4, ...: Complement of BDD nodes
  • + *
  • Bdd.RESULT_OFFSET+: Result terminals (100_000_000 + resultIndex)
  • + *
+ * + *

Node storage format: [variableIndex, highRef, lowRef] + * where variableIndex identifies the condition being tested: + *

    + *
  • -1: terminal node marker (only used for index 0)
  • + *
  • 0 to conditionCount-1: condition indices
  • + *
+ */ +final class BddBuilder { + + // Terminal constants + private static final int TRUE_REF = 1; + private static final int FALSE_REF = -1; + + // ITE operation cache for memoization + private final Map iteCache; + // Node storage: index 0 is reserved for the terminal node + private List nodes; + // Unique table for node deduplication + private Map uniqueTable; + // Track the boundary between conditions and results + private int conditionCount = -1; + + private final TripleKey mutableKey = new TripleKey(0, 0, 0); + + /** + * Creates a new BDD engine. + */ + public BddBuilder() { + this.nodes = new ArrayList<>(); + this.uniqueTable = new HashMap<>(); + this.iteCache = new HashMap<>(); + // Initialize with terminal node at index 0 + nodes.add(new int[] {-1, TRUE_REF, FALSE_REF}); + } + + /** + * Sets the number of conditions. Must be called before creating result nodes. + * + * @param count the number of conditions + */ + public void setConditionCount(int count) { + if (conditionCount != -1) { + throw new IllegalStateException("Condition count already set"); + } + this.conditionCount = count; + } + + /** + * Returns the TRUE terminal reference. + * + * @return TRUE reference (always 1) + */ + public int makeTrue() { + return TRUE_REF; + } + + /** + * Returns the FALSE terminal reference. + * + * @return FALSE reference (always -1) + */ + public int makeFalse() { + return FALSE_REF; + } + + /** + * Creates a result terminal reference. + * + * @param resultIndex the result index (must be non-negative) + * @return reference to the result terminal (RESULT_OFFSET + resultIndex) + * @throws IllegalArgumentException if resultIndex is negative + * @throws IllegalStateException if condition count not set + */ + public int makeResult(int resultIndex) { + if (conditionCount == -1) { + throw new IllegalStateException("Must set condition count before creating results"); + } else if (resultIndex < 0) { + throw new IllegalArgumentException("Result index must be non-negative: " + resultIndex); + } else { + return Bdd.RESULT_OFFSET + resultIndex; + } + } + + /** + * Creates or retrieves a BDD node for the given variable and branches. + * + *

Applies BDD reduction rules: + *

    + *
  • Eliminates redundant tests where both branches are identical
  • + *
  • Ensures complement edges appear only on the low branch
  • + *
  • Reuses existing nodes via the unique table
  • + *
+ * + * @param var the variable index + * @param high the reference for when variable is true + * @param low the reference for when variable is false + * @return reference to the BDD node + */ + public int makeNode(int var, int high, int low) { + // Reduction rule: if both branches are identical, skip this test + if (high == low) { + return high; + } + + // Complement edge canonicalization: ensure complement only on low branch. + // Don't apply this to result nodes or when branches contain results + boolean flip = false; + if (!isResultVariable(var) && !isResult(high) && !isResult(low) && isComplement(low)) { + high = negate(high); + low = negate(low); + flip = true; + } + + // Check if this node already exists + mutableKey.update(var, high, low); + Integer existing = uniqueTable.get(mutableKey); + + if (existing != null) { + int ref = toReference(existing); + return flip ? negate(ref) : ref; + } + + // Create new node + return insertNode(var, high, low, flip, nodes, uniqueTable); + } + + private int insertNode(int var, int high, int low, boolean flip, List nodes, Map tbl) { + int idx = nodes.size(); + nodes.add(new int[] {var, high, low}); + tbl.put(new TripleKey(var, high, low), idx); + int ref = toReference(idx); + return flip ? negate(ref) : ref; + } + + /** + * Negates a BDD reference (logical NOT). + * + * @param ref the reference to negate + * @return the negated reference + * @throws IllegalArgumentException if ref is a result terminal + */ + public int negate(int ref) { + if (isResult(ref)) { + throw new IllegalArgumentException("Cannot negate result terminal: " + ref); + } + return -ref; + } + + /** + * Checks if a reference has a complement edge. + * + * @param ref the reference to check + * @return true if complemented (negative) + */ + public boolean isComplement(int ref) { + return ref < 0; + } + + /** + * Checks if a reference is a boolean terminal (TRUE or FALSE). + * + * @param ref the reference to check + * @return true if boolean terminal + */ + public boolean isTerminal(int ref) { + return Math.abs(ref) == 1; + } + + /** + * Checks if a reference is a result terminal. + * + * @param ref the reference to check + * @return true if result terminal + */ + public boolean isResult(int ref) { + if (isTerminal(ref) || conditionCount == -1) { + return false; + } + return ref >= Bdd.RESULT_OFFSET; + } + + /** + * Checks if a variable index represents a result. + * + * @param varIdx the variable index + * @return true if result + */ + private boolean isResultVariable(int varIdx) { + return conditionCount != -1 && varIdx >= conditionCount; + } + + /** + * Checks if a reference is any kind of terminal. + * + * @param ref the reference to check + * @return true if any terminal type + */ + private boolean isAnyTerminal(int ref) { + return isTerminal(ref) || isResult(ref); + } + + /** + * Gets the variable index for a BDD node. + * + * @param ref the BDD reference + * @return the variable index, or -1 for terminals + */ + public int getVariable(int ref) { + if (isTerminal(ref)) { + return -1; + } else if (isResult(ref)) { + // For results, return the virtual variable index (conditionCount + resultIndex) + return conditionCount + (ref - Bdd.RESULT_OFFSET); + } else { + return nodes.get(Math.abs(ref) - 1)[0]; + } + } + + /** + * Computes the cofactor of a BDD with respect to a variable assignment. + * + * @param bdd the BDD to restrict + * @param varIndex the variable to fix + * @param value the value to assign (true or false) + * @return the restricted BDD + */ + public int cofactor(int bdd, int varIndex, boolean value) { + // Terminals and results are unaffected by cofactoring + if (isAnyTerminal(bdd)) { + return bdd; + } + + boolean complemented = isComplement(bdd); + int nodeIndex = toNodeIndex(bdd); + int[] node = nodes.get(nodeIndex); + int nodeVar = node[0]; + + if (nodeVar == varIndex) { + // This node tests our variable, so take the appropriate branch + int child = value ? node[1] : node[2]; + // Only negate if child is not a result + return (complemented && !isResult(child)) ? negate(child) : child; + } else if (nodeVar > varIndex) { + // Variable doesn't appear in this BDD (due to ordering) + return bdd; + } else { + // Variable appears deeper, so recurse on both branches + int high = cofactor(node[1], varIndex, value); + int low = cofactor(node[2], varIndex, value); + int result = makeNode(nodeVar, high, low); + return (complemented && !isResult(result)) ? negate(result) : result; + } + } + + /** + * Computes the logical AND of two BDDs. + * + * @param f first operand + * @param g second operand + * @return f AND g + * @throws IllegalArgumentException if operands are result terminals + */ + public int and(int f, int g) { + validateBooleanOperands(f, g, "AND"); + return ite(f, g, makeFalse()); + } + + /** + * Computes the logical OR of two BDDs. + * + * @param f first operand + * @param g second operand + * @return f OR g + * @throws IllegalArgumentException if operands are result terminals + */ + public int or(int f, int g) { + validateBooleanOperands(f, g, "OR"); + return ite(f, makeTrue(), g); + } + + /** + * Computes if-then-else (ITE) operation: "if f then g else h". + * + *

This is the fundamental BDD operation from which all others are derived. + * Includes optimizations for special cases and complement edges. + * + * @param f the condition (must be boolean) + * @param g the "then" branch + * @param h the "else" branch + * @return the resulting BDD + * @throws IllegalArgumentException if f is a result terminal + */ + public int ite(int f, int g, int h) { + // Normalize: if condition is complemented, swap branches + if (isComplement(f)) { + f = negate(f); + int tmp = g; + g = h; + h = tmp; + } + + // Terminal cases and validation. + if (isResult(f)) { + throw new IllegalArgumentException("Condition f must be boolean, not a result terminal"); + } else if (f == TRUE_REF) { + return g; + } else if (f == FALSE_REF) { + return h; + } else if (g == h) { + return g; + } + + // Boolean-specific optimizations (don't apply to result terminals) + if (!(isResult(g) || isResult(h))) { + // Standard Boolean identities + if (g == TRUE_REF && h == FALSE_REF) { + return f; + } else if (g == FALSE_REF && h == TRUE_REF) { + return negate(f); + } else if (isComplement(g) && isComplement(h) && !isResult(negate(g)) && !isResult(negate(h))) { + // Factor out common complement only if the negated values aren't results + return negate(ite(f, negate(g), negate(h))); + } else if (g == f) { // Simplifications when f appears in branches + return or(f, h); + } else if (h == f) { + return and(f, g); + } else if (g == negate(f)) { + return and(negate(f), h); + } else if (h == negate(f)) { + return or(negate(f), g); + } + } + + // Check cache using the mutable key. + Integer cached = iteCache.get(mutableKey.update(f, g, h)); + if (cached != null) { + return cached; + } + + // Create the actual key, and reserve cache slot to handle recursive calls + TripleKey key = new TripleKey(f, g, h); + iteCache.put(key, FALSE_REF); + + // Shannon expansion: find the top variable + int v = getTopVariable(f, g, h); + + // Compute cofactors + int f0 = cofactor(f, v, false); + int f1 = cofactor(f, v, true); + int g0 = cofactor(g, v, false); + int g1 = cofactor(g, v, true); + int h0 = cofactor(h, v, false); + int h1 = cofactor(h, v, true); + + // Recursive ITE on cofactors + int r0 = ite(f0, g0, h0); + int r1 = ite(f1, g1, h1); + + // Build result node + int result = makeNode(v, r1, r0); + + // Update cache with actual result + iteCache.put(key, result); + return result; + } + + /** + * Reduces the BDD by eliminating redundant nodes. + * + * @param rootRef the root of the BDD to reduce + * @return the reduced BDD root + */ + public int reduce(int rootRef) { + // Quick exit for terminals/results + if (isTerminal(rootRef) || isResult(rootRef)) { + return rootRef; + } + + boolean rootComp = isComplement(rootRef); + int absRoot = rootComp ? negate(rootRef) : rootRef; + + // Prep new storage + int N = nodes.size(); + List newNodes = new ArrayList<>(N); + Map newUnique = new HashMap<>(N * 2); + newNodes.add(new int[] {-1, TRUE_REF, FALSE_REF}); + + // Mapping array + int[] oldToNew = new int[N]; + Arrays.fill(oldToNew, -1); + + // Recurse + int newRoot = reduceRec(absRoot, oldToNew, newNodes, newUnique); + + // Swap in + this.nodes = newNodes; + this.uniqueTable = newUnique; + clearCaches(); + + return rootComp ? negate(newRoot) : newRoot; + } + + private int reduceRec(int ref, int[] oldToNew, List newNodes, Map newUnique) { + // Handle terminals and results first + if (isTerminal(ref)) { + return ref; + } + + // Handle result references (not stored as nodes) + if (isResult(ref)) { + return ref; + } + + // Peel complement + boolean comp = isComplement(ref); + int abs = comp ? negate(ref) : ref; + int idx = toNodeIndex(abs); + + // Already processed? + int mapped = oldToNew[idx]; + if (mapped != -1) { + return comp ? negate(mapped) : mapped; + } + + // Process children + int[] nd = nodes.get(idx); + int var = nd[0]; + int hiNew = reduceRec(nd[1], oldToNew, newNodes, newUnique); + int loNew = reduceRec(nd[2], oldToNew, newNodes, newUnique); + + // Apply reduction rule + int resultAbs; + if (hiNew == loNew) { + resultAbs = hiNew; + } else { + resultAbs = makeNodeInNew(var, hiNew, loNew, newNodes, newUnique); + } + + oldToNew[idx] = resultAbs; + return comp ? negate(resultAbs) : resultAbs; + } + + private int makeNodeInNew(int var, int hi, int lo, List newNodes, Map newUnique) { + if (hi == lo) { + return hi; + } + + // Canonicalize complement edges (but not for result nodes) + boolean comp = false; + if (!isResultVariable(var) && !isResult(hi) && !isResult(lo) && isComplement(lo)) { + hi = negate(hi); + lo = negate(lo); + comp = true; + } + + // Check if node already exists in new structure + Integer existing = newUnique.get(mutableKey.update(var, hi, lo)); + if (existing != null) { + int ref = toReference(existing); + return comp ? negate(ref) : ref; + } else { + // Create new node + return insertNode(var, hi, lo, comp, newNodes, newUnique); + } + } + + /** + * Finds the topmost variable among three BDDs. + */ + private int getTopVariable(int f, int g, int h) { + int varF = getVariable(f); + int varG = getVariable(g); + int varH = getVariable(h); + + // Filter out -1 (terminal marker) and find minimum + int min = Integer.MAX_VALUE; + if (varF >= 0 && varF < min) { + min = varF; + } + if (varG >= 0 && varG < min) { + min = varG; + } + if (varH >= 0 && varH < min) { + min = varH; + } + + return min == Integer.MAX_VALUE ? -1 : min; + } + + /** + * Clears all operation caches. + */ + public void clearCaches() { + iteCache.clear(); + } + + /** + * Clear out the state of the builder, but reuse the existing arrays, maps, etc. + * + * @return this builder + */ + public BddBuilder reset() { + clearCaches(); + uniqueTable.clear(); + nodes.clear(); + nodes.add(new int[] {-1, TRUE_REF, FALSE_REF}); + conditionCount = -1; + return this; + } + + /** + * Returns a defensive copy of the node table. + * + * @return list of node arrays + */ + public List getNodes() { + List copy = new ArrayList<>(nodes.size()); + for (int[] node : nodes) { + copy.add(node.clone()); + } + return copy; + } + + /** + * Get the array of nodes. + * + * @return array of nodes. + */ + public int[][] getNodesArray() { + return nodes.toArray(new int[0][]); + } + + private void validateBooleanOperands(int f, int g, String operation) { + if (isResult(f) || isResult(g)) { + throw new IllegalArgumentException("Cannot perform " + operation + " on result terminals"); + } + } + + private int toNodeIndex(int ref) { + return Math.abs(ref) - 1; + } + + private int toReference(int nodeIndex) { + return nodeIndex + 1; + } + + private static final class TripleKey { + private int a, b, c, hash; + + private TripleKey(int a, int b, int c) { + update(a, b, c); + } + + TripleKey update(int a, int b, int c) { + this.a = a; + this.b = b; + this.c = c; + int i = (a * 31 + b) * 31 + c; + this.hash = (i ^ (i >>> 16)); + return this; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (!(o instanceof TripleKey)) { + return false; + } + TripleKey k = (TripleKey) o; + return a == k.a && b == k.b && c == k.c; + } + + @Override + public int hashCode() { + return hash; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java new file mode 100644 index 00000000000..1d47e379d1f --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java @@ -0,0 +1,141 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.ConditionInfo; +import software.amazon.smithy.rulesengine.logic.ConditionReference; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.logic.cfg.CfgNode; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionData; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionNode; +import software.amazon.smithy.rulesengine.logic.cfg.ResultNode; + +/** + * BDD compiler that builds an unreduced BDD from CFG. + */ +final class BddCompiler { + private static final Logger LOGGER = Logger.getLogger(BddCompiler.class.getName()); + + private final Cfg cfg; + private final BddBuilder bddBuilder; + private final ConditionOrderingStrategy orderingStrategy; + + // Condition ordering + private List orderedConditions; + private Map conditionToIndex; + + // Result indexing + private final Map ruleToIndex = new HashMap<>(); + private final List indexedResults = new ArrayList<>(); + private int nextResultIndex = 0; + private int noMatchIndex = -1; + + // Simple cache to avoid recomputing identical subgraphs + private final Map nodeCache = new HashMap<>(); + + BddCompiler(Cfg cfg, ConditionOrderingStrategy orderingStrategy, BddBuilder bddBuilder) { + this.cfg = Objects.requireNonNull(cfg, "CFG cannot be null"); + this.orderingStrategy = Objects.requireNonNull(orderingStrategy, "Ordering strategy cannot be null"); + this.bddBuilder = Objects.requireNonNull(bddBuilder, "BDD builder cannot be null"); + } + + Bdd compile() { + long start = System.currentTimeMillis(); + extractAndOrderConditions(); + + // Set the condition count in the builder + bddBuilder.setConditionCount(orderedConditions.size()); + + // Create the "no match" terminal + noMatchIndex = getOrCreateResultIndex(NoMatchRule.INSTANCE); + int rootRef = convertCfgToBdd(cfg.getRoot()); + rootRef = bddBuilder.reduce(rootRef); + Parameters parameters = cfg.getRuleSet().getParameters(); + Bdd bdd = new Bdd(parameters, orderedConditions, indexedResults, bddBuilder.getNodesArray(), rootRef); + + long elapsed = System.currentTimeMillis() - start; + LOGGER.fine(String.format( + "BDD compilation complete: %d conditions, %d results, %d BDD nodes in %dms", + orderedConditions.size(), + indexedResults.size(), + bddBuilder.getNodes().size() - 1, + elapsed)); + + return bdd; + } + + private int convertCfgToBdd(CfgNode cfgNode) { + Integer cached = nodeCache.get(cfgNode); + if (cached != null) { + return cached; + } + + int result; + if (cfgNode == null) { + result = bddBuilder.makeResult(noMatchIndex); + + } else if (cfgNode instanceof ResultNode) { + Rule rule = ((ResultNode) cfgNode).getResult(); + result = bddBuilder.makeResult(getOrCreateResultIndex(rule)); + + } else { + ConditionNode cn = (ConditionNode) cfgNode; + ConditionReference ref = cn.getCondition(); + int varIdx = conditionToIndex.get(ref.getCondition()); + + // Recursively build the two branches + int hi = convertCfgToBdd(cn.getTrueBranch()); + int lo = convertCfgToBdd(cn.getFalseBranch()); + + // If the original rule said "not condition", swap branches + if (ref.isNegated()) { + int tmp = hi; + hi = lo; + lo = tmp; + } + + // Build the pure boolean test for variable varIdx + int test = bddBuilder.makeNode(varIdx, bddBuilder.makeTrue(), bddBuilder.makeFalse()); + + // Combine with ITE (reduces and merges) + result = bddBuilder.ite(test, hi, lo); + } + + nodeCache.put(cfgNode, result); + return result; + } + + private int getOrCreateResultIndex(Rule rule) { + return ruleToIndex.computeIfAbsent(rule, r -> { + int idx = nextResultIndex++; + indexedResults.add(r); + return idx; + }); + } + + private void extractAndOrderConditions() { + // Extract conditions from CFG and order them. + ConditionData data = cfg.getConditionData(); + Map infos = data.getConditionInfos(); + orderedConditions = orderingStrategy.orderConditions(data.getConditions(), infos); + + // Build index map + conditionToIndex = new LinkedHashMap<>(); + for (int i = 0; i < orderedConditions.size(); i++) { + conditionToIndex.put(orderedConditions.get(i), i); + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java new file mode 100644 index 00000000000..0cdf6028d6a --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java @@ -0,0 +1,363 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.ConditionEvaluator; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.logic.cfg.CfgNode; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionNode; +import software.amazon.smithy.rulesengine.logic.cfg.ResultNode; + +/** + * Verifies functional equivalence between a CFG and its BDD representation. + * + *

This verifier uses structural equivalence checking to ensure that both representations produce the same result. + * When the BDD has fewer than 20 conditions, the checking is exhaustive. When there are more, random samples are + * checked up the earlier of max samples being reached, or the max duration being reached. + */ +public final class BddEquivalenceChecker { + + private static final Logger LOGGER = Logger.getLogger(BddEquivalenceChecker.class.getName()); + + private static final int EXHAUSTIVE_THRESHOLD = 20; + private static final int DEFAULT_MAX_SAMPLES = 1_000_000; + private static final Duration DEFAULT_TIMEOUT = Duration.ofMinutes(1); + + private final Cfg cfg; + private final Bdd bdd; + private final List parameters; + private final Map conditionToIndex = new HashMap<>(); + + private int maxSamples = DEFAULT_MAX_SAMPLES; + private Duration timeout = DEFAULT_TIMEOUT; + + private int testsRun = 0; + private long startTime; + + public static BddEquivalenceChecker of(Cfg cfg, Bdd bdd) { + return new BddEquivalenceChecker(cfg, bdd); + } + + private BddEquivalenceChecker(Cfg cfg, Bdd bdd) { + this.cfg = cfg; + this.bdd = bdd; + this.parameters = new ArrayList<>(cfg.getRuleSet().getParameters().toList()); + + for (int i = 0; i < bdd.getConditions().size(); i++) { + conditionToIndex.put(bdd.getConditions().get(i), i); + } + } + + /** + * Sets the maximum number of samples to test for large condition sets. + * + *

Defaults to a max of 1M samples. Set to {@code <= 0} to disable the max. + * + * @param maxSamples the maximum number of samples + * @return this verifier for method chaining + */ + public BddEquivalenceChecker setMaxSamples(int maxSamples) { + if (maxSamples < 1) { + maxSamples = Integer.MAX_VALUE; + } + this.maxSamples = maxSamples; + return this; + } + + /** + * Sets the maximum amount of time to take for the verification (runs until timeout or max samples met). + * + *

Defaults to a 1-minute timeout if not overridden. + * + * @param timeout the timeout duration + * @return this verifier for method chaining + */ + public BddEquivalenceChecker setMaxDuration(Duration timeout) { + this.timeout = timeout; + return this; + } + + /** + * Verifies that the BDD produces identical results to the CFG. + * + * @throws VerificationException if any discrepancy is found + */ + public void verify() { + startTime = System.currentTimeMillis(); + verifyResults(); + testsRun = 0; + + LOGGER.info(() -> String.format("Verifying BDD with %d conditions (max samples: %d, timeout: %s)", + bdd.getConditionCount(), + maxSamples, + timeout)); + + if (bdd.getConditionCount() <= EXHAUSTIVE_THRESHOLD) { + verifyExhaustive(); + } else { + verifyWithLimits(); + } + + LOGGER.info(String.format("BDD verification passed: %d tests in %s", testsRun, getElapsedDuration())); + } + + private void verifyResults() { + Set cfgResults = new HashSet<>(); + for (CfgNode node : cfg) { + if (node instanceof ResultNode) { + Rule result = ((ResultNode) node).getResult(); + if (result != null) { + cfgResults.add(result); + } + } + } + + // Remove the NoMatchRule that's added by default. It's not in the CFG. + Set bddResults = new HashSet<>(bdd.getResults()); + bddResults.removeIf(v -> v == NoMatchRule.INSTANCE); + + if (!cfgResults.equals(bddResults)) { + Set inCfgOnly = new HashSet<>(cfgResults); + inCfgOnly.removeAll(bddResults); + Set inBddOnly = new HashSet<>(bddResults); + inBddOnly.removeAll(cfgResults); + throw new IllegalStateException(String.format( + "Result mismatch: CFG has %d results, BDD has %d results (excluding NoMatchRule).%n" + + "In CFG only: %s%n" + + "In BDD only: %s", + cfgResults.size(), + bddResults.size(), + inCfgOnly, + inBddOnly)); + } + } + + /** + * Exhaustively tests all possible condition combinations. + */ + private void verifyExhaustive() { + long totalCombinations = 1L << bdd.getConditionCount(); + LOGGER.info(() -> "Running exhaustive verification with " + totalCombinations + " combinations"); + for (long mask = 0; mask < totalCombinations; mask++) { + verifyCase(mask); + if (hasEitherLimitBeenExceeded()) { + LOGGER.info(String.format("Exhaustive verification stopped after %d tests " + + "(limit: %d samples or %s timeout)", testsRun, maxSamples, timeout)); + break; + } + } + } + + /** + * Verifies with configured limits (samples and timeout). + * Continues until EITHER limit is reached: maxSamples reached OR timeout exceeded. + */ + private void verifyWithLimits() { + LOGGER.info(() -> String.format("Running limited verification (will stop at %d samples OR %s timeout)", + maxSamples, + timeout)); + verifyCriticalCases(); + + while (!hasEitherLimitBeenExceeded()) { + long mask = randomMask(); + verifyCase(mask); + if (testsRun % 10000 == 0 && testsRun > 0) { + LOGGER.fine(() -> String.format("Progress: %d tests run, %s elapsed", testsRun, getElapsedDuration())); + } + } + + LOGGER.info(() -> String.format("Verification complete: %d tests run in %s", testsRun, getElapsedDuration())); + } + + /** + * Tests critical edge cases that are likely to expose bugs. + */ + private void verifyCriticalCases() { + LOGGER.fine("Testing critical edge cases"); + + // All conditions false + verifyCase(0); + + // All conditions true + verifyCase((1L << bdd.getConditionCount()) - 1); + + // Each condition true individually + for (int i = 0; i < bdd.getConditionCount() && !hasEitherLimitBeenExceeded(); i++) { + verifyCase(1L << i); + } + + // Each condition false individually (all others true) + long allTrue = (1L << bdd.getConditionCount()) - 1; + for (int i = 0; i < bdd.getConditionCount() && !hasEitherLimitBeenExceeded(); i++) { + verifyCase(allTrue ^ (1L << i)); + } + + // Alternating patterns + if (!hasEitherLimitBeenExceeded()) { + verifyCase(0x5555555555555555L & ((1L << bdd.getConditionCount()) - 1)); + } + + if (!hasEitherLimitBeenExceeded()) { + verifyCase(0xAAAAAAAAAAAAAAAAL & ((1L << bdd.getConditionCount()) - 1)); + } + } + + private boolean hasEitherLimitBeenExceeded() { + return testsRun >= maxSamples || isTimedOut(); + } + + private boolean isTimedOut() { + return getElapsedDuration().compareTo(timeout) >= 0; + } + + private Duration getElapsedDuration() { + return Duration.ofMillis(System.currentTimeMillis() - startTime); + } + + private void verifyCase(long mask) { + testsRun++; + + // Create evaluators that will return fixed values for conditions + FixedMaskEvaluator maskEvaluator = new FixedMaskEvaluator(mask); + Rule cfgResult = evaluateCfgWithMask(maskEvaluator); + Rule bddResult = evaluateBdd(mask); + + if (!resultsEqual(cfgResult, bddResult)) { + StringBuilder errorMsg = new StringBuilder(); + errorMsg.append("BDD verification mismatch found!\n"); + errorMsg.append("Test case #").append(testsRun).append("\n"); + errorMsg.append("Condition mask: ").append(Long.toBinaryString(mask)).append("\n"); + errorMsg.append("\nCondition details:\n"); + for (int i = 0; i < bdd.getConditions().size(); i++) { + Condition condition = bdd.getConditions().get(i); + boolean value = (mask & (1L << i)) != 0; + errorMsg.append(" Condition ") + .append(i) + .append(" [") + .append(value) + .append("]: ") + .append(condition) + .append("\n"); + } + errorMsg.append("\nResults:\n"); + errorMsg.append(" CFG result: ").append(describeResult(cfgResult)).append("\n"); + errorMsg.append(" BDD result: ").append(describeResult(bddResult)); + throw new VerificationException(errorMsg.toString()); + } + } + + private Rule evaluateCfgWithMask(ConditionEvaluator maskEvaluator) { + CfgNode result = evaluateCfgNode(cfg.getRoot(), conditionToIndex, maskEvaluator); + if (result instanceof ResultNode) { + return ((ResultNode) result).getResult(); + } + + return null; + } + + // Recursively evaluates a CFG node. + private CfgNode evaluateCfgNode( + CfgNode node, + Map conditionToIndex, + ConditionEvaluator maskEvaluator + ) { + if (node instanceof ResultNode) { + return node; + } + + if (node instanceof ConditionNode) { + ConditionNode condNode = (ConditionNode) node; + Condition condition = condNode.getCondition().getCondition(); + + Integer index = conditionToIndex.get(condition); + if (index == null) { + throw new IllegalStateException("Condition not found in BDD: " + condition); + } + + boolean conditionResult = maskEvaluator.test(index); + + // Handle negation if the condition reference is negated + if (condNode.getCondition().isNegated()) { + conditionResult = !conditionResult; + } + + // Follow the appropriate branch + if (conditionResult) { + return evaluateCfgNode(condNode.getTrueBranch(), conditionToIndex, maskEvaluator); + } else { + return evaluateCfgNode(condNode.getFalseBranch(), conditionToIndex, maskEvaluator); + } + } + + throw new IllegalStateException("Unknown CFG node type: " + node); + } + + private Rule evaluateBdd(long mask) { + FixedMaskEvaluator evaluator = new FixedMaskEvaluator(mask); + BddEvaluator bddEvaluator = BddEvaluator.from(bdd); + int resultIndex = bddEvaluator.evaluate(evaluator); + return resultIndex < 0 ? null : bdd.getResults().get(resultIndex); + } + + private boolean resultsEqual(Rule r1, Rule r2) { + if (r1 == r2) { + return true; + } else if (r1 == null || r2 == null) { + return false; + } else { + return r1.withoutConditions().equals(r2.withoutConditions()); + } + } + + // Generates a random bit mask for sampling. + private long randomMask() { + long mask = 0; + for (int i = 0; i < bdd.getConditionCount(); i++) { + if (Math.random() < 0.5) { + mask |= (1L << i); + } + } + return mask; + } + + private String describeResult(Rule rule) { + return rule == null ? "null (no match)" : rule.toString(); + } + + // A condition evaluator that returns values based on a fixed bit mask. + private static class FixedMaskEvaluator implements ConditionEvaluator { + private final long mask; + + FixedMaskEvaluator(long mask) { + this.mask = mask; + } + + @Override + public boolean test(int conditionIndex) { + return (mask & (1L << conditionIndex)) != 0; + } + } + + /** + * Exception thrown when verification fails. + */ + public static class VerificationException extends RuntimeException { + public VerificationException(String message) { + super(message); + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluator.java new file mode 100644 index 00000000000..a9aead5155a --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluator.java @@ -0,0 +1,86 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import software.amazon.smithy.rulesengine.logic.ConditionEvaluator; + +/** + * Simple BDD evaluator that works directly with BDD nodes. + */ +public final class BddEvaluator { + + private final int[][] nodes; + private final int rootRef; + private final int conditionCount; + + private BddEvaluator(int[][] nodes, int rootRef, int conditionCount) { + this.nodes = nodes; + this.rootRef = rootRef; + this.conditionCount = conditionCount; + } + + /** + * Create evaluator from a Bdd object. + * + * @param bdd the BDD + * @return the evaluator + */ + public static BddEvaluator from(Bdd bdd) { + return from(bdd.getNodes(), bdd.getRootRef(), bdd.getConditionCount()); + } + + /** + * Create evaluator from BDD data. + * + * @param nodes BDD nodes array + * @param rootRef root reference + * @param conditionCount number of conditions + * @return the evaluator + */ + public static BddEvaluator from(int[][] nodes, int rootRef, int conditionCount) { + return new BddEvaluator(nodes, rootRef, conditionCount); + } + + /** + * Evaluates the BDD. + * + * @param evaluator the condition evaluator + * @return the result index, or -1 for no match + */ + public int evaluate(ConditionEvaluator evaluator) { + int resultOff = Bdd.RESULT_OFFSET; + int ref = this.rootRef; + + while (true) { + int abs = Math.abs(ref); + // stop once we hit a terminal (+/-1) or a result node (|ref| >= resultOff) + if (abs <= 1 || abs >= resultOff) { + break; + } + + int[] node = this.nodes[abs - 1]; + int varIdx = node[0]; + int hi = node[1]; + int lo = node[2]; + + // swap branches for a complemented pointer + if (ref < 0) { + int tmp = hi; + hi = lo; + lo = tmp; + } + + ref = evaluator.test(varIdx) ? hi : lo; + } + + // +/-1 means no match. + if (ref == 1 || ref == -1) { + return -1; + } + + int resultIdx = ref - resultOff; + return resultIdx == 0 ? -1 : resultIdx; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddNodeHelpers.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddNodeHelpers.java new file mode 100644 index 00000000000..d3cd0ed153f --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddNodeHelpers.java @@ -0,0 +1,150 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.Set; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.model.node.ObjectNode; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.utils.SetUtils; + +final class BddNodeHelpers { + private static final int[] TERMINAL_NODE = new int[] {-1, 1, -1}; + private static final Set ALLOWED_PROPERTIES = SetUtils.of( + "parameters", + "conditions", + "results", + "root", + "nodes", + "nodeCount"); + + private BddNodeHelpers() {} + + static Node toNode(Bdd bdd) { + ObjectNode.Builder builder = ObjectNode.builder(); + + List conditions = new ArrayList<>(); + for (Condition c : bdd.getConditions()) { + conditions.add(c.toNode()); + } + + List results = new ArrayList<>(); + if (!(bdd.getResults().get(0) instanceof NoMatchRule)) { + throw new IllegalArgumentException("BDD must always have a NoMatchRule as the first result"); + } + for (int i = 1; i < bdd.getResults().size(); i++) { + Rule result = bdd.getResults().get(i); + if (result instanceof NoMatchRule) { + throw new IllegalArgumentException("NoMatch rules can only appear at rule index 0. Found at index" + i); + } + results.add(bdd.getResults().get(i).toNode()); + } + + return builder + .withMember("parameters", bdd.getParameters().toNode()) + .withMember("conditions", Node.fromNodes(conditions)) + .withMember("results", Node.fromNodes(results)) + .withMember("root", bdd.getRootRef()) + .withMember("nodes", encodeNodes(bdd)) + .withMember("nodeCount", bdd.getNodes().length) + .build(); + } + + static Bdd fromNode(Node node) { + ObjectNode obj = node.expectObjectNode(); + obj.warnIfAdditionalProperties(ALLOWED_PROPERTIES); + Parameters params = Parameters.fromNode(obj.expectObjectMember("parameters")); + List conditions = obj.expectArrayMember("conditions").getElementsAs(Condition::fromNode); + + // Read the results and prepend NoMatchRule at index 0 + List serializedResults = obj.expectArrayMember("results").getElementsAs(Rule::fromNode); + List results = new ArrayList<>(); + results.add(NoMatchRule.INSTANCE); // Always add no-match at index 0 + results.addAll(serializedResults); + + String nodesBase64 = obj.expectStringMember("nodes").getValue(); + int nodeCount = obj.expectNumberMember("nodeCount").getValue().intValue(); + int[][] nodes = decodeNodes(nodesBase64, nodeCount); + int rootRef = obj.expectNumberMember("root").getValue().intValue(); + return new Bdd(params, conditions, results, nodes, rootRef); + } + + static String encodeNodes(Bdd bdd) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos)) { + int[][] nodes = bdd.getNodes(); + for (int[] node : nodes) { + writeVarInt(dos, node[0]); + writeVarInt(dos, node[1]); + writeVarInt(dos, node[2]); + } + dos.flush(); + return Base64.getEncoder().encodeToString(baos.toByteArray()); + } catch (IOException e) { + throw new RuntimeException("Failed to encode BDD nodes", e); + } + } + + static int[][] decodeNodes(String base64, int nodeCount) { + if (base64.isEmpty() || nodeCount == 0) { + return new int[][] {TERMINAL_NODE}; + } + + byte[] data = Base64.getDecoder().decode(base64); + int[][] nodes = new int[nodeCount][]; + + try (ByteArrayInputStream bais = new ByteArrayInputStream(data); + DataInputStream dis = new DataInputStream(bais)) { + for (int i = 0; i < nodeCount; i++) { + int varIdx = readVarInt(dis); + int high = readVarInt(dis); + int low = readVarInt(dis); + nodes[i] = new int[] {varIdx, high, low}; + } + if (bais.available() > 0) { + throw new IllegalArgumentException("Extra data found after decoding " + nodeCount + + " nodes. " + bais.available() + " bytes remaining."); + } + return nodes; + } catch (IOException e) { + throw new RuntimeException("Failed to decode BDD nodes", e); + } + } + + // Zig-zag + varint encode of a signed int + private static void writeVarInt(DataOutputStream dos, int value) throws IOException { + int zz = (value << 1) ^ (value >> 31); + while ((zz & ~0x7F) != 0) { + dos.writeByte((zz & 0x7F) | 0x80); + zz >>>= 7; + } + dos.writeByte(zz); + } + + // Decode a signed int from varint + zig-zag. + private static int readVarInt(DataInputStream dis) throws IOException { + int shift = 0, result = 0; + while (true) { + byte b = dis.readByte(); + result |= (b & 0x7F) << shift; + if ((b & 0x80) == 0) + break; + shift += 7; + } + // reverse zig-zag + return (result >>> 1) ^ -(result & 1); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraph.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraph.java new file mode 100644 index 00000000000..dcbba6596ad --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraph.java @@ -0,0 +1,110 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.logic.ConditionInfo; + +/** + * Immutable graph of dependencies between conditions. + * + *

This class performs the expensive AST analysis once to extract: + *

    + *
  • Variable definitions - which conditions define which variables
  • + *
  • Variable usage - which conditions use which variables
  • + *
  • Raw dependencies - which conditions must come before others
  • + *
+ */ +final class ConditionDependencyGraph { + private final Map conditionInfos; + private final Map> dependencies; + private final Map> variableDefiners; + private final Map> isSetConditions; + + /** + * Creates a dependency graph by analyzing the given conditions. + * + * @param conditionInfos metadata about each condition + */ + public ConditionDependencyGraph(Map conditionInfos) { + this.conditionInfos = Collections.unmodifiableMap(new LinkedHashMap<>(conditionInfos)); + this.variableDefiners = new LinkedHashMap<>(); + this.isSetConditions = new LinkedHashMap<>(); + + // Categorize all conditions + for (Map.Entry entry : conditionInfos.entrySet()) { + Condition cond = entry.getKey(); + ConditionInfo info = entry.getValue(); + + // Track variable definition + String definedVar = info.getReturnVariable(); + if (definedVar != null) { + variableDefiners.computeIfAbsent(definedVar, k -> new LinkedHashSet<>()).add(cond); + } + + // Track isSet conditions + if (isIsset(cond)) { + for (String var : info.getReferences()) { + isSetConditions.computeIfAbsent(var, k -> new LinkedHashSet<>()).add(cond); + } + } + } + + // Compute dependencies + Map> deps = new LinkedHashMap<>(); + for (Map.Entry entry : conditionInfos.entrySet()) { + Condition cond = entry.getKey(); + ConditionInfo info = entry.getValue(); + + Set condDeps = new LinkedHashSet<>(); + + for (String usedVar : info.getReferences()) { + // Must come after any condition that defines this variable + condDeps.addAll(variableDefiners.getOrDefault(usedVar, Collections.emptySet())); + + // Non-isSet conditions must come after isSet checks for undefined variables + if (!isIsset(cond)) { + condDeps.addAll(isSetConditions.getOrDefault(usedVar, Collections.emptySet())); + } + } + + condDeps.remove(cond); // Remove self-dependencies + if (!condDeps.isEmpty()) { + deps.put(cond, Collections.unmodifiableSet(condDeps)); + } + } + + this.dependencies = Collections.unmodifiableMap(deps); + } + + /** + * Gets the dependencies for a condition. + * + * @param condition the condition to query + * @return set of conditions that must come before it (never null) + */ + public Set getDependencies(Condition condition) { + return dependencies.getOrDefault(condition, Collections.emptySet()); + } + + /** + * Gets the number of conditions in this dependency graph. + * + * @return the number of conditions + */ + public int size() { + return conditionInfos.size(); + } + + private static boolean isIsset(Condition cond) { + return cond.getFunction().getFunctionDefinition() == IsSet.getDefinition(); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionOrderingStrategy.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionOrderingStrategy.java new file mode 100644 index 00000000000..c0cd4fe900d --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionOrderingStrategy.java @@ -0,0 +1,43 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.List; +import java.util.Map; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.logic.ConditionInfo; + +/** + * Strategy interface for ordering conditions in a BDD. + */ +@FunctionalInterface +interface ConditionOrderingStrategy { + /** + * Orders the given conditions for BDD construction. + * + * @param conditions array of conditions to order + * @param conditionInfos metadata about each condition + * @return ordered list of conditions + */ + List orderConditions(Condition[] conditions, Map conditionInfos); + + /** + * Default ordering strategy that uses the existing ConditionOrderer. + * + * @return return the default ordering strategy. + */ + static ConditionOrderingStrategy defaultOrdering() { + return DefaultOrderingStrategy::orderConditions; + } + + /** + * Fixed ordering strategy that uses a pre-determined order. + * + * @return a fixed ordering strategy. + */ + static ConditionOrderingStrategy fixed(List ordering) { + return (conditions, infos) -> ordering; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategy.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategy.java new file mode 100644 index 00000000000..cd174a81b54 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategy.java @@ -0,0 +1,104 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.logic.ConditionInfo; + +/** + * Orders conditions for BDD construction while respecting variable dependencies. + * + *

The ordering ensures that: + *

    + *
  • Variables are defined before they are used
  • + *
  • isSet checks come before value checks for the same variable
  • + *
  • Simpler conditions are evaluated first (when dependencies allow)
  • + *
+ */ +final class DefaultOrderingStrategy { + + private DefaultOrderingStrategy() {} + + static List orderConditions(Condition[] conditions, Map conditionInfos) { + return sort(conditions, new ConditionDependencyGraph(conditionInfos), conditionInfos); + } + + private static List sort( + Condition[] conditions, + ConditionDependencyGraph deps, + Map infos + ) { + List result = new ArrayList<>(); + Set visited = new HashSet<>(); + Set visiting = new HashSet<>(); + + // Sort conditions by priority + List queue = new ArrayList<>(); + Collections.addAll(queue, conditions); + + queue.sort(Comparator + // fewer deps first + .comparingInt((Condition c) -> deps.getDependencies(c).size()) + // isSet() before everything else + .thenComparingInt(c -> c.getFunction().getFunctionDefinition() == IsSet.getDefinition() ? 0 : 1) + // variable-defining conditions first + .thenComparingInt(c -> infos.get(c).getReturnVariable() != null ? 0 : 1) + // fewer references first + .thenComparingInt(c -> infos.get(c).getReferences().size()) + // stable tie-breaker + .thenComparing(Condition::toString)); + + // Visit in priority order + for (Condition cond : queue) { + if (!visited.contains(cond)) { + visit(cond, deps, visited, visiting, result, infos); + } + } + + return result; + } + + private static void visit( + Condition cond, + ConditionDependencyGraph depGraph, + Set visited, + Set visiting, + List result, + Map infos + ) { + if (visiting.contains(cond)) { + throw new IllegalStateException("Circular dependency detected involving: " + cond); + } + + if (visited.contains(cond)) { + return; + } + + visiting.add(cond); + + // Visit dependencies first + Set deps = depGraph.getDependencies(cond); + if (!deps.isEmpty()) { + List sortedDeps = new ArrayList<>(deps); + sortedDeps.sort(Comparator.comparingInt(c -> infos.get(c).getComplexity())); + + for (Condition dep : sortedDeps) { + visit(dep, depGraph, visited, visiting, result, infos); + } + } + + visiting.remove(cond); + visited.add(cond); + result.add(cond); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java new file mode 100644 index 00000000000..1131e0d3173 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java @@ -0,0 +1,96 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.function.Function; +import java.util.logging.Logger; + +/** + * Reverses the node ordering in a BDD from bottom-up to top-down for better cache locality. + * + *

This transformation reverses the node array (except the terminal at index 0) + * and updates all references throughout the BDD to maintain correctness. + */ +public final class NodeReversal implements Function { + + private static final Logger LOGGER = Logger.getLogger(NodeReversal.class.getName()); + + @Override + public Bdd apply(Bdd bdd) { + LOGGER.info("Starting BDD node reversal optimization"); + int[][] nodes = bdd.getNodes(); + int nodeCount = nodes.length; + + if (nodeCount <= 2) { + return bdd; + } + + // Create the index mapping: old index -> new index + // Index 0 (terminal) stays at 0 + int[] oldToNew = new int[nodeCount]; + oldToNew[0] = 0; + + // Reverse indices for non-terminal nodes + for (int oldIdx = 1; oldIdx < nodeCount; oldIdx++) { + int newIdx = nodeCount - oldIdx; + oldToNew[oldIdx] = newIdx; + } + + // Create new node array with reversed order + int[][] newNodes = new int[nodeCount][]; + newNodes[0] = nodes[0].clone(); // Terminal stays at index 0 + + // Add nodes in reverse order, updating their references + int newIdx = 1; + for (int oldIdx = nodeCount - 1; oldIdx >= 1; oldIdx--) { + int[] oldNode = nodes[oldIdx]; + newNodes[newIdx++] = new int[] { + oldNode[0], // variable index stays the same + remapReference(oldNode[1], oldToNew), // remap high reference + remapReference(oldNode[2], oldToNew) // remap low reference + }; + } + + // Remap the root reference + int newRoot = remapReference(bdd.getRootRef(), oldToNew); + + LOGGER.info("BDD node reversal complete"); + + return new Bdd(bdd.getParameters(), bdd.getConditions(), bdd.getResults(), newNodes, newRoot); + } + + /** + * Remaps a reference through the index mapping. + * + * @param ref the reference to remap + * @param oldToNew the index mapping array + * @return the remapped reference + */ + private int remapReference(int ref, int[] oldToNew) { + // Handle special cases + if (ref == 0) { + return 0; // Invalid reference stays invalid + } else if (ref == 1 || ref == -1) { + return ref; // TRUE/FALSE terminals unchanged + } else if (ref >= Bdd.RESULT_OFFSET) { + return ref; // Result references are not remapped + } + + // Handle regular node references (with possible complement) + boolean isComplemented = ref < 0; + int absRef = isComplemented ? -ref : ref; + + // Convert from reference to index (1-based to 0-based) + int oldIdx = absRef - 1; + + if (oldIdx >= oldToNew.length) { + throw new IllegalStateException("Invalid reference: " + ref); + } + + int newIdx = oldToNew[oldIdx]; + int newRef = newIdx + 1; + return isComplemented ? -newRef : newRef; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraints.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraints.java new file mode 100644 index 00000000000..5f2ca2df997 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraints.java @@ -0,0 +1,92 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; + +/** + * Order-specific constraints derived from a dependency graph. + * + *

This class efficiently computes position-based constraints for a specific ordering of conditions, using the + * pre-computed dependency graph. It can be created cheaply for each new ordering during optimization. + */ +final class OrderConstraints { + private final Condition[] conditions; + private final Map conditionToIndex; + private final int[] minValidPosition; + private final int[] maxValidPosition; + + /** + * Creates order constraints for a specific ordering. + * + * @param graph the pre-computed dependency graph + * @param conditions the conditions in their specific order + */ + public OrderConstraints(ConditionDependencyGraph graph, List conditions) { + int n = conditions.size(); + if (n != graph.size()) { + throw new IllegalArgumentException( + "Condition count (" + n + ") doesn't match dependency graph size (" + graph.size() + ")"); + } + + this.conditions = conditions.toArray(new Condition[0]); + this.conditionToIndex = new HashMap<>(n * 2); + this.minValidPosition = new int[n]; + this.maxValidPosition = new int[n]; + + // Build index mapping + for (int i = 0; i < n; i++) { + conditionToIndex.put(this.conditions[i], i); + } + + // Build dependencies and compute valid positions in one pass + for (int i = 0; i < n; i++) { + maxValidPosition[i] = n - 1; // Initialize max position + for (Condition dep : graph.getDependencies(this.conditions[i])) { + Integer depIndex = conditionToIndex.get(dep); + if (depIndex != null) { + // This condition must come after its dependency + minValidPosition[i] = Math.max(minValidPosition[i], depIndex + 1); + // The dependency must come before this condition + maxValidPosition[depIndex] = Math.min(maxValidPosition[depIndex], i - 1); + } + } + } + } + + /** + * Checks if moving a condition from one position to another would violate dependencies. + * + * @param from current position + * @param to target position + * @return true if the move is valid + */ + public boolean canMove(int from, int to) { + return from == to || (to >= minValidPosition[from] && to <= maxValidPosition[from]); + } + + /** + * Gets the minimum valid position for a condition. + * + * @param conditionIndex the condition index + * @return the minimum position where this condition can be placed + */ + public int getMinValidPosition(int conditionIndex) { + return minValidPosition[conditionIndex]; + } + + /** + * Gets the maximum valid position for a condition. + * + * @param conditionIndex the condition index + * @return the maximum position where this condition can be placed + */ + public int getMaxValidPosition(int conditionIndex) { + return maxValidPosition[conditionIndex]; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java new file mode 100644 index 00000000000..8e4eb850512 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -0,0 +1,566 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ForkJoinPool; +import java.util.function.Function; +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.logic.ConditionInfo; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.logic.cfg.CfgNode; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionNode; +import software.amazon.smithy.utils.SmithyBuilder; + +/** + * BDD optimization using tiered parallel position evaluation with dependency-aware constraints. + * + *

This algorithm improves BDD size through a multi-stage approach: + *

    + *
  • Coarse optimization for large BDDs (fast reduction)
  • + *
  • Medium optimization for moderate BDDs (balanced approach)
  • + *
  • Granular optimization for small BDDs (maximum quality)
  • + *
+ * + *

Each stage runs until reaching its target size or maximum passes. + */ +public final class SiftingOptimization implements Function { + private static final Logger LOGGER = Logger.getLogger(SiftingOptimization.class.getName()); + + // Default thresholds and passes for each optimization level + private static final int DEFAULT_COARSE_MIN_NODES = 50_000; + private static final int DEFAULT_COARSE_MAX_PASSES = 5; + private static final int DEFAULT_MEDIUM_MIN_NODES = 10_000; + private static final int DEFAULT_MEDIUM_MAX_PASSES = 5; + private static final int DEFAULT_GRANULAR_MAX_NODES = 10_000; + private static final int DEFAULT_GRANULAR_MAX_PASSES = 8; + + // When a variable has fewer than this many valid positions, try them all. + private static final int EXHAUSTIVE_THRESHOLD = 20; + + // Thread-local BDD builders to avoid allocation overhead + private final ThreadLocal threadBuilder = ThreadLocal.withInitial(BddBuilder::new); + + private final Cfg cfg; + private final ConditionDependencyGraph dependencyGraph; + private final Map conditionInfos; + + // Tiered optimization settings + private final int coarseMinNodes; + private final int coarseMaxPasses; + private final int mediumMinNodes; + private final int mediumMaxPasses; + private final int granularMaxNodes; + private final int granularMaxPasses; + + // Internal effort levels for the tiered optimization stages. + private enum OptimizationEffort { + COARSE(12, 4, 0), + MEDIUM(2, 18, 5), + GRANULAR(1, 30, 6); + + final int sampleRate; + final int maxPositions; + final int nearbyRadius; + + OptimizationEffort(int sampleRate, int maxPositions, int nearbyRadius) { + this.sampleRate = sampleRate; + this.maxPositions = maxPositions; + this.nearbyRadius = nearbyRadius; + } + } + + private SiftingOptimization(Builder builder) { + this.cfg = SmithyBuilder.requiredState("cfg", builder.cfg); + this.coarseMinNodes = builder.coarseMinNodes; + this.coarseMaxPasses = builder.coarseMaxPasses; + this.mediumMinNodes = builder.mediumMinNodes; + this.mediumMaxPasses = builder.mediumMaxPasses; + this.granularMaxNodes = builder.granularMaxNodes; + this.granularMaxPasses = builder.granularMaxPasses; + + // Extract condition infos from CFG + this.conditionInfos = new LinkedHashMap<>(); + for (CfgNode node : cfg) { + if (node instanceof ConditionNode) { + ConditionInfo info = ((ConditionNode) node).getCondition(); + conditionInfos.put(info.getCondition(), info); + } + } + + this.dependencyGraph = new ConditionDependencyGraph(conditionInfos); + } + + /** + * Creates a new builder for SiftingOptimization. + * + * @return a new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + @Override + public Bdd apply(Bdd bdd) { + try { + return doApply(bdd); + } finally { + threadBuilder.remove(); + } + } + + private Bdd doApply(Bdd bdd) { + LOGGER.info("Starting BDD sifting optimization"); + long startTime = System.currentTimeMillis(); + + // Pre-spin the ForkJoinPool for better first-pass performance + ForkJoinPool.commonPool().submit(() -> {}).join(); + OptimizationState state = initializeOptimization(bdd); + LOGGER.info(String.format("Initial reordering: %d -> %d nodes", state.initialSize, state.currentSize)); + + state = runCoarseStage(state); + state = runMediumStage(state); + state = runGranularStage(state); + + double totalTimeInSeconds = (System.currentTimeMillis() - startTime) / 1000.0; + if (state.bestSize < state.initialSize) { + LOGGER.info(String.format("Optimization complete: %d -> %d nodes (%.1f%% total reduction) in %fs", + state.initialSize, + state.bestSize, + (1.0 - (double) state.bestSize / state.initialSize) * 100, + totalTimeInSeconds)); + } else { + LOGGER.info(String.format("No improvements found in %fs", totalTimeInSeconds)); + } + + return state.bestBdd; + } + + private OptimizationState initializeOptimization(Bdd bdd) { + // Start with an intelligent initial ordering + List initialOrder = DefaultOrderingStrategy.orderConditions( + bdd.getConditions().toArray(new Condition[0]), + conditionInfos); + + // Sanity check that ordering didn't lose conditions + if (initialOrder.size() != bdd.getConditions().size()) { + throw new IllegalStateException("Initial ordering changed condition count: " + + bdd.getConditions().size() + " -> " + initialOrder.size()); + } + + Condition[] order = initialOrder.toArray(new Condition[0]); + List orderView = Arrays.asList(order); + + // Build initial BDD with better ordering + Bdd currentBest = Bdd.from(cfg, new BddBuilder(), ConditionOrderingStrategy.fixed(orderView)); + int currentSize = currentBest.getNodes().length - 1; // -1 for terminal + int initialSize = bdd.getNodes().length - 1; + + return new OptimizationState(order, orderView, currentBest, currentSize, initialSize); + } + + private OptimizationState runCoarseStage(OptimizationState state) { + if (state.currentSize <= coarseMinNodes) { + return state; + } + return runOptimizationStage(state, "Coarse", OptimizationEffort.COARSE, coarseMinNodes, coarseMaxPasses, 4.0); + } + + private OptimizationState runMediumStage(OptimizationState state) { + if (state.currentSize <= mediumMinNodes) { + return state; + } + return runOptimizationStage(state, "Medium", OptimizationEffort.MEDIUM, mediumMinNodes, mediumMaxPasses, 1.5); + } + + private OptimizationState runGranularStage(OptimizationState state) { + if (state.currentSize > granularMaxNodes) { + LOGGER.info(String.format("Skipping granular stage - BDD too large (%d nodes > %d threshold)", + state.currentSize, + granularMaxNodes)); + return state; + } + + // Run with no minimums + state = runOptimizationStage(state, "Granular", OptimizationEffort.GRANULAR, 0, granularMaxPasses, 0.0); + + // Also perform adjacent swaps in granular stage + OptimizationResult swapResult = performAdjacentSwaps(state.order, state.orderView, state.currentSize); + if (swapResult.improved) { + LOGGER.info(String.format("Adjacent swaps: %d -> %d nodes", state.currentSize, swapResult.size)); + return state.withResult(swapResult.bdd, swapResult.size); + } + + return state; + } + + private OptimizationState runOptimizationStage( + OptimizationState state, + String stageName, + OptimizationEffort effort, + int targetNodeCount, + int maxPasses, + double minReductionPercent + ) { + + LOGGER.info(String.format("Stage: %s optimization (%d nodes%s)", + stageName, + state.currentSize, + targetNodeCount > 0 ? String.format(", target < %d", targetNodeCount) : "")); + + OptimizationState currentState = state; + + for (int pass = 1; pass <= maxPasses; pass++) { + // Stop if we've reached the target + if (targetNodeCount > 0 && currentState.currentSize <= targetNodeCount) { + break; + } + + int passStartSize = currentState.currentSize; + OptimizationResult result = runOptimizationPass( + currentState.order, + currentState.orderView, + currentState.currentSize, + effort); + + if (!result.improved) { + LOGGER.fine(String.format("%s pass %d found no improvements", stageName, pass)); + break; + } else { + currentState = currentState.withResult(result.bdd, result.size); + double reduction = (1.0 - (double) result.size / passStartSize) * 100; + LOGGER.fine(String.format("%s pass %d: %d -> %d nodes (%.1f%% reduction)", + stageName, + pass, + passStartSize, + result.size, + reduction)); + // Check for diminishing returns + if (minReductionPercent > 0 && reduction < minReductionPercent) { + LOGGER.fine(String.format("%s optimization yielding diminishing returns", stageName)); + break; + } + } + } + + return currentState; + } + + private OptimizationResult runOptimizationPass( + Condition[] order, + List orderView, + int currentSize, + OptimizationEffort effort + ) { + + int improvements = 0; + OrderConstraints constraints = new OrderConstraints(dependencyGraph, orderView); + Bdd bestBdd = null; + int bestSize = currentSize; + + // Sample variables based on effort level + for (int varIdx = 0; varIdx < order.length; varIdx += effort.sampleRate) { + List positions = getPositions(varIdx, constraints, effort); + if (positions.isEmpty()) { + continue; + } else if (positions.size() > effort.maxPositions) { + positions = positions.subList(0, effort.maxPositions); + } + + // Find best position + PositionCount best = findBestPosition(positions, order, bestSize, varIdx); + + if (best == null || best.count >= bestSize) { + continue; + } + + // Move to best position and build BDD once + move(order, varIdx, best.position); + Bdd newBdd = Bdd.from(cfg, new BddBuilder(), ConditionOrderingStrategy.fixed(orderView)); + int newSize = newBdd.getNodes().length - 1; + + if (newSize < bestSize) { + bestBdd = newBdd; + bestSize = newSize; + improvements++; + + // Update constraints after successful move + constraints = new OrderConstraints(dependencyGraph, orderView); + } + } + + return new OptimizationResult(bestBdd, bestSize, improvements > 0); + } + + private OptimizationResult performAdjacentSwaps(Condition[] order, List orderView, int currentSize) { + OrderConstraints constraints = new OrderConstraints(dependencyGraph, orderView); + Bdd bestBdd = null; + int bestSize = currentSize; + boolean improved = false; + + for (int i = 0; i < order.length - 1; i++) { + if (constraints.canMove(i, i + 1)) { + move(order, i, i + 1); + int swappedSize = countNodes(orderView); + if (swappedSize < bestSize) { + bestBdd = Bdd.from(cfg, new BddBuilder(), ConditionOrderingStrategy.fixed(orderView)); + bestSize = swappedSize; + improved = true; + } else { + // Swap back if no improvement + move(order, i + 1, i); + } + } + } + + return new OptimizationResult(bestBdd, bestSize, improved); + } + + private PositionCount findBestPosition( + List positions, + final Condition[] currentOrder, + final int currentSize, + final int varIdx + ) { + return positions.parallelStream() + .map(pos -> { + Condition[] threadOrder = currentOrder.clone(); + move(threadOrder, varIdx, pos); + int nodeCount = countNodes(Arrays.asList(threadOrder)); + return new PositionCount(pos, nodeCount); + }) + .filter(pc -> pc.count < currentSize) + .min(Comparator.comparingInt(pc -> pc.count)) + .orElse(null); + } + + private List getPositions(int varIdx, OrderConstraints constraints, OptimizationEffort effort) { + int min = constraints.getMinValidPosition(varIdx); + int max = constraints.getMaxValidPosition(varIdx); + int range = max - min; + return range <= EXHAUSTIVE_THRESHOLD + ? getExhaustivePositions(varIdx, min, max, constraints) + : getStrategicPositions(varIdx, min, max, range, constraints, effort); + } + + private List getExhaustivePositions(int varIdx, int min, int max, OrderConstraints constraints) { + List positions = new ArrayList<>(max - min); + for (int p = min; p < max; p++) { + if (p != varIdx && constraints.canMove(varIdx, p)) { + positions.add(p); + } + } + return positions; + } + + private List getStrategicPositions( + int varIdx, + int min, + int max, + int range, + OrderConstraints constraints, + OptimizationEffort effort + ) { + List positions = new ArrayList<>(effort.maxPositions); + + // Boundaries (these are most likely to be optimal) + if (min != varIdx && constraints.canMove(varIdx, min)) { + positions.add(min); + } + if (max - 1 != varIdx && constraints.canMove(varIdx, max - 1)) { + positions.add(max - 1); + } + + // Nearby positions (only if effort includes nearbyRadius) + if (effort.nearbyRadius > 0) { + for (int offset = -effort.nearbyRadius; offset <= effort.nearbyRadius; offset++) { + if (offset != 0) { + int p = varIdx + offset; + if (p >= min && p < max && !positions.contains(p) && constraints.canMove(varIdx, p)) { + positions.add(p); + } + } + } + } + + // Adaptive sampling: fewer samples for smaller ranges + int maxSamples = Math.min(15, effort.maxPositions / 2); + int samples = Math.min(maxSamples, Math.max(2, range / 4)); + int step = Math.max(1, range / samples); + + for (int p = min + step; p < max - step; p += step) { + if (p != varIdx && !positions.contains(p) && constraints.canMove(varIdx, p)) { + positions.add(p); + } + } + + return positions; + } + + /** + * Moves an element in an array from one position to another. + */ + private static void move(Condition[] arr, int from, int to) { + if (from == to) { + return; + } + + Condition moving = arr[from]; + if (from < to) { + // Moving right: shift elements left + System.arraycopy(arr, from + 1, arr, from, to - from); + } else { + // Moving left: shift elements right + System.arraycopy(arr, to, arr, to + 1, from - to); + } + arr[to] = moving; + } + + /** + * Counts nodes for a given ordering without keeping the BDD. + */ + private int countNodes(List ordering) { + BddBuilder builder = threadBuilder.get().reset(); + return Bdd.from(cfg, builder, ConditionOrderingStrategy.fixed(ordering)).getNodes().length - 1; + } + + // Position and its node count + private static final class PositionCount { + final int position; + final int count; + + PositionCount(int position, int count) { + this.position = position; + this.count = count; + } + } + + // Result of an optimization pass + private static final class OptimizationResult { + final Bdd bdd; + final int size; + final boolean improved; + + OptimizationResult(Bdd bdd, int size, boolean improved) { + this.bdd = bdd; + this.size = size; + this.improved = improved; + } + } + + // State tracking during optimization + private static final class OptimizationState { + final Condition[] order; + final List orderView; + final Bdd bestBdd; + final int currentSize; + final int bestSize; + final int initialSize; + + OptimizationState( + Condition[] order, + List orderView, + Bdd bestBdd, + int currentSize, + int initialSize + ) { + this.order = order; + this.orderView = orderView; + this.bestBdd = bestBdd; + this.currentSize = currentSize; + this.bestSize = currentSize; + this.initialSize = initialSize; + } + + OptimizationState withResult(Bdd newBdd, int newSize) { + return new OptimizationState(order, orderView, newBdd, newSize, initialSize); + } + } + + /** + * Builder for SiftingOptimization. + */ + public static final class Builder implements SmithyBuilder { + private Cfg cfg; + private int coarseMinNodes = DEFAULT_COARSE_MIN_NODES; + private int coarseMaxPasses = DEFAULT_COARSE_MAX_PASSES; + private int mediumMinNodes = DEFAULT_MEDIUM_MIN_NODES; + private int mediumMaxPasses = DEFAULT_MEDIUM_MAX_PASSES; + private int granularMaxNodes = DEFAULT_GRANULAR_MAX_NODES; + private int granularMaxPasses = DEFAULT_GRANULAR_MAX_PASSES; + + private Builder() {} + + /** + * Sets the required control flow graph to optimize. + * + * @param cfg the control flow graph + * @return this builder + */ + public Builder cfg(Cfg cfg) { + this.cfg = cfg; + return this; + } + + /** + * Sets the coarse optimization parameters. + * + *

Coarse optimization runs until the BDD has fewer than minNodeCount nodes + * or maxPasses have been completed. + * + * @param minNodeCount the target size to stop coarse optimization (default: 50,000) + * @param maxPasses the maximum number of coarse passes (default: 3) + * @return this builder + */ + public Builder coarseEffort(int minNodeCount, int maxPasses) { + this.coarseMinNodes = minNodeCount; + this.coarseMaxPasses = maxPasses; + return this; + } + + /** + * Sets the medium optimization parameters. + * + *

Medium optimization runs until the BDD has fewer than minNodeCount nodes + * or maxPasses have been completed. + * + * @param minNodeCount the target size to stop medium optimization (default: 10,000) + * @param maxPasses the maximum number of medium passes (default: 4) + * @return this builder + */ + public Builder mediumEffort(int minNodeCount, int maxPasses) { + this.mediumMinNodes = minNodeCount; + this.mediumMaxPasses = maxPasses; + return this; + } + + /** + * Sets the granular optimization parameters. + * + *

Granular optimization only runs if the BDD has fewer than maxNodeCount nodes, + * and runs for at most maxPasses. + * + * @param maxNodeCount the maximum size to attempt granular optimization (default: 3,000) + * @param maxPasses the maximum number of granular passes (default: 2) + * @return this builder + */ + public Builder granularEffort(int maxNodeCount, int maxPasses) { + this.granularMaxNodes = maxNodeCount; + this.granularMaxPasses = maxPasses; + return this; + } + + @Override + public SiftingOptimization build() { + return new SiftingOptimization(this); + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java new file mode 100644 index 00000000000..4d3df58ccfe --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java @@ -0,0 +1,245 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.utils.SmithyBuilder; + +/** + * A Control Flow Graph (CFG) representation of endpoint rule decision logic. + * + *

The CFG transforms the hierarchical decision tree structure into an optimized + * representation with node deduplication to prevent exponential growth. + * + *

The CFG consists of: + *

    + *
  • A root node representing the entry point of the decision logic
  • + *
  • A DAG structure where condition nodes are shared when they have identical subtrees
  • + *
+ */ +public final class Cfg implements Iterable { + + private final EndpointRuleSet ruleSet; + private final CfgNode root; + private ConditionData data; + + Cfg(EndpointRuleSet ruleSet, CfgNode root) { + this.ruleSet = ruleSet; + this.root = SmithyBuilder.requiredState("root", root); + } + + /** + * Create a CFG from the given ruleset. + * + * @param ruleSet Rules to convert to CFG. + * @return the CFG result. + */ + public static Cfg from(EndpointRuleSet ruleSet) { + CfgBuilder builder = new CfgBuilder(ruleSet); + CfgNode terminal = ResultNode.terminal(); + Map processedRules = new HashMap<>(); + CfgNode root = convertRulesToChain(builder.ruleSet.getRules(), terminal, builder, processedRules); + return builder.build(root); + } + + /** + * Get the condition data of the CFG. + * + * @return the lazily created and cached prepared condition data. + */ + public ConditionData getConditionData() { + ConditionData result = data; + if (result == null) { + result = ConditionData.from(this); + data = result; + } + return result; + } + + public EndpointRuleSet getRuleSet() { + return ruleSet; + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } else if (object == null || getClass() != object.getClass()) { + return false; + } else { + return root.equals(((Cfg) object).root); + } + } + + @Override + public int hashCode() { + return root.hashCode(); + } + + /** + * Returns the root node of the control flow graph. + * + * @return the root node + */ + public CfgNode getRoot() { + return root; + } + + @Override + public Iterator iterator() { + return new Iterator() { + private final Deque stack = new ArrayDeque<>(); + private final Set visited = new HashSet<>(); + private CfgNode next; + + { + if (root != null) { + stack.push(root); + } + advance(); + } + + private void advance() { + next = null; + while (!stack.isEmpty()) { + CfgNode node = stack.pop(); + if (visited.add(node)) { + // Push children before returning this node + if (node instanceof ConditionNode) { + ConditionNode cond = (ConditionNode) node; + stack.push(cond.getFalseBranch()); + stack.push(cond.getTrueBranch()); + } + next = node; + return; + } + } + } + + @Override + public boolean hasNext() { + return next != null; + } + + @Override + public CfgNode next() { + if (next == null) { + throw new NoSuchElementException(); + } + CfgNode result = next; + advance(); + return result; + } + }; + } + + // Converts a list of rules into a conditional chain. Each rule's false branch goes to the next rule. + private static CfgNode convertRulesToChain( + List rules, + CfgNode fallthrough, + CfgBuilder builder, + Map processedRules + ) { + // Make a reversed view of the rules list + List reversed = new ArrayList<>(rules); + Collections.reverse(reversed); + CfgNode next = fallthrough; + for (Rule rule : reversed) { + next = convertRule(rule, next, builder, processedRules); + } + return next; + } + + /** + * Converts a single rule to CFG nodes. + * + * @param rule the rule to convert + * @param fallthrough what to do if this rule doesn't match + * @param builder the CFG builder + * @param processedRules cache for processed rules + * @return the entry point for this rule + */ + private static CfgNode convertRule( + Rule rule, + CfgNode fallthrough, + CfgBuilder builder, + Map processedRules + ) { + RuleKey key = new RuleKey(rule, fallthrough); + CfgNode existing = processedRules.get(key); + if (existing != null) { + return existing; + } + + CfgNode body; + if (rule instanceof EndpointRule || rule instanceof ErrorRule) { + body = builder.createResult(rule); + } else if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + // Recursively convert nested rules with same fallthrough + body = convertRulesToChain(treeRule.getRules(), fallthrough, builder, processedRules); + } else { + throw new IllegalArgumentException("Unknown rule type: " + rule.getClass()); + } + + // Build conditions from last to first + CfgNode current = body; + for (int i = rule.getConditions().size() - 1; i >= 0; i--) { + Condition cond = rule.getConditions().get(i); + // For chained conditions (AND semantics), if one fails, we go to the fallthrough + current = builder.createCondition(cond, current, fallthrough); + } + + // Cache the result for this (rule, fallthrough) combination + processedRules.put(key, current); + + return current; + } + + private static final class RuleKey { + private final Rule rule; + private final CfgNode fallthrough; + private final int hashCode; + + RuleKey(Rule rule, CfgNode fallthrough) { + this.rule = rule; + this.fallthrough = fallthrough; + // Use identity hash for fallthrough since it's a node reference + this.hashCode = System.identityHashCode(rule) * 31 + System.identityHashCode(fallthrough); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (!(o instanceof RuleKey)) { + return false; + } + RuleKey that = (RuleKey) o; + return rule == that.rule && fallthrough == that.fallthrough; + } + + @Override + public int hashCode() { + return hashCode; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java new file mode 100644 index 00000000000..3275c060f44 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java @@ -0,0 +1,230 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.ConditionInfo; +import software.amazon.smithy.rulesengine.logic.ConditionReference; + +/** + * Builder for constructing Control Flow Graphs with node deduplication. + * + *

The builder performs simple hash-consing during top-down construction, deduplicating nodes when the same node + * would be created multiple times in the same context. + */ +public final class CfgBuilder { + final EndpointRuleSet ruleSet; + + // Simple hash-consing for nodes created in the same context + private final Map nodeCache = new HashMap<>(); + + // Condition and result canonicalization + private final Map conditionToInfo = new HashMap<>(); + private final Map conditionToReference = new HashMap<>(); + private final Map resultCache = new HashMap<>(); + private final Map resultNodeCache = new HashMap<>(); + + public CfgBuilder(EndpointRuleSet ruleSet) { + // Disambiguate conditions and references so variable names are globally unique. + this.ruleSet = SsaTransform.transform(ruleSet); + } + + /** + * Build the CFG with the given root node. + * + * @param root Root node to use for the built CFG. + * @return the built CFG. + */ + public Cfg build(CfgNode root) { + return new Cfg(ruleSet, Objects.requireNonNull(root)); + } + + /** + * Creates a condition node, reusing existing nodes when possible. + * + * @param condition the condition to evaluate + * @param trueBranch the node to evaluate when the condition is true + * @param falseBranch the node to evaluate when the condition is false + * @return a condition node (possibly cached) + */ + public CfgNode createCondition(Condition condition, CfgNode trueBranch, CfgNode falseBranch) { + return createCondition(createConditionReference(condition), trueBranch, falseBranch); + } + + /** + * Creates a condition node, reusing existing nodes when possible. + * + * @param condRef the condition reference to evaluate + * @param trueBranch the node to evaluate when the condition is true + * @param falseBranch the node to evaluate when the condition is false + * @return a condition node (possibly cached) + */ + public CfgNode createCondition(ConditionReference condRef, CfgNode trueBranch, CfgNode falseBranch) { + NodeSignature signature = new NodeSignature(condRef, trueBranch, falseBranch); + return nodeCache.computeIfAbsent(signature, key -> new ConditionNode(condRef, trueBranch, falseBranch)); + } + + /** + * Creates a result node representing a terminal rule evaluation. + * + * @param rule the result rule (endpoint or error) + * @return a result node (cached if identical rule already seen) + */ + public CfgNode createResult(Rule rule) { + Rule canonical = rule.withoutConditions(); + Rule interned = resultCache.computeIfAbsent(canonical, k -> k); + return resultNodeCache.computeIfAbsent(interned, ResultNode::new); + } + + /** + * Creates a canonical condition reference, handling negation and deduplication. + */ + public ConditionReference createConditionReference(Condition condition) { + // Check cache first + ConditionReference cached = conditionToReference.get(condition); + if (cached != null) { + return cached; + } + + // Check if it's a negation + boolean negated = false; + Condition canonical = condition; + + if (isNegationWrapper(condition)) { + negated = true; + canonical = unwrapNegation(condition); + + // Check if we already have the non-negated version + ConditionReference existing = conditionToReference.get(canonical); + if (existing != null) { + // Reuse the existing ConditionInfo, just negate the reference + ConditionReference negatedReference = existing.negate(); + conditionToReference.put(condition, negatedReference); + return negatedReference; + } + } + + // Canonicalize for commutative operations + canonical = canonical.canonicalize(); + + // Canonicalize boolean equals + Condition beforeBooleanCanon = canonical; + canonical = canonicalizeBooleanEquals(canonical); + + if (!canonical.equals(beforeBooleanCanon)) { + negated = !negated; + } + + // Get or create the ConditionInfo + ConditionInfo info = conditionToInfo.computeIfAbsent(canonical, ConditionInfo::from); + + // Create the reference (possibly negated) + ConditionReference reference = new ConditionReference(info, negated); + + // Cache the reference under the original key + conditionToReference.put(condition, reference); + + // Also cache under the canonical form if different + if (!negated && !condition.equals(canonical)) { + conditionToReference.put(canonical, reference); + } + + return reference; + } + + private Condition canonicalizeBooleanEquals(Condition condition) { + if (!(condition.getFunction() instanceof BooleanEquals)) { + return condition; + } + + List args = condition.getFunction().getArguments(); + + // After commutative canonicalization, if there's a reference, it's in position 0 + if (args.get(0) instanceof Reference && args.get(1) instanceof Literal) { + Reference ref = (Reference) args.get(0); + Boolean literalValue = ((Literal) args.get(1)).asBooleanLiteral().orElse(null); + + if (literalValue != null && !literalValue && ruleSet != null) { + String varName = ref.getName().toString(); + Optional param = ruleSet.getParameters().get(Identifier.of(varName)); + if (param.isPresent() && param.get().getDefault().isPresent()) { + // Convert booleanEquals(var, false) to booleanEquals(var, true) + return condition.toBuilder().fn(BooleanEquals.ofExpressions(ref, true)).build(); + } + } + } + + return condition; + } + + private static boolean isNegationWrapper(Condition condition) { + if (!(condition.getFunction() instanceof Not)) { + return false; + } else if (condition.getResult().isPresent()) { + return false; + } else { + return condition.getFunction().getArguments().get(0) instanceof LibraryFunction; + } + } + + private static Condition unwrapNegation(Condition negatedCondition) { + return negatedCondition.toBuilder() + .fn((LibraryFunction) negatedCondition.getFunction().getArguments().get(0)) + .build(); + } + + // Signature for node deduplication during construction. + private static final class NodeSignature { + private final ConditionReference condition; + private final CfgNode trueBranch; + private final CfgNode falseBranch; + private final int hashCode; + + NodeSignature(ConditionReference condition, CfgNode trueBranch, CfgNode falseBranch) { + this.condition = condition; + this.trueBranch = trueBranch; + this.falseBranch = falseBranch; + // Use identity hash for branches. + this.hashCode = Objects.hash( + condition, + System.identityHashCode(trueBranch), + System.identityHashCode(falseBranch)); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (!(o instanceof NodeSignature)) { + return false; + } + NodeSignature that = (NodeSignature) o; + // Reference equality for children + return Objects.equals(condition, that.condition) + && trueBranch == that.trueBranch + && falseBranch == that.falseBranch; + } + + @Override + public int hashCode() { + return hashCode; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgNode.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgNode.java new file mode 100644 index 00000000000..230a7d4483b --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgNode.java @@ -0,0 +1,13 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +/** + * Abstract base class for all nodes in a Control Flow Graph (CFG). + */ +public abstract class CfgNode { + // Package-private "sealed" class. + CfgNode() {} +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionData.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionData.java new file mode 100644 index 00000000000..4a8ee1e7a3b --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionData.java @@ -0,0 +1,64 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.logic.ConditionInfo; + +/** + * Extracts and indexes condition data from a CFG. + */ +public final class ConditionData { + private final Condition[] conditions; + private final Map conditionToIndex; + private final Map conditionInfos; + + private ConditionData(Condition[] conditions, Map index, Map infos) { + this.conditions = conditions; + this.conditionToIndex = index; + this.conditionInfos = infos; + } + + /** + * Extracts and indexes all conditions from a CFG. + * + * @param cfg the control flow graph to process + * @return ConditionData containing indexed conditions + */ + public static ConditionData from(Cfg cfg) { + List conditionList = new ArrayList<>(); + Map indexMap = new LinkedHashMap<>(); + Map infoMap = new HashMap<>(); + + for (CfgNode node : cfg) { + if (node instanceof ConditionNode) { + ConditionNode condNode = (ConditionNode) node; + ConditionInfo info = condNode.getCondition(); + Condition condition = info.getCondition(); + + if (!indexMap.containsKey(condition)) { + indexMap.put(condition, conditionList.size()); + conditionList.add(condition); + infoMap.put(condition, info); + } + } + } + + return new ConditionData(conditionList.toArray(new Condition[0]), indexMap, infoMap); + } + + public Condition[] getConditions() { + return conditions; + } + + public Map getConditionInfos() { + return conditionInfos; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionNode.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionNode.java new file mode 100644 index 00000000000..c6eb29b5fe0 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionNode.java @@ -0,0 +1,83 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.Objects; +import software.amazon.smithy.rulesengine.logic.ConditionReference; + +/** + * A CFG node that evaluates a condition and branches based on the result. + */ +public final class ConditionNode extends CfgNode { + + private final ConditionReference condition; + private final CfgNode trueBranch; + private final CfgNode falseBranch; + private final int hash; + + /** + * Creates a new condition node. + * + * @param condition condition reference (can be negated) + * @param trueBranch node to evaluate if the condition is true + * @param falseBranch node to evaluate if the condition is false + */ + public ConditionNode(ConditionReference condition, CfgNode trueBranch, CfgNode falseBranch) { + this.condition = Objects.requireNonNull(condition); + this.trueBranch = Objects.requireNonNull(trueBranch, "trueBranch must not be null"); + this.falseBranch = Objects.requireNonNull(falseBranch, "falseBranch must not be null"); + this.hash = Objects.hash(condition, trueBranch, falseBranch); + } + + /** + * Returns the condition reference for this node. + * + * @return the condition reference + */ + public ConditionReference getCondition() { + return condition; + } + + /** + * Returns the node to evaluate if the condition is true. + * + * @return the true branch node + */ + public CfgNode getTrueBranch() { + return trueBranch; + } + + /** + * Returns the node to evaluate if the condition is false. + * + * @return the false branch node + */ + public CfgNode getFalseBranch() { + return falseBranch; + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } else if (object == null || getClass() != object.getClass()) { + return false; + } + ConditionNode o = (ConditionNode) object; + return condition.equals(o.condition) && trueBranch.equals(o.trueBranch) && falseBranch.equals(o.falseBranch); + } + + @Override + public int hashCode() { + return hash; + } + + @Override + public String toString() { + return "ConditionNode{condition=" + condition + + ", trueBranch=" + System.identityHashCode(trueBranch) + + ", falseBranch=" + System.identityHashCode(falseBranch) + '}'; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ResultNode.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ResultNode.java new file mode 100644 index 00000000000..509e9ee6873 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ResultNode.java @@ -0,0 +1,65 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.Objects; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; + +/** + * A terminal CFG node that represents a final result, either an endpoint or error. + */ +public final class ResultNode extends CfgNode { + private final Rule result; + private final int hash; + private static final ResultNode TERMINAL = new ResultNode(); + + public ResultNode(Rule result) { + this.result = result; + this.hash = result == null ? 11 : result.hashCode(); + } + + private ResultNode() { + this(null); + } + + /** + * Returns a terminal node representing no match. + * + * @return the terminal result node + */ + public static ResultNode terminal() { + return TERMINAL; + } + + /** + * Get the underlying result. + * + * @return the result value. + */ + public Rule getResult() { + return result; + } + + @Override + public boolean equals(Object object) { + if (this == object) { + return true; + } else if (object == null || getClass() != object.getClass()) { + return false; + } else { + return Objects.equals(result, (((ResultNode) object)).result); + } + } + + @Override + public int hashCode() { + return hash; + } + + @Override + public String toString() { + return "ResultNode{hash=" + hash + ", result=" + result + '}'; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java new file mode 100644 index 00000000000..08956b65f93 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java @@ -0,0 +1,488 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.IdentityHashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionNode; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.RecordLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.TupleLiteral; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; + +/** + * Transforms a decision tree into Static Single Assignment (SSA) form. + * + *

This transformation ensures that each variable is assigned exactly once by renaming variables when they are + * reassigned in different parts of the tree. For example, if variable "x" is assigned in multiple branches, they + * become "x", "x_1", "x_2", etc. Without this transform, the BDD compilation would confuse divergent paths that have + * the same variable name. + * + *

Note that this transform is only applied when the reassignment is done using different + * arguments than previously seen assignments of the same variable name. + * + *

TODO: This transform does not yet introduce phi nodes at control flow merge points. + * We need to add an OR function to the rules engine to do that. + */ +final class SsaTransform { + + // Stack of scopes, each mapping original variable names to their current SSA names + private final Deque> scopeStack = new ArrayDeque<>(); + + // Cache of already rewritten conditions to avoid redundant work + private final Map rewrittenConditions = new IdentityHashMap<>(); + + // Cache of already rewritten rules + private final Map rewrittenRules = new IdentityHashMap<>(); + + // Set of input parameter names that should never be rewritten + private final Set inputParams; + + // Map from variable name -> expression -> SSA name + // Pre-computed to ensure consistent naming across the tree + private final Map> variableExpressionMappings; + + private SsaTransform(Set inputParams, Map> variableExpressionMappings) { + // Start with an empty global scope + scopeStack.push(new HashMap<>()); + this.inputParams = inputParams; + this.variableExpressionMappings = variableExpressionMappings; + } + + static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + Set inputParameters = extractInputParameters(ruleSet); + + // Collect all variable bindings and create unique names for each unique expression + Map> variableBindings = collectVariableBindings(ruleSet.getRules()); + Map> variableExpressionMappings = createExpressionMappings(variableBindings); + + // Rewrite with the pre-computed mappings + SsaTransform ssaTransform = new SsaTransform(inputParameters, variableExpressionMappings); + List rewrittenRules = new ArrayList<>(ruleSet.getRules().size()); + for (Rule original : ruleSet.getRules()) { + rewrittenRules.add(ssaTransform.processRule(original)); + } + + return EndpointRuleSet.builder() + .parameters(ruleSet.getParameters()) + .rules(rewrittenRules) + .version(ruleSet.getVersion()) + .build(); + } + + // Collect a set of input parameter names. We use these to know that expressions that only work with an input + // parameter and use the same arguments can be kept as-is rather than need a cloned and renamed assignment. + private static Set extractInputParameters(EndpointRuleSet ruleSet) { + Set inputParameters = new HashSet<>(); + for (Parameter param : ruleSet.getParameters()) { + inputParameters.add(param.getName().toString()); + } + return inputParameters; + } + + private static Map> collectVariableBindings(List rules) { + Map> variableBindings = new HashMap<>(); + collectBindingsFromRules(rules, variableBindings); + return variableBindings; + } + + private static void collectBindingsFromRules(List rules, Map> variableBindings) { + for (Rule rule : rules) { + collectBindingsFromRule(rule, variableBindings); + } + } + + private static void collectBindingsFromRule(Rule rule, Map> variableBindings) { + for (Condition condition : rule.getConditions()) { + if (condition.getResult().isPresent()) { + String varName = condition.getResult().get().toString(); + String expression = condition.getFunction().toString(); + variableBindings.computeIfAbsent(varName, k -> new HashSet<>()).add(expression); + } + } + + if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + collectBindingsFromRules(treeRule.getRules(), variableBindings); + } + } + + /** + * Creates a mapping from variable name -> expression -> SSA name. + * Variables assigned multiple times get unique SSA names (x, x_1, x_2, etc). + */ + private static Map> createExpressionMappings(Map> bindings) { + Map> result = new HashMap<>(); + for (Map.Entry> entry : bindings.entrySet()) { + String varName = entry.getKey(); + Set expressions = entry.getValue(); + result.put(varName, createMappingForVariable(varName, expressions)); + } + + return result; + } + + private static Map createMappingForVariable(String varName, Set expressions) { + Map mapping = new HashMap<>(); + + if (expressions.size() == 1) { + // Only one expression for this variable, so no SSA renaming needed + String expression = expressions.iterator().next(); + mapping.put(expression, varName); + } else { + // Multiple expressions, so create unique SSA names + List sortedExpressions = new ArrayList<>(expressions); + sortedExpressions.sort(String::compareTo); // Ensure deterministic ordering + + for (int i = 0; i < sortedExpressions.size(); i++) { + String expression = sortedExpressions.get(i); + String uniqueName = (i == 0) ? varName : varName + "_" + i; + mapping.put(expression, uniqueName); + } + } + + return mapping; + } + + private Rule processRule(Rule rule) { + enterScope(); + Rule rewrittenRule = rewriteRule(rule); + exitScope(); + return rewrittenRule; + } + + // Enters a new scope, inheriting all variable mappings from the parent scope + private void enterScope() { + scopeStack.push(new HashMap<>(scopeStack.peek())); + } + + // Exits the current scope, reverting to the parent scope's variable mappings + private void exitScope() { + if (scopeStack.size() <= 1) { + throw new IllegalStateException("Cannot exit global scope"); + } + scopeStack.pop(); + } + + // Rewrites a condition's bindings and references to use SSA names + private Condition rewriteCondition(Condition condition) { + boolean hasBinding = condition.getResult().isPresent(); + + // Check cache for non-binding conditions + if (!hasBinding) { + Condition cached = rewrittenConditions.get(condition); + if (cached != null) { + return cached; + } + } + + LibraryFunction fn = condition.getFunction(); + Set rewritableRefs = filterOutInputParameters(fn.getReferences()); + + // Determine if this binding needs an SSA name + String uniqueBindingName = null; + boolean needsUniqueBinding = false; + if (hasBinding) { + String varName = condition.getResult().get().toString(); + Map expressionMap = variableExpressionMappings.get(varName); + if (expressionMap != null) { + uniqueBindingName = expressionMap.get(fn.toString()); + needsUniqueBinding = uniqueBindingName != null && !uniqueBindingName.equals(varName); + } + } + + // Early return if no rewriting needed + if (!needsRewriting(rewritableRefs) && !needsUniqueBinding) { + if (!hasBinding) { + rewrittenConditions.put(condition, condition); + } + return condition; + } + + // Rewrite the expression + LibraryFunction rewrittenExpr = (LibraryFunction) rewriteExpression(fn); + boolean exprChanged = rewrittenExpr != fn; + + // Build the rewritten condition + Condition rewritten; + if (hasBinding && uniqueBindingName != null) { + // Update scope with the SSA name + scopeStack.peek().put(condition.getResult().get().toString(), uniqueBindingName); + + if (needsUniqueBinding || exprChanged) { + rewritten = condition.toBuilder().fn(rewrittenExpr).result(Identifier.of(uniqueBindingName)).build(); + } else { + rewritten = condition; + } + } else if (exprChanged) { + rewritten = condition.toBuilder().fn(rewrittenExpr).build(); + } else { + rewritten = condition; + } + + // Cache non-binding conditions + if (!hasBinding) { + rewrittenConditions.put(condition, rewritten); + } + + return rewritten; + } + + private Set filterOutInputParameters(Set references) { + if (references.isEmpty() || inputParams.isEmpty()) { + return references; + } + Set filtered = new HashSet<>(references); + filtered.removeAll(inputParams); + return filtered; + } + + // Check if any references in scope need to be rewritten to SSA names + private boolean needsRewriting(Set references) { + if (references.isEmpty()) { + return false; + } + + Map currentScope = scopeStack.peek(); + for (String ref : references) { + String mapped = currentScope.get(ref); + if (mapped != null && !mapped.equals(ref)) { + return true; + } + } + return false; + } + + private boolean needsRewriting(Expression expression) { + return needsRewriting(filterOutInputParameters(expression.getReferences())); + } + + // Rewrites a rule's conditions to use SSA names + private Rule rewriteRule(Rule rule) { + Rule cached = rewrittenRules.get(rule); + if (cached != null) { + return cached; + } + + List rewrittenConditions = rewriteConditions(rule.getConditions()); + boolean conditionsChanged = !rewrittenConditions.equals(rule.getConditions()); + + Rule result; + if (rule instanceof EndpointRule) { + result = rewriteEndpointRule((EndpointRule) rule, rewrittenConditions, conditionsChanged); + } else if (rule instanceof ErrorRule) { + result = rewriteErrorRule((ErrorRule) rule, rewrittenConditions, conditionsChanged); + } else if (rule instanceof TreeRule) { + result = rewriteTreeRule((TreeRule) rule, rewrittenConditions, conditionsChanged); + } else if (conditionsChanged) { + throw new UnsupportedOperationException("Cannot change rule: " + rule); + } else { + result = rule; + } + + rewrittenRules.put(rule, result); + return result; + } + + private List rewriteConditions(List conditions) { + List rewritten = new ArrayList<>(conditions.size()); + for (Condition condition : conditions) { + rewritten.add(rewriteCondition(condition)); + } + return rewritten; + } + + private Rule rewriteEndpointRule( + EndpointRule rule, + List rewrittenConditions, + boolean conditionsChanged + ) { + Endpoint endpoint = rule.getEndpoint(); + + // Rewrite endpoint components to use SSA names + Expression rewrittenUrl = rewriteExpression(endpoint.getUrl()); + Map> rewrittenHeaders = rewriteHeaders(endpoint.getHeaders()); + Map rewrittenProperties = rewriteProperties(endpoint.getProperties()); + + boolean endpointChanged = rewrittenUrl != endpoint.getUrl() + || !rewrittenHeaders.equals(endpoint.getHeaders()) + || !rewrittenProperties.equals(endpoint.getProperties()); + + if (conditionsChanged || endpointChanged) { + return EndpointRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(rewrittenConditions) + .endpoint(Endpoint.builder() + .url(rewrittenUrl) + .headers(rewrittenHeaders) + .properties(rewrittenProperties) + .build()); + } + + return rule; + } + + private Rule rewriteErrorRule(ErrorRule rule, List rewrittenConditions, boolean conditionsChanged) { + Expression rewrittenError = rewriteExpression(rule.getError()); + + if (conditionsChanged || rewrittenError != rule.getError()) { + return ErrorRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(rewrittenConditions) + .error(rewrittenError); + } + + return rule; + } + + private Rule rewriteTreeRule(TreeRule rule, List rewrittenConditions, boolean conditionsChanged) { + List rewrittenNestedRules = new ArrayList<>(); + boolean nestedChanged = false; + + for (Rule nestedRule : rule.getRules()) { + enterScope(); + Rule rewritten = rewriteRule(nestedRule); + rewrittenNestedRules.add(rewritten); + if (rewritten != nestedRule) { + nestedChanged = true; + } + exitScope(); + } + + if (conditionsChanged || nestedChanged) { + return TreeRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(rewrittenConditions) + .treeRule(rewrittenNestedRules); + } + + return rule; + } + + private Map> rewriteHeaders(Map> headers) { + Map> rewritten = new LinkedHashMap<>(); + for (Map.Entry> entry : headers.entrySet()) { + List rewrittenValues = new ArrayList<>(); + for (Expression expr : entry.getValue()) { + rewrittenValues.add(rewriteExpression(expr)); + } + rewritten.put(entry.getKey(), rewrittenValues); + } + return rewritten; + } + + private Map rewriteProperties(Map properties) { + Map rewritten = new LinkedHashMap<>(); + for (Map.Entry entry : properties.entrySet()) { + Expression rewrittenExpr = rewriteExpression(entry.getValue()); + if (!(rewrittenExpr instanceof Literal)) { + throw new IllegalStateException("Property value must be a literal"); + } + rewritten.put(entry.getKey(), (Literal) rewrittenExpr); + } + return rewritten; + } + + // Recursively rewrites an expression to use SSA names + private Expression rewriteExpression(Expression expression) { + if (!needsRewriting(expression)) { + return expression; + } + + if (expression instanceof StringLiteral) { + return rewriteStringLiteral((StringLiteral) expression); + } else if (expression instanceof TupleLiteral) { + return rewriteTupleLiteral((TupleLiteral) expression); + } else if (expression instanceof RecordLiteral) { + return rewriteRecordLiteral((RecordLiteral) expression); + } else if (expression instanceof Reference) { + return rewriteReference((Reference) expression); + } else if (expression instanceof LibraryFunction) { + return rewriteLibraryFunction((LibraryFunction) expression); + } + + return expression; + } + + private Expression rewriteStringLiteral(StringLiteral str) { + Template template = str.value(); + if (template.isStatic()) { + return str; + } + + StringBuilder templateBuilder = new StringBuilder(); + for (Template.Part part : template.getParts()) { + if (part instanceof Template.Dynamic) { + Template.Dynamic dynamic = (Template.Dynamic) part; + Expression rewritten = rewriteExpression(dynamic.toExpression()); + templateBuilder.append('{').append(rewritten).append('}'); + } else { + templateBuilder.append(((Template.Literal) part).getValue()); + } + } + return Literal.stringLiteral(Template.fromString(templateBuilder.toString())); + } + + private Expression rewriteTupleLiteral(TupleLiteral tuple) { + List rewrittenMembers = new ArrayList<>(); + for (Literal member : tuple.members()) { + rewrittenMembers.add((Literal) rewriteExpression(member)); + } + return Literal.tupleLiteral(rewrittenMembers); + } + + private Expression rewriteRecordLiteral(RecordLiteral record) { + Map rewrittenMembers = new LinkedHashMap<>(); + for (Map.Entry entry : record.members().entrySet()) { + rewrittenMembers.put(entry.getKey(), (Literal) rewriteExpression(entry.getValue())); + } + return Literal.recordLiteral(rewrittenMembers); + } + + private Expression rewriteReference(Reference ref) { + String originalName = ref.getName().toString(); + String uniqueName = resolveReference(originalName); + return Expression.getReference(Identifier.of(uniqueName)); + } + + private Expression rewriteLibraryFunction(LibraryFunction fn) { + List rewrittenArgs = new ArrayList<>(fn.getArguments()); + rewrittenArgs.replaceAll(this::rewriteExpression); + FunctionNode node = FunctionNode.builder() + .name(Node.from(fn.getName())) + .arguments(rewrittenArgs) + .build(); + return fn.getFunctionDefinition().createFunction(node); + } + + private String resolveReference(String originalName) { + // Input parameters are never rewritten + return inputParams.contains(originalName) + ? originalName + : scopeStack.peek().getOrDefault(originalName, originalName); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/BddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/BddTrait.java new file mode 100644 index 00000000000..0386860e040 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/BddTrait.java @@ -0,0 +1,77 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.traits; + +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.model.traits.AbstractTrait; +import software.amazon.smithy.model.traits.AbstractTraitBuilder; +import software.amazon.smithy.model.traits.Trait; +import software.amazon.smithy.rulesengine.logic.bdd.Bdd; +import software.amazon.smithy.utils.SmithyBuilder; +import software.amazon.smithy.utils.SmithyUnstableApi; +import software.amazon.smithy.utils.ToSmithyBuilder; + +/*** + * Defines an endpoint rule-set using a binary decision diagram (BDD) used to resolve the client's transport endpoint. + */ +@SmithyUnstableApi +public final class BddTrait extends AbstractTrait implements ToSmithyBuilder { + public static final ShapeId ID = ShapeId.from("smithy.rules#bdd"); + + private final Bdd bdd; + + private BddTrait(Builder builder) { + super(ID, builder.getSourceLocation()); + bdd = SmithyBuilder.requiredState("bdd", builder.bdd); + } + + public static Builder builder() { + return new Builder(); + } + + public Bdd getBdd() { + return bdd; + } + + @Override + protected Node createNode() { + return bdd.toNode(); + } + + @Override + public Builder toBuilder() { + return builder().sourceLocation(getSourceLocation()).bdd(bdd); + } + + public static final class Provider extends AbstractTrait.Provider { + public Provider() { + super(ID); + } + + @Override + public Trait createTrait(ShapeId target, Node value) { + BddTrait trait = builder().sourceLocation(value).bdd(Bdd.fromNode(value)).build(); + trait.setNodeCache(value); + return trait; + } + } + + public static final class Builder extends AbstractTraitBuilder { + private Bdd bdd; + + private Builder() {} + + public Builder bdd(Bdd bdd) { + this.bdd = bdd; + return this; + } + + @Override + public BddTrait build() { + return new BddTrait(this); + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/BddTraitValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/BddTraitValidator.java new file mode 100644 index 00000000000..e8dc6fe0ada --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/BddTraitValidator.java @@ -0,0 +1,112 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.validators; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.validation.AbstractValidator; +import software.amazon.smithy.model.validation.ValidationEvent; +import software.amazon.smithy.rulesengine.logic.bdd.Bdd; +import software.amazon.smithy.rulesengine.traits.BddTrait; + +public final class BddTraitValidator extends AbstractValidator { + @Override + public List validate(Model model) { + if (!model.isTraitApplied(BddTrait.class)) { + return Collections.emptyList(); + } + + List events = new ArrayList<>(); + for (ServiceShape service : model.getServiceShapesWithTrait(BddTrait.class)) { + validateService(events, service, service.expectTrait(BddTrait.class)); + } + + return events; + } + + private void validateService(List events, ServiceShape service, BddTrait trait) { + Bdd bdd = trait.getBdd(); + + // Validate root reference + int rootRef = bdd.getRootRef(); + if (Bdd.isComplemented(rootRef)) { + events.add(error(service, trait, "Root reference cannot be complemented: " + rootRef)); + } + validateReference(events, service, trait, "Root", rootRef, bdd); + + // Validate node references + int[][] nodes = bdd.getNodes(); + for (int i = 0; i < nodes.length; i++) { + // Skip terminal node at index 0 + if (i == 0) { + continue; + } + + // Guard against malformed nodes array + if (nodes[i] == null || nodes[i].length != 3) { + events.add(error(service, trait, String.format("Node %d is malformed", i))); + continue; + } + + int[] node = nodes[i]; + int varIdx = node[0]; + int highRef = node[1]; + int lowRef = node[2]; + + if (varIdx < 0 || varIdx >= bdd.getConditionCount()) { + events.add(error(service, + trait, + String.format( + "Node %d has invalid variable index %d (condition count: %d)", + i, + varIdx, + bdd.getConditionCount()))); + } + + validateReference(events, service, trait, String.format("Node %d high", i), highRef, bdd); + validateReference(events, service, trait, String.format("Node %d low", i), lowRef, bdd); + } + } + + private void validateReference( + List events, + ServiceShape service, + BddTrait trait, + String context, + int ref, + Bdd bdd + ) { + if (ref == 0) { + events.add(error(service, trait, String.format("%s has invalid reference: 0", context))); + } else if (Bdd.isNodeReference(ref)) { + int nodeIndex = Math.abs(ref) - 1; + if (nodeIndex >= bdd.getNodes().length) { + events.add(error(service, + trait, + String.format( + "%s reference %d points to non-existent node %d (node count: %d)", + context, + ref, + nodeIndex, + bdd.getNodes().length))); + } + } else if (Bdd.isResultReference(ref)) { + int resultIndex = ref - Bdd.RESULT_OFFSET; + if (resultIndex >= bdd.getResults().size()) { + events.add(error(service, + trait, + String.format( + "%s reference %d points to non-existent result %d (result count: %d)", + context, + ref, + resultIndex, + bdd.getResults().size()))); + } + } + } +} diff --git a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService index 353f3b95e8e..02900906528 100644 --- a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService +++ b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService @@ -4,3 +4,4 @@ software.amazon.smithy.rulesengine.traits.StaticContextParamsTrait$Provider software.amazon.smithy.rulesengine.traits.OperationContextParamsTrait$Provider software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait$Provider software.amazon.smithy.rulesengine.traits.EndpointTestsTrait$Provider +software.amazon.smithy.rulesengine.traits.BddTrait$Provider diff --git a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator index 2f24c94156b..fa5f972e343 100644 --- a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator +++ b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator @@ -7,3 +7,4 @@ software.amazon.smithy.rulesengine.validators.RuleSetUriValidator software.amazon.smithy.rulesengine.validators.RuleSetParamMissingDocsValidator software.amazon.smithy.rulesengine.validators.RuleSetParameterValidator software.amazon.smithy.rulesengine.validators.RuleSetTestCaseValidator +software.amazon.smithy.rulesengine.validators.BddTraitValidator diff --git a/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy b/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy index c475281bc3f..5926d473029 100644 --- a/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy +++ b/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy @@ -7,9 +7,203 @@ namespace smithy.rules @trait(selector: "service") document endpointRuleSet +/// Defines an endpoint rule-set using a binary decision diagram (BDD). +@unstable +@trait(selector: "service") +structure bdd { + /// A map of zero or more endpoint parameter names to their parameter configuration. + @required + parameters: Parameters + + /// An ordered list of unique conditions used throughout the BDD. + @required + conditions: Conditions + + /// An ordered list of results referenced by BDD nodes. The first result is always the terminal node. + @required + results: Results + + /// The root node of where to start evaluating the BDD. + @required + @range(min: -1) + root: Integer + + /// The number of nodes contained in the BDD. + @required + @range(min: 0) + nodeCount: Integer + + /// Base64-encoded array of BDD nodes representing the decision graph structure. + /// + /// The first node (index 0) is always the terminal node `[-1, 1, -1]` and is included in the nodeCount. + /// User-defined nodes start at index 1. + /// + /// Zig-zag encoding transforms signed integers to unsigned: + /// - 0 -> 0, -1 → 1, 1 → 2, -2 → 3, 2 → 4, etc. + /// - Formula: `(n << 1) ^ (n >> 31)` + /// - This ensures small negative numbers use few bytes + /// + /// Each node consists of three varint-encoded integers written sequentially: + /// 1. variable index + /// 2. high reference (when condition is true) + /// 3. low reference (when condition is false) + /// + /// Node Structure [variable, high, low]: + /// - variable: The index of the condition being tested (0 to conditionCount-1) + /// - high: Reference to follow when the condition evaluates to true + /// - low: Reference to follow when the condition evaluates to false + /// + /// Reference Encoding: + /// - 0: Invalid/unused reference (never appears in valid BDDs) + /// - 1: TRUE terminal (treated as "no match" in endpoint resolution) + /// - -1: FALSE terminal (treated as "no match" in endpoint resolution) + /// - 2, 3, 4, ...: Node references pointing to nodes[ref-1] + /// - -2, -3, -4, ...: Complement node references (logical NOT of nodes[abs(ref)-1]) + /// - 2000000+: Result terminals (2000000 + resultIndex) + /// + /// Complement edges: + /// A negative reference represents the logical NOT of the referenced node's entire subgraph. So `-5` means the + /// complement of node 5 (located in the array at index 4, since `index = |ref| - 1`). In this case, evaluate the + /// condition referenced by node 4, and if it is TRUE, use the low reference, and if it's FALSE, use the high + /// reference. This optimization significantly reduces BDD size by allowing a single subgraph to represent both a + /// boolean function and its complement; instead of creating separate nodes for `condition AND other` and + /// `NOT(condition AND other)`, we can reuse the same nodes with complement edges. Complement edges cannot be + /// used on result terminals. + /// + /// Example (before encoding): + /// ``` + /// nodes = [ + /// [ -1, 1, -1], // 0: terminal node + /// [ 0, 3, 2], // 1: if condition[0] then node 3, else node 2 + /// [ 1, 2000001, -1], // 2: if condition[1] then result[1], else FALSE + /// ] + /// ``` + /// + /// After zig-zag + varint + base64: `"AQEBAAYEBAGBwOgPAQ=="` + @required + nodes: String +} + +@private +map Parameters { + key: String + value: Parameter +} + +/// A rules input parameter. +@private +structure Parameter { + /// The parameter type. + @required + type: ParameterType + + /// True if the parameter is deprecated. + deprecated: Boolean + + /// Documentation about the parameter. + documentation: String + + /// Specifies the default value for the parameter if not set. + /// Parameters with defaults MUST also be marked as required. The type of the provided default MUST match type. + default: Document + + /// Specifies a named built-in value that is sourced and provided to the endpoint provider by a caller. + builtIn: String + + /// Specifies that the parameter is required to be provided to the endpoint provider. + required: Boolean +} + +/// The kind of parameter. +enum ParameterType { + STRING = "string" + BOOLEAN = "boolean" + STRING_ARRAY = "stringArray" +} + +@private +list Conditions { + member: Condition +} + +@private +structure Condition { + /// The name of the function to be executed. + @required + fn: String + + /// The arguments for the function. + /// An array of one or more of the following types: string, bool, array, Reference object, or Function object + @required + argv: DocumentList + + /// The optional destination variable to assign the functions result to. + assign: String +} + +@private +list DocumentList { + member: Document +} + +@private +list Results { + member: Result +} + +@private +structure Result { + /// Result type. + @required + type: ResultType + + /// An optional description of the result. + documentation: String + + /// Provided if type is "error". + error: Document + + /// Provided if type is "endpoint". + endpoint: EndpointObject +} + +@private +enum ResultType { + ENDPOINT = "endpoint" + ERROR = "error" +} + +@private +structure EndpointObject { + /// The endpoint url. This MUST specify a scheme and hostname and MAY contain port and base path components. + /// A string value MAY be a Template string. Any value for this property MUST resolve to a string. + @required + url: Document + + /// A map containing zero or more key value property pairs. Endpoint properties MAY be arbitrarily deep and + /// contain other maps and arrays. + properties: EndpointProperties + + /// A map of transport header names to their respective values. A string value in an array MAY be a + /// template string. + headers: EndpointObjectHeaders +} + +@private +map EndpointProperties { + key: String + value: Document +} + +@private +map EndpointObjectHeaders { + key: String + value: DocumentList +} + /// Defines endpoint test-cases for validating a client's endpoint rule-set. @unstable -@trait(selector: "service[trait|smithy.rules#endpointRuleSet]") +@trait(selector: "service :is([trait|smithy.rules#endpointRuleSet], [trait|smithy.rules#bdd])") structure endpointTests { /// The endpoint tests schema version. @required diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/value/ToObjectTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/value/ToObjectTest.java new file mode 100644 index 00000000000..6bec1847317 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/value/ToObjectTest.java @@ -0,0 +1,87 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.value; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNull; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.evaluation.value.Value; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.utils.ListUtils; +import software.amazon.smithy.utils.MapUtils; + +class ToObjectTest { + @Test + void testStringValueToObject() { + Value value = Value.stringValue("hello"); + + assertEquals("hello", value.toObject()); + } + + @Test + void testIntegerValueToObject() { + Value value = Value.integerValue(42); + + assertEquals(42, value.toObject()); + } + + @Test + void testBooleanValueToObject() { + assertEquals(Boolean.TRUE, Value.booleanValue(true).toObject()); + assertEquals(Boolean.FALSE, Value.booleanValue(false).toObject()); + } + + @Test + void testEmptyValueToObject() { + Value value = Value.emptyValue(); + + assertNull(value.toObject()); + } + + @Test + void testArrayValueToObject() { + Value arrayValue = Value.arrayValue(ListUtils.of( + Value.stringValue("a"), + Value.integerValue(1), + Value.booleanValue(true))); + + Object result = arrayValue.toObject(); + assertInstanceOf(List.class, result); + + List list = (List) result; + assertEquals(3, list.size()); + assertEquals("a", list.get(0)); + assertEquals(1, list.get(1)); + assertEquals(true, list.get(2)); + } + + @Test + void testRecordValueToObject() { + Map map = new LinkedHashMap<>(); + map.put(Identifier.of("name"), Value.stringValue("test")); + map.put(Identifier.of("count"), Value.integerValue(5)); + map.put(Identifier.of("enabled"), Value.booleanValue(true)); + + Value recordValue = Value.recordValue(map); + Object result = recordValue.toObject(); + + assertInstanceOf(Map.class, result); + Map resultMap = (Map) result; + assertEquals("test", resultMap.get("name")); + assertEquals(5, resultMap.get("count")); + assertEquals(true, resultMap.get("enabled")); + } + + @Test + void testEmptyCollections() { + assertEquals(ListUtils.of(), Value.arrayValue(ListUtils.of()).toObject()); + assertEquals(MapUtils.of(), Value.recordValue(MapUtils.of()).toObject()); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImplTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImplTest.java new file mode 100644 index 00000000000..95c94170d0f --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImplTest.java @@ -0,0 +1,126 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.BooleanLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; + +public class ConditionInfoImplTest { + + @Test + void testSimpleIsSetCondition() { + Condition condition = Condition.builder() + .fn(IsSet.ofExpressions(Literal.of("{Region}"))) + .build(); + + ConditionInfo info = ConditionInfo.from(condition); + + assertEquals(condition, info.getCondition()); + + assertEquals(4, info.getComplexity()); + assertEquals(1, info.getReferences().size()); + assertTrue(info.getReferences().contains("Region")); + assertNull(info.getReturnVariable()); + } + + @Test + void testConditionWithVariableBinding() { + Condition condition = Condition.builder() + .fn(IsSet.ofExpressions(Literal.of("{Region}"))) + .result(Identifier.of("RegionExists")) + .build(); + + ConditionInfo info = ConditionInfo.from(condition); + + assertEquals("RegionExists", info.getReturnVariable()); + } + + @Test + void testComplexNestedCondition() { + // Test nested function calls + Condition condition = Condition.builder() + .fn(Not.ofExpressions( + BooleanEquals.ofExpressions( + IsSet.ofExpressions(Literal.of("{Region}")), + BooleanLiteral.of(true)))) + .build(); + + ConditionInfo info = ConditionInfo.from(condition); + + assertEquals(11, info.getComplexity()); + assertEquals(1, info.getReferences().size()); + assertTrue(info.getReferences().contains("Region")); + } + + @Test + void testTemplateStringComplexity() { + Condition condition = Condition.builder() + .fn(StringEquals.ofExpressions( + Literal.of("{Endpoint}"), + StringLiteral.of("https://{Service}.{Region}.amazonaws.com"))) + .build(); + + ConditionInfo info = ConditionInfo.from(condition); + + assertTrue(info.getComplexity() > StringEquals.getDefinition().getCostHeuristic()); + assertEquals(3, info.getReferences().size()); + assertTrue(info.getReferences().contains("Endpoint")); + assertTrue(info.getReferences().contains("Service")); + assertTrue(info.getReferences().contains("Region")); + } + + @Test + void testGetAttrNestedComplexity() { + Condition condition = Condition.builder() + .fn(GetAttr.ofExpressions(GetAttr.ofExpressions(Literal.of("{ComplexObject}"), "nested"), "value")) + .build(); + + ConditionInfo info = ConditionInfo.from(condition); + + assertEquals(8, info.getComplexity()); + assertEquals(1, info.getReferences().size()); + assertTrue(info.getReferences().contains("ComplexObject")); + } + + @Test + void testEquals() { + Condition condition1 = Condition.builder().fn(IsSet.ofExpressions(Literal.of("{Region}"))).build(); + Condition condition2 = Condition.builder().fn(IsSet.ofExpressions(Literal.of("{Region}"))).build(); + + ConditionInfo info1 = ConditionInfo.from(condition1); + ConditionInfo info2 = ConditionInfo.from(condition2); + + assertEquals(info1, info2); + assertEquals(info1.hashCode(), info2.hashCode()); + } + + @Test + void testToString() { + Condition condition = Condition.builder() + .fn(IsSet.ofExpressions(Literal.of("{Region}"))) + .result(Identifier.of("RegionExists")) + .build(); + + ConditionInfo info = ConditionInfo.from(condition); + + String str = info.toString(); + assertTrue(str.contains("isSet")); + assertTrue(str.contains("Region")); + assertTrue(str.contains("RegionExists")); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionReferenceTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionReferenceTest.java new file mode 100644 index 00000000000..1430c15b599 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionReferenceTest.java @@ -0,0 +1,138 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; + +public class ConditionReferenceTest { + + private ConditionInfo baseConditionInfo; + private Condition simpleCondition; + + @BeforeEach + void setUp() { + simpleCondition = Condition.builder().fn(IsSet.ofExpressions(Literal.of("{Region}"))).build(); + baseConditionInfo = ConditionInfo.from(simpleCondition); + } + + @Test + void testBasicConstruction() { + ConditionReference ref = new ConditionReference(baseConditionInfo, false); + + assertFalse(ref.isNegated()); + assertEquals(simpleCondition, ref.getCondition()); + } + + @Test + void testNegatedConstruction() { + ConditionReference ref = new ConditionReference(baseConditionInfo, true); + + assertTrue(ref.isNegated()); + assertEquals(simpleCondition, ref.getCondition()); + } + + @Test + void testNegateMethod() { + ConditionReference ref = new ConditionReference(baseConditionInfo, false); + ConditionReference negated = ref.negate(); + + assertFalse(ref.isNegated()); + assertTrue(negated.isNegated()); + assertEquals(ref.getCondition(), negated.getCondition()); + } + + @Test + void testDoubleNegation() { + ConditionReference ref = new ConditionReference(baseConditionInfo, false); + ConditionReference doubleNegated = ref.negate().negate(); + + assertFalse(doubleNegated.isNegated()); + assertEquals(ref.getCondition(), doubleNegated.getCondition()); + } + + @Test + void testGetReturnVariable() { + Condition condWithVariable = Condition.builder() + .fn(IsSet.ofExpressions(Literal.of("{Region}"))) + .result(Identifier.of("RegionSet")) + .build(); + + ConditionInfo info = ConditionInfo.from(condWithVariable); + ConditionReference ref = new ConditionReference(info, false); + + assertEquals("RegionSet", ref.getReturnVariable()); + } + + @Test + void testEquals() { + ConditionReference ref1 = new ConditionReference(baseConditionInfo, false); + ConditionReference ref2 = new ConditionReference(baseConditionInfo, false); + + assertEquals(ref1, ref2); + } + + @Test + void testNotEqualsWithDifferentNegation() { + ConditionReference ref1 = new ConditionReference(baseConditionInfo, false); + ConditionReference ref2 = new ConditionReference(baseConditionInfo, true); + + assertNotEquals(ref1, ref2); + } + + @Test + void testNotEqualsWithDifferentCondition() { + Condition otherCondition = Condition.builder().fn(IsSet.ofExpressions(Literal.of("{Bucket}"))).build(); + ConditionInfo otherInfo = ConditionInfo.from(otherCondition); + ConditionReference ref1 = new ConditionReference(baseConditionInfo, false); + ConditionReference ref2 = new ConditionReference(otherInfo, false); + + assertNotEquals(ref1, ref2); + } + + @Test + void testHashCode() { + ConditionReference ref1 = new ConditionReference(baseConditionInfo, false); + ConditionReference ref2 = new ConditionReference(baseConditionInfo, false); + + assertEquals(ref1.hashCode(), ref2.hashCode()); + } + + @Test + void testHashCodeDifferentForNegated() { + ConditionReference ref1 = new ConditionReference(baseConditionInfo, false); + ConditionReference ref2 = new ConditionReference(baseConditionInfo, true); + + // Hash codes should be different for negated vs non-negated + assertNotEquals(ref1.hashCode(), ref2.hashCode()); + } + + @Test + void testToString() { + ConditionReference ref = new ConditionReference(baseConditionInfo, false); + String str = ref.toString(); + + assertFalse(str.startsWith("!")); + assertTrue(str.contains("isSet")); + } + + @Test + void testToStringNegated() { + ConditionReference ref = new ConditionReference(baseConditionInfo, true); + String str = ref.toString(); + + assertTrue(str.startsWith("!")); + assertTrue(str.contains("isSet")); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/RuleBasedConditionEvaluatorTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/RuleBasedConditionEvaluatorTest.java new file mode 100644 index 00000000000..34fed409058 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/RuleBasedConditionEvaluatorTest.java @@ -0,0 +1,72 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.evaluation.RuleEvaluator; +import software.amazon.smithy.rulesengine.language.evaluation.value.Value; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; + +class RuleBasedConditionEvaluatorTest { + + @Test + void testEvaluatesConditions() { + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("param1")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.isSet("param2")).build(); + Condition[] conditions = {cond1, cond2}; + + // Create a mock evaluator that returns true for first condition, false for second + RuleEvaluator mockEvaluator = new RuleEvaluator() { + @Override + public Value evaluateCondition(Condition condition) { + if (condition == cond1) { + return Value.booleanValue(true); + } else { + return Value.booleanValue(false); + } + } + }; + + RuleBasedConditionEvaluator evaluator = new RuleBasedConditionEvaluator(mockEvaluator, conditions); + + assertTrue(evaluator.test(0)); + assertFalse(evaluator.test(1)); + } + + @Test + void testHandlesEmptyValue() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("param")).build(); + Condition[] conditions = {cond}; + + RuleEvaluator mockEvaluator = new RuleEvaluator() { + @Override + public Value evaluateCondition(Condition condition) { + return Value.emptyValue(); + } + }; + + RuleBasedConditionEvaluator evaluator = new RuleBasedConditionEvaluator(mockEvaluator, conditions); + assertFalse(evaluator.test(0)); + } + + @Test + void testHandlesNonBooleanTruthyValue() { + Condition cond = Condition.builder().fn(TestHelpers.parseUrl("https://example.com")).build(); + Condition[] conditions = {cond}; + + RuleEvaluator mockEvaluator = new RuleEvaluator() { + @Override + public Value evaluateCondition(Condition condition) { + return Value.stringValue("some-string"); + } + }; + + RuleBasedConditionEvaluator evaluator = new RuleBasedConditionEvaluator(mockEvaluator, conditions); + assertTrue(evaluator.test(0)); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/TestHelpers.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/TestHelpers.java new file mode 100644 index 00000000000..093fa7e49a8 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/TestHelpers.java @@ -0,0 +1,46 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic; + +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.ParseUrl; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; + +public final class TestHelpers { + + private TestHelpers() {} + + public static LibraryFunction isSet(String paramName) { + return IsSet.ofExpressions(Expression.getReference(Identifier.of(paramName))); + } + + public static LibraryFunction stringEquals(String paramName, String value) { + return StringEquals.ofExpressions( + Expression.getReference(Identifier.of(paramName)), + StringLiteral.of(value)); + } + + public static LibraryFunction booleanEquals(String paramName, boolean value) { + return BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of(paramName)), + Literal.booleanLiteral(value)); + } + + public static LibraryFunction parseUrl(String urlTemplate) { + return ParseUrl.ofExpressions(Literal.stringLiteral(Template.fromString(urlTemplate))); + } + + public static Endpoint endpoint(String url) { + return Endpoint.builder().url(Expression.of(url)).build(); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java new file mode 100644 index 00000000000..c94a9ec05a8 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java @@ -0,0 +1,533 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class BddBuilderTest { + + private BddBuilder builder; + + @BeforeEach + void setUp() { + builder = new BddBuilder(); + } + + @Test + void testTerminals() { + assertEquals(1, builder.makeTrue()); + assertEquals(-1, builder.makeFalse()); + + // Terminals are constants + assertEquals(1, builder.makeTrue()); + assertEquals(-1, builder.makeFalse()); + } + + @Test + void testNodeReduction() { + builder.setConditionCount(2); + + // Node with identical branches should be reduced + int reduced = builder.makeNode(0, builder.makeTrue(), builder.makeTrue()); + assertEquals(1, reduced); // Should return TRUE directly + + // Verify no new node was created + assertEquals(1, builder.getNodes().size()); + } + + @Test + void testComplementCanonicalization() { + builder.setConditionCount(2); + + // Create node with complement on low branch + int node1 = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + + // Create equivalent node with flipped branches and complement + // This should canonicalize to the same node but with complement + int node2 = builder.makeNode(0, builder.makeFalse(), builder.makeTrue()); + + assertEquals(-node1, node2); // Should be complement of first node + + // Only one actual node should be created + assertEquals(2, builder.getNodes().size()); // terminal + 1 node + } + + @Test + void testNodeDeduplication() { + builder.setConditionCount(2); + + // Create same node twice + int node1 = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int node2 = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + + assertEquals(node1, node2); // Should return same reference + assertEquals(2, builder.getNodes().size()); // No duplicate created + } + + @Test + void testResultNodes() { + builder.setConditionCount(2); + + int result0 = builder.makeResult(0); + int result1 = builder.makeResult(1); + + // Result nodes should have distinct references + assertTrue(result0 != result1); + assertTrue(builder.isResult(result0)); + assertTrue(builder.isResult(result1)); + assertFalse(builder.isResult(builder.makeTrue())); + + // Result refs should be encoded with RESULT_OFFSET + assertEquals(Bdd.RESULT_OFFSET, result0); + assertEquals(Bdd.RESULT_OFFSET + 1, result1); + } + + @Test + void testNegation() { + builder.setConditionCount(2); + + int node = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int negated = builder.negate(node); + + assertEquals(-node, negated); + assertEquals(node, builder.negate(negated)); // Double negation + + // Cannot negate terminals + assertEquals(-1, builder.negate(builder.makeTrue())); + assertEquals(1, builder.negate(builder.makeFalse())); + } + + @Test + void testNegateResult() { + builder.setConditionCount(2); + int result = builder.makeResult(0); + + assertThrows(IllegalArgumentException.class, () -> builder.negate(result)); + } + + @Test + void testAndOperation() { + builder.setConditionCount(2); + + // TRUE AND TRUE = TRUE + assertEquals(1, builder.and(builder.makeTrue(), builder.makeTrue())); + + // TRUE AND FALSE = FALSE + assertEquals(-1, builder.and(builder.makeTrue(), builder.makeFalse())); + + // FALSE AND x = FALSE + assertEquals(-1, builder.and(builder.makeFalse(), builder.makeTrue())); + assertEquals(-1, builder.and(builder.makeFalse(), builder.makeFalse())); + + // x AND x = x + int node = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + assertEquals(node, builder.and(node, node)); + } + + @Test + void testOrOperation() { + builder.setConditionCount(2); + + // FALSE OR FALSE = FALSE + assertEquals(-1, builder.or(builder.makeFalse(), builder.makeFalse())); + + // TRUE OR x = TRUE + assertEquals(1, builder.or(builder.makeTrue(), builder.makeFalse())); + assertEquals(1, builder.or(builder.makeTrue(), builder.makeTrue())); + + // FALSE OR TRUE = TRUE + assertEquals(1, builder.or(builder.makeFalse(), builder.makeTrue())); + + // x OR x = x + int node = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + assertEquals(node, builder.or(node, node)); + } + + @Test + void testIteBasicCases() { + builder.setConditionCount(2); + + // ITE(TRUE, g, h) = g + int g = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int h = builder.makeNode(1, builder.makeTrue(), builder.makeFalse()); + assertEquals(g, builder.ite(builder.makeTrue(), g, h)); + + // ITE(FALSE, g, h) = h + assertEquals(h, builder.ite(builder.makeFalse(), g, h)); + + // ITE(f, g, g) = g + int f = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + assertEquals(g, builder.ite(f, g, g)); + } + + @Test + void testIteWithComplement() { + builder.setConditionCount(2); + + int f = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int g = builder.makeNode(1, builder.makeTrue(), builder.makeFalse()); + int h = builder.makeNode(1, builder.makeFalse(), builder.makeTrue()); + + // ITE with complemented condition should swap branches + int result1 = builder.ite(f, g, h); + int result2 = builder.ite(builder.negate(f), h, g); + assertEquals(result1, result2); + } + + @Test + void testResultInIte() { + builder.setConditionCount(1); + + int cond = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int result0 = builder.makeResult(0); + int result1 = builder.makeResult(1); + + // ITE with result terminals + int ite = builder.ite(cond, result0, result1); + assertTrue(ite != 0); + + // Condition cannot be a result + assertThrows(IllegalArgumentException.class, + () -> builder.ite(result0, builder.makeTrue(), builder.makeFalse())); + } + + @Test + void testSetConditionCountRequired() { + // Cannot create result without setting condition count + assertThrows(IllegalStateException.class, () -> builder.makeResult(0)); + } + + @Test + void testGetVariable() { + builder.setConditionCount(3); + + assertEquals(-1, builder.getVariable(builder.makeTrue())); + assertEquals(-1, builder.getVariable(builder.makeFalse())); + + int node = builder.makeNode(1, builder.makeTrue(), builder.makeFalse()); + assertEquals(1, builder.getVariable(node)); + assertEquals(1, builder.getVariable(Math.abs(node))); // Use absolute value for complement + + // Test result references + int result = builder.makeResult(0); + assertEquals(3, builder.getVariable(result)); // conditionCount + 0 + } + + @Test + void testReduceSimpleBdd() { + builder.setConditionCount(3); + + // makeNode already reduces nodes with identical branches + // So we need to create a different scenario for reduction + int a = builder.makeNode(2, builder.makeTrue(), builder.makeFalse()); + int b = builder.makeNode(1, a, builder.makeFalse()); + int root = builder.makeNode(0, b, a); + + int nodesBefore = builder.getNodes().size(); + builder.reduce(root); + + // Structure should be preserved if already optimal + assertEquals(nodesBefore, builder.getNodes().size()); + } + + @Test + void testReduceNoChange() { + builder.setConditionCount(2); + + // Create already-reduced BDD + int right = builder.makeNode(1, builder.makeTrue(), builder.makeFalse()); + int root = builder.makeNode(0, right, builder.makeFalse()); + + int nodesBefore = builder.getNodes().size(); + builder.reduce(root); + + // No change expected + assertEquals(nodesBefore, builder.getNodes().size()); + } + + @Test + void testReduceTerminals() { + // Reducing terminals should return them unchanged + assertEquals(1, builder.reduce(builder.makeTrue())); + assertEquals(-1, builder.reduce(builder.makeFalse())); + } + + @Test + void testReduceResults() { + builder.setConditionCount(1); + int result = builder.makeResult(0); + + // Reducing result nodes should return them unchanged + assertEquals(result, builder.reduce(result)); + } + + @Test + void testReduceWithComplement() { + builder.setConditionCount(3); + + // Create BDD with complement edges + int a = builder.makeNode(2, builder.makeTrue(), builder.makeFalse()); + int b = builder.makeNode(1, a, builder.negate(a)); + int root = builder.makeNode(0, b, builder.makeFalse()); + + // Reduce without complement first + int reduced = builder.reduce(root); + + // Now test reducing with complemented root + int complementRoot = builder.negate(root); + int reducedComplement = builder.reduce(complementRoot); + + // The result should be the complement of the reduced root + assertEquals(builder.negate(reduced), reducedComplement); + + // Verify the structure is preserved + assertTrue(builder.getNodes().size() > 1); + } + + @Test + void testReduceClearsCache() { + builder.setConditionCount(2); + + // Create nodes and perform ITE to populate cache - use only boolean nodes + int a = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int b = builder.makeNode(1, builder.makeTrue(), builder.makeFalse()); + int ite1 = builder.ite(a, b, builder.makeFalse()); + + // Reduce + builder.reduce(ite1); + + // Cache should be cleared, so same ITE creates new result + // Recreate the nodes since reduce may have changed internal state + a = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + b = builder.makeNode(1, builder.makeTrue(), builder.makeFalse()); + int ite2 = builder.ite(a, b, builder.makeFalse()); + assertTrue(ite2 != 0); // Should get a valid reference + } + + @Test + void testReduceSharedSubgraphs() { + builder.setConditionCount(3); + + // Create BDD with shared subgraphs - use only boolean nodes + int shared = builder.makeNode(2, builder.makeTrue(), builder.makeFalse()); + int left = builder.makeNode(1, shared, builder.makeFalse()); + int right = builder.makeNode(1, builder.makeTrue(), shared); + int root = builder.makeNode(0, left, right); + + builder.reduce(root); + + // Shared subgraph should remain shared after reduction + List nodes = builder.getNodes(); + + // Verify structure is maintained - at least one node should exist + assertTrue(nodes.size() > 1); + } + + @Test + void testReducePreservesResultNodes() { + builder.setConditionCount(2); + + // Create BDD with result terminals + int result0 = builder.makeResult(0); + int result1 = builder.makeResult(1); + int cond = builder.makeNode(0, result0, result1); + int root = builder.makeNode(1, cond, builder.makeFalse()); + + int reduced = builder.reduce(root); + + // Result refs should be preserved in the hi/lo branches + boolean foundResult0 = false; + boolean foundResult1 = false; + + // Check the nodes for result references + for (int[] node : builder.getNodes()) { + if (node[1] == result0 || node[2] == result0) { + foundResult0 = true; + } + if (node[1] == result1 || node[2] == result1) { + foundResult1 = true; + } + } + + assertTrue(foundResult0); + assertTrue(foundResult1); + } + + @Test + void testReduceActuallyReduces() { + builder.setConditionCount(3); + + // First create some nodes + int bottom = builder.makeNode(2, builder.makeTrue(), builder.makeFalse()); + int middle = builder.makeNode(1, bottom, builder.makeFalse()); + int root = builder.makeNode(0, middle, bottom); + + int beforeSize = builder.getNodes().size(); + builder.reduce(root); + int afterSize = builder.getNodes().size(); + + // In this case, no reduction should occur since makeNode already optimized + assertEquals(beforeSize, afterSize); + } + + @Test + void testReduceWithPreExistingComplementStructure() { + builder.setConditionCount(3); + + // Create a structure where reduce will encounter complement on low during rebuild + // First create base nodes + int a = builder.makeNode(2, builder.makeTrue(), builder.makeFalse()); + + // When we create this node, makeNode will canonicalize it + // The actual stored node will have the complement bit on the reference, not in the node + int b = builder.makeNode(1, builder.makeTrue(), builder.negate(a)); + + // This creates a scenario where during reduce's rebuild, + // makeNodeInNew might encounter the complement + int root = builder.makeNode(0, b, a); + + // Force a reduce operation + int reduced = builder.reduce(root); + + // The BDD should be functionally equivalent + // We can't make strong assertions about node count since reduce may optimize + assertTrue(reduced != 0); + + // Verify the BDD still evaluates correctly + // by checking that it's not a constant + assertNotEquals(reduced, builder.makeTrue()); + assertNotEquals(reduced, builder.makeFalse()); + } + + @Test + void testCofactorRecursive() { + builder.setConditionCount(3); + + // Create a multi-level BDD with only boolean nodes (no results) + int bottom = builder.makeNode(2, builder.makeTrue(), builder.makeFalse()); + int middle = builder.makeNode(1, bottom, builder.makeFalse()); + int root = builder.makeNode(0, middle, bottom); + + // Cofactor with respect to variable 1 (appears deeper in BDD) + int cofactorTrue = builder.cofactor(root, 1, true); + int cofactorFalse = builder.cofactor(root, 1, false); + + // The cofactors should be different + assertTrue(cofactorTrue != cofactorFalse); + + // Verify structure is simplified + assertTrue(builder.getNodes().size() > 1); + } + + @Test + void testCofactorWithResults() { + builder.setConditionCount(2); + + // Create BDD with result terminals + int result0 = builder.makeResult(0); + int result1 = builder.makeResult(1); + int node = builder.makeNode(0, result0, result1); + + // Cofactor should select appropriate result + assertEquals(result0, builder.cofactor(node, 0, true)); + assertEquals(result1, builder.cofactor(node, 0, false)); + } + + @Test + void testReduceWithResults() { + builder.setConditionCount(2); + + // Create a BDD that uses results properly + int result0 = builder.makeResult(0); + int result1 = builder.makeResult(1); + + // Create condition nodes that branch to results + int node = builder.makeNode(1, result0, result1); + int root = builder.makeNode(0, node, builder.makeFalse()); + + int reduced = builder.reduce(root); + + // The structure should be preserved + assertNotEquals(0, reduced); // Should not be invalid + assertFalse(builder.isResult(reduced)); // Root should still be a condition node + } + + @Test + void testIteWithResultsInBranches() { + builder.setConditionCount(2); + + // Create a condition and two results + int cond = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int result0 = builder.makeResult(0); + int result1 = builder.makeResult(1); + + // ITE should handle results in then/else branches + int ite = builder.ite(cond, result0, result1); + + // The result should be a node that branches to the two results + assertTrue(ite > 0); + assertFalse(builder.isResult(ite)); + } + + @Test + void testResultMaskNoCollisions() { + builder.setConditionCount(3); + + // Create many nodes to ensure no collision with result encoding + int node1 = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int node2 = builder.makeNode(1, node1, builder.makeFalse()); + int node3 = builder.makeNode(2, node2, node1); + + // Create results + int result0 = builder.makeResult(0); + int result1 = builder.makeResult(1); + + // Verify no collisions + assertNotEquals(node1, result0); + assertNotEquals(node2, result0); + assertNotEquals(node3, result0); + assertNotEquals(node1, result1); + assertNotEquals(node2, result1); + assertNotEquals(node3, result1); + + // Verify correct identification + assertFalse(builder.isResult(node1)); + assertFalse(builder.isResult(node2)); + assertFalse(builder.isResult(node3)); + assertTrue(builder.isResult(result0)); + assertTrue(builder.isResult(result1)); + } + + @Test + void testReset() { + builder.setConditionCount(2); + + // Create some state + int node = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + int result = builder.makeResult(0); + + // Reset + builder.reset(); + + // Verify state is cleared + assertEquals(1, builder.getNodes().size()); // Only terminal + assertThrows(IllegalStateException.class, () -> builder.makeResult(0)); + + // Can use builder again + builder.setConditionCount(1); + int newNode = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + assertNotEquals(0, newNode); // Should get a valid reference + assertNotEquals(1, Math.abs(newNode)); // Should not be a terminal + assertNotEquals(-1, Math.abs(newNode)); // Should not be a terminal + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java new file mode 100644 index 00000000000..f34bc2a97e9 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java @@ -0,0 +1,190 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; + +class BddCompilerTest { + + // Common parameters used across tests + private static final Parameter REGION_PARAM = Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build(); + + private static final Parameter BUCKET_PARAM = Parameter.builder() + .name("Bucket") + .type(ParameterType.STRING) + .build(); + + private static final Parameter A_PARAM = Parameter.builder() + .name("A") + .type(ParameterType.STRING) + .build(); + + private static final Parameter B_PARAM = Parameter.builder() + .name("B") + .type(ParameterType.STRING) + .build(); + + @Test + void testCompileSimpleEndpointRule() { + // Single rule with one condition + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(REGION_PARAM).build()) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + + Bdd bdd = compiler.compile(); + + assertNotNull(bdd); + assertEquals(1, bdd.getConditions().size()); + // Results include: endpoint when condition true, no match when false, + // and possibly a no match for the overall fallthrough + assertTrue(bdd.getResults().size() >= 2); + assertTrue(bdd.getRootRef() > 0); + } + + @Test + void testCompileErrorRule() { + // Error rule instead of endpoint + Rule rule = ErrorRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .error("Bucket is required"); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(BUCKET_PARAM).build()) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + + Bdd bdd = compiler.compile(); + + assertEquals(1, bdd.getConditions().size()); + // Similar to endpoint rule + assertTrue(bdd.getResults().size() >= 2); + } + + @Test + void testCompileTreeRule() { + // Nested tree rule + Rule nestedRule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://bucket.example.com")); + Rule treeRule = TreeRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .treeRule(nestedRule); + Parameters params = Parameters.builder().addParameter(REGION_PARAM).addParameter(BUCKET_PARAM).build(); + EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).addRule(treeRule).build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + + Bdd bdd = compiler.compile(); + + assertEquals(2, bdd.getConditions().size()); + assertTrue(bdd.getNodes().length > 2); // Should have multiple nodes + } + + @Test + void testCompileWithCustomOrdering() { + // Multiple conditions to test ordering + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("A")).build(), + Condition.builder().fn(TestHelpers.isSet("B")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + Parameters params = Parameters.builder().addParameter(A_PARAM).addParameter(B_PARAM).build(); + EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).addRule(rule).build(); + + Cfg cfg = Cfg.from(ruleSet); + + // Use fixed ordering (B before A) + Condition condA = rule.getConditions().get(0); + Condition condB = rule.getConditions().get(1); + ConditionOrderingStrategy customOrdering = ConditionOrderingStrategy.fixed(Arrays.asList(condB, condA)); + + BddCompiler compiler = new BddCompiler(cfg, customOrdering, new BddBuilder()); + Bdd bdd = compiler.compile(); + + // Verify ordering was applied + assertEquals(condB, bdd.getConditions().get(0)); + assertEquals(condA, bdd.getConditions().get(1)); + } + + @Test + void testCompileEmptyRuleSet() { + // No rules + EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(Parameters.builder().build()).build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + Bdd bdd = compiler.compile(); + + assertEquals(0, bdd.getConditions().size()); + // Even with no rules, there's still a result (no match) + assertFalse(bdd.getResults().isEmpty()); + // Should have at least terminal node + assertNotEquals(0, bdd.getNodes().length); + } + + @Test + void testCompileSameResultMultiplePaths() { + // Two rules leading to same endpoint + Rule rule1 = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Rule rule2 = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(REGION_PARAM) + .addParameter(BUCKET_PARAM) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule1) + .addRule(rule2) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + + Bdd bdd = compiler.compile(); + + // The BDD compiler might create separate result nodes even for same endpoint + // depending on how the CFG is structured + assertEquals(3, bdd.getResults().size()); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluatorTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluatorTest.java new file mode 100644 index 00000000000..dc9af6c90f4 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluatorTest.java @@ -0,0 +1,215 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.ConditionEvaluator; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.utils.ListUtils; + +class BddEvaluatorTest { + + private static final Parameters EMPTY = Parameters.builder().build(); + + @Test + void testEvaluateTerminalTrue() { + // BDD with just TRUE terminal + int[][] nodes = new int[][] {{-1, 1, -1}}; + Bdd bdd = new Bdd(EMPTY, ListUtils.of(), ListUtils.of(), nodes, 1); + + BddEvaluator evaluator = BddEvaluator.from(bdd); + int result = evaluator.evaluate(idx -> true); + + assertEquals(-1, result); // TRUE terminal returns -1 (TRUE isn't valid in our MTBDD) + } + + @Test + void testEvaluateTerminalFalse() { + // BDD with just FALSE terminal + int[][] nodes = new int[][] {{-1, 1, -1}}; + Bdd bdd = new Bdd(EMPTY, ListUtils.of(), ListUtils.of(), nodes, -1); + + BddEvaluator evaluator = BddEvaluator.from(bdd); + int result = evaluator.evaluate(idx -> true); + + assertEquals(-1, result); // FALSE terminal returns -1 (same as TRUE; FALSE isn't valid in our MTBDD). + } + + @Test + void testEvaluateSingleConditionTrue() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("param")).build(); + Rule rule = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + + // BDD: if condition then result1 else no-match + // With new encoding: result references are encoded as RESULT_OFFSET + resultIndex + int result1Ref = Bdd.RESULT_OFFSET + 1; + + int[][] nodes = new int[][] { + {-1, 1, -1}, // 0: terminal + {0, result1Ref, -1} // 1: condition node (high=result1, low=FALSE) + }; + Bdd bdd = new Bdd(EMPTY, ListUtils.of(cond), ListUtils.of(null, rule), nodes, 2); + + BddEvaluator evaluator = BddEvaluator.from(bdd); + + // When condition is true, should return result 1 + assertEquals(1, evaluator.evaluate(idx -> true)); + + // When condition is false, should return -1 (no match) + assertEquals(-1, evaluator.evaluate(idx -> false)); + } + + @Test + void testEvaluateComplementedNode() { + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("param1")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.isSet("param2")).build(); + Rule rule = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + + // BDD with a complemented reference to an internal node + // We want: if cond1 then NOT(cond2) else false + // Which means: if cond1 && !cond2 then result1 else no-match + int result1Ref = Bdd.RESULT_OFFSET + 1; + + int[][] nodes = new int[][] { + {-1, 1, -1}, // 0: terminal + {0, -3, -1}, // 1: cond1 node (high=-3 (complement of ref 3 = node 2), low=FALSE) + {1, -1, result1Ref} // 2: cond2 node (high=FALSE, low=result1) + }; + // Root is 2 (reference to node at index 1) + Bdd bdd = new Bdd(EMPTY, ListUtils.of(cond1, cond2), ListUtils.of(null, rule), nodes, 2); + + BddEvaluator evaluator = BddEvaluator.from(bdd); + + // When cond1=true, we follow high branch to -3 (complement of node 2) + // The complement flips the node's branch selection + // If cond2=true, with complement we take the "false" branch (which is low) -> result1 + // If cond2=false, with complement we take the "true" branch (which is high) -> FALSE + ConditionEvaluator bothTrue = idx -> true; + assertEquals(1, evaluator.evaluate(bothTrue)); // cond1=true, cond2=true -> result1 + + ConditionEvaluator firstTrueSecondFalse = idx -> idx == 0; + assertEquals(-1, evaluator.evaluate(firstTrueSecondFalse)); // cond1=true, cond2=false -> FALSE + + // When cond1=false, we follow low branch to FALSE + ConditionEvaluator firstFalse = idx -> false; + assertEquals(-1, evaluator.evaluate(firstFalse)); // cond1=false -> FALSE + } + + @Test + void testEvaluateMultipleConditions() { + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("param1")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.isSet("param2")).build(); + Rule rule1 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://a.com")); + Rule rule2 = ErrorRule.builder().error("hi"); + + // BDD: if cond1 then (if cond2 then result1 else result2) else no-match + // Result references encoded with RESULT_OFFSET + int result1Ref = Bdd.RESULT_OFFSET + 1; + int result2Ref = Bdd.RESULT_OFFSET + 2; + + int[][] nodes = new int[][] { + {-1, 1, -1}, // 0: terminal + {0, 3, -1}, // 1: cond1 node (high=cond2 node, low=FALSE) + {1, result1Ref, result2Ref} // 2: cond2 node (high=result1, low=result2) + }; + Bdd bdd = new Bdd(EMPTY, ListUtils.of(cond1, cond2), ListUtils.of(null, rule1, rule2), nodes, 2); + + BddEvaluator evaluator = BddEvaluator.from(bdd); + ConditionEvaluator condEval = idx -> idx == 0; // only first condition is true + + int result = evaluator.evaluate(condEval); + assertEquals(2, result); // Should get result2 since cond2 is false + } + + @Test + void testEvaluateNoMatchResult() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("param")).build(); + + // BDD with explicit no-match result (index 0) + // Result0 reference encoded with RESULT_OFFSET + int result0Ref = Bdd.RESULT_OFFSET; + + int[][] nodes = new int[][] { + {-1, 1, -1}, // 0: terminal + {0, -1, result0Ref} // 1: condition node + }; + Bdd bdd = new Bdd(EMPTY, ListUtils.of(cond), ListUtils.of((Rule) null), nodes, 2); + + BddEvaluator evaluator = BddEvaluator.from(bdd); + int result = evaluator.evaluate(idx -> false); + + assertEquals(-1, result); // Result index 0 is treated as no-match + } + + @Test + void testEvaluateWithLargeResultIndex() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("param")).build(); + Rule rule = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + + // Test with a larger result index to ensure offset works correctly + int result999Ref = Bdd.RESULT_OFFSET + 999; + + int[][] nodes = new int[][] { + {-1, 1, -1}, // 0: terminal + {0, result999Ref, -1} // 1: condition node + }; + + // Create a results list with 1000 entries (0-999) + Rule[] results = new Rule[1000]; + results[999] = rule; + + Bdd bdd = new Bdd(EMPTY, ListUtils.of(cond), ListUtils.of(results), nodes, 2); + + BddEvaluator evaluator = BddEvaluator.from(bdd); + + // When condition is true, should return result 999 + assertEquals(999, evaluator.evaluate(idx -> true)); + } + + @Test + void testEvaluateComplexBddWithMixedReferences() { + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("param1")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.isSet("param2")).build(); + Condition cond3 = Condition.builder().fn(TestHelpers.isSet("param3")).build(); + Rule rule1 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://a.com")); + Rule rule2 = ErrorRule.builder().error("error"); + + // Complex BDD with multiple conditions, complement edges, and results + int result1Ref = Bdd.RESULT_OFFSET + 1; + int result2Ref = Bdd.RESULT_OFFSET + 2; + + int[][] nodes = new int[][] { + {-1, 1, -1}, // 0: terminal + {0, 3, 4}, // 1: cond1 node + {1, result1Ref, -1}, // 2: cond2 node + {2, result2Ref, -5} // 3: cond3 node (low has complement ref) + }; + + Bdd bdd = new Bdd(EMPTY, + ListUtils.of(cond1, cond2, cond3), + ListUtils.of(null, rule1, rule2), + nodes, + 2); + + BddEvaluator evaluator = BddEvaluator.from(bdd); + + // Test various paths through the BDD + ConditionEvaluator allTrue = idx -> true; + assertEquals(1, evaluator.evaluate(allTrue)); // cond1=T -> cond2=T -> result1 + + ConditionEvaluator firstTrueOnly = idx -> idx == 0; + assertEquals(-1, evaluator.evaluate(firstTrueOnly)); // cond1=T -> cond2=F -> FALSE + + ConditionEvaluator firstFalseThirdTrue = idx -> idx == 2; + assertEquals(2, evaluator.evaluate(firstFalseThirdTrue)); // cond1=F -> cond3=T -> result2 + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java new file mode 100644 index 00000000000..d2b24c04e40 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java @@ -0,0 +1,311 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.utils.ListUtils; + +class BddTest { + + @Test + void testConstructorValidation() { + Parameters params = Parameters.builder().build(); + int[][] nodes = new int[][] {{-1, 1, -1}}; + + // Should reject complemented root (except -1 which is FALSE terminal) + assertThrows(IllegalArgumentException.class, () -> new Bdd(params, ListUtils.of(), ListUtils.of(), nodes, -2)); + + // Should accept positive root + Bdd bdd = new Bdd(params, ListUtils.of(), ListUtils.of(), nodes, 1); + assertEquals(1, bdd.getRootRef()); + + // Should accept FALSE terminal as root + Bdd bdd2 = new Bdd(params, ListUtils.of(), ListUtils.of(), nodes, -1); + assertEquals(-1, bdd2.getRootRef()); + } + + @Test + void testBasicAccessors() { + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + Rule rule = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + int[][] nodes = new int[][] { + {-1, 1, -1}, + {0, 3, -1}, + {1, 1, -1} + }; + + Bdd bdd = new Bdd(params, ListUtils.of(cond), ListUtils.of(rule), nodes, 2); + + assertEquals(params, bdd.getParameters()); + assertEquals(1, bdd.getConditions().size()); + assertEquals(cond, bdd.getConditions().get(0)); + assertEquals(1, bdd.getConditionCount()); + assertEquals(1, bdd.getResults().size()); + assertEquals(rule, bdd.getResults().get(0)); + assertEquals(3, bdd.getNodes().length); + assertEquals(2, bdd.getRootRef()); + } + + @Test + void testFromRuleSet() { + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build()) + .addRule(EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com"))) + .build(); + + Bdd bdd = Bdd.from(ruleSet); + + assertTrue(bdd.getConditionCount() > 0); + assertFalse(bdd.getResults().isEmpty()); + assertTrue(bdd.getNodes().length > 1); // At least terminal + one node + } + + @Test + void testFromCfg() { + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().build()) + .addRule(ErrorRule.builder().error("test error")) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Bdd bdd = Bdd.from(cfg); + + assertEquals(0, bdd.getConditionCount()); // No conditions + assertFalse(bdd.getResults().isEmpty()); + } + + @Test + void testEquals() { + Bdd bdd1 = createSimpleBdd(); + Bdd bdd2 = createSimpleBdd(); + + assertEquals(bdd1, bdd2); + assertEquals(bdd1.hashCode(), bdd2.hashCode()); + + // Different root - use a different value than what createSimpleBdd returns + Bdd bdd3 = new Bdd(bdd1.getParameters(), bdd1.getConditions(), bdd1.getResults(), bdd1.getNodes(), -1); + assertNotEquals(bdd1, bdd3); + + // Different conditions + Condition newCond = Condition.builder().fn(TestHelpers.isSet("Bucket")).build(); + Bdd bdd4 = new Bdd(bdd1 + .getParameters(), ListUtils.of(newCond), bdd1.getResults(), bdd1.getNodes(), bdd1.getRootRef()); + assertNotEquals(bdd1, bdd4); + + // Different nodes + int[][] newNodes = new int[][] {{-1, 1, -1}, {0, -1, Bdd.RESULT_OFFSET + 1}}; + Bdd bdd5 = new Bdd(bdd1.getParameters(), bdd1.getConditions(), bdd1.getResults(), newNodes, bdd1.getRootRef()); + assertNotEquals(bdd1, bdd5); + } + + @Test + void testToString() { + Bdd bdd = createSimpleBdd(); + String str = bdd.toString(); + + assertTrue(str.contains("Bdd{")); + assertTrue(str.contains("conditions")); + assertTrue(str.contains("results")); + assertTrue(str.contains("root:")); + assertTrue(str.contains("nodes")); + } + + @Test + void testToStringBuilder() { + Bdd bdd = createSimpleBdd(); + StringBuilder sb = new StringBuilder(); + bdd.toString(sb); + + String str = sb.toString(); + assertTrue(str.contains("C0:")); // Condition index + assertTrue(str.contains("R0:")); // Result index + assertTrue(str.contains("terminal")); // Terminal node + } + + @Test + void testToNodeAndFromNode() { + Bdd original = createSimpleBdd(); + + Node node = original.toNode(); + assertTrue(node.isObjectNode()); + assertTrue(node.expectObjectNode().containsMember("conditions")); + assertTrue(node.expectObjectNode().containsMember("results")); + assertTrue(node.expectObjectNode().containsMember("nodes")); + assertTrue(node.expectObjectNode().containsMember("root")); + + // Original has 2 results: NoMatchRule at 0, endpoint at 1 + // Serialized should only have 1 result (the endpoint) + int serializedResultCount = node.expectObjectNode() + .expectArrayMember("results") + .getElements() + .size(); + assertEquals(1, serializedResultCount); + + Bdd restored = Bdd.fromNode(node); + assertEquals(original.getRootRef(), restored.getRootRef()); + assertEquals(original.getConditionCount(), restored.getConditionCount()); + assertEquals(original.getResults().size(), restored.getResults().size()); + assertEquals(original.getNodes().length, restored.getNodes().length); + + // Verify NoMatchRule was restored at index 0 + assertInstanceOf(NoMatchRule.class, restored.getResults().get(0)); + } + + @Test + void testToStringWithDifferentNodeTypes() { + Parameters params = Parameters.builder().build(); + + // Two conditions + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.booleanEquals("UseFips", true)).build(); + + // Two endpoint results + Rule endpoint1 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + Rule endpoint2 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example-fips.com")); + + // NoMatchRule MUST be at index 0 + List results = new ArrayList<>(); + results.add(NoMatchRule.INSTANCE); // Index 0 - always NoMatch + results.add(endpoint1); // Index 1 + results.add(endpoint2); // Index 2 + + // BDD structure referencing the correct indices + int[][] nodes = new int[][] { + {-1, 1, -1}, // 0: terminal node + {0, 2, -1}, // 1: if Region is set, go to node 2, else FALSE + {1, Bdd.RESULT_OFFSET + 2, Bdd.RESULT_OFFSET + 1} // 2: if UseFips, return result 2, else result 1 + }; + + Bdd bdd = new Bdd(params, ListUtils.of(cond1, cond2), results, nodes, 1); + String str = bdd.toString(); + + assertTrue(str.contains("Endpoint:")); + assertTrue(str.contains("C0")); + assertTrue(str.contains("C1")); + assertTrue(str.contains("R0")); + assertTrue(str.contains("R1")); + assertTrue(str.contains("R2")); + assertTrue(str.contains("NoMatchRule")); // R0 will show as NoMatchRule + + // Test serialization doesn't include NoMatchRule + Node serialized = bdd.toNode(); + assertEquals(2, + serialized.expectObjectNode() + .expectArrayMember("results") + .getElements() + .size()); // Only the two endpoints, not NoMatch + } + + private Bdd createSimpleBdd() { + Parameters params = Parameters.builder().build(); + Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + Rule endpoint = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + + // NoMatchRule MUST be at index 0 + List results = new ArrayList<>(); + results.add(NoMatchRule.INSTANCE); // Index 0 - always NoMatch + results.add(endpoint); // Index 1 - the actual endpoint + + int[][] nodes = new int[][] { + {-1, 1, -1}, // 0: terminal + {0, Bdd.RESULT_OFFSET + 1, -1} // 1: if cond true, return result 1 (endpoint), else FALSE + }; + + return new Bdd(params, ListUtils.of(cond), results, nodes, 1); + } + + // Used to regenerate BDD test cases for errorfiles + // @Test + // void generateValidBddEncoding() { + // Parameter region = Parameter.builder() + // .name("Region") + // .type(ParameterType.STRING) + // .required(true) + // .documentation("The AWS region") + // .build(); + // + // Parameter useFips = Parameter.builder() + // .name("UseFips") + // .type(ParameterType.BOOLEAN) + // .required(true) + // .defaultValue(software.amazon.smithy.rulesengine.language.evaluation.value.Value.booleanValue(false)) + // .documentation("Use FIPS endpoints") + // .build(); + // + // Parameters params = Parameters.builder() + // .addParameter(region) + // .addParameter(useFips) + // .build(); + // + // Condition useFipsTrue = Condition.builder() + // .fn(BooleanEquals.ofExpressions( + // Expression.getReference(Identifier.of("UseFips")), + // Expression.of(true))) + // .build(); + // + // // Create endpoints + // Endpoint normalEndpoint = Endpoint.builder() + // .url(Expression.of("https://service.{Region}.amazonaws.com")) + // .build(); + // + // Endpoint fipsEndpoint = Endpoint.builder() + // .url(Expression.of("https://service-fips.{Region}.amazonaws.com")) + // .build(); + // + // Rule fipsRule = EndpointRule.builder() + // .condition(useFipsTrue) + // .endpoint(fipsEndpoint); + // + // Rule normalRule = EndpointRule.builder() + // .endpoint(normalEndpoint); + // + // EndpointRuleSet ruleSet = EndpointRuleSet.builder() + // .parameters(params) + // .rules(Arrays.asList(fipsRule, normalRule)) + // .build(); + // + // Cfg cfg = Cfg.from(ruleSet); + // Bdd bdd = Bdd.from(cfg); + // + // BddTrait trait = BddTrait.builder().bdd(bdd).build(); + // BddTraitValidator validator = new BddTraitValidator(); + // ServiceShape service = ServiceShape.builder().id("foo#Bar").addTrait(trait).build(); + // Model model = Model.builder().addShape(service).build(); + // System.out.println(validator.validate(model)); + // + // System.out.println(bdd); + // + // // Get the base64 encoded nodes + // System.out.println(Node.prettyPrintJson(trait.toNode())); + // } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraphTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraphTest.java new file mode 100644 index 00000000000..90c92a4dd96 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraphTest.java @@ -0,0 +1,120 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.logic.ConditionInfo; +import software.amazon.smithy.rulesengine.logic.TestHelpers; + +class ConditionDependencyGraphTest { + + @Test + void testBasicVariableDependency() { + // Condition that defines a variable + Condition definer = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + // Condition that uses the variable + Condition user = Condition.builder() + .fn(BooleanEquals.ofExpressions(Expression.of("{hasRegion}"), Expression.of(true))) + .build(); + + Map conditionInfos = new HashMap<>(); + conditionInfos.put(definer, ConditionInfo.from(definer)); + conditionInfos.put(user, ConditionInfo.from(user)); + + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditionInfos); + + // Definer has no dependencies + assertTrue(graph.getDependencies(definer).isEmpty()); + + // User depends on definer + Set userDeps = graph.getDependencies(user); + assertEquals(1, userDeps.size()); + assertTrue(userDeps.contains(definer)); + } + + @Test + void testIsSetDependencyForNonIsSetCondition() { + // isSet condition for a variable + Condition isSetCondition = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + // Non-isSet condition using the same variable + Condition userCondition = Condition.builder() + .fn(StringEquals.ofExpressions(Expression.of("{Region}"), Expression.of("us-east-1"))) + .build(); + + Map conditionInfos = new HashMap<>(); + conditionInfos.put(isSetCondition, ConditionInfo.from(isSetCondition)); + conditionInfos.put(userCondition, ConditionInfo.from(userCondition)); + + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditionInfos); + + // Non-isSet condition depends on isSet for undefined variables + Set userDeps = graph.getDependencies(userCondition); + assertEquals(1, userDeps.size()); + assertTrue(userDeps.contains(isSetCondition)); + } + + @Test + void testMultipleDependencies() { + // Define two variables + Condition definer1 = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition definer2 = Condition.builder() + .fn(TestHelpers.isSet("Bucket")) + .result(Identifier.of("hasBucket")) + .build(); + + // Use both variables + Condition user = Condition.builder() + .fn(BooleanEquals.ofExpressions( + BooleanEquals.ofExpressions(Expression.of("{hasRegion}"), Expression.of(true)), + BooleanEquals.ofExpressions(Expression.of("{hasBucket}"), Expression.of(true)))) + .build(); + + Map conditionInfos = new HashMap<>(); + conditionInfos.put(definer1, ConditionInfo.from(definer1)); + conditionInfos.put(definer2, ConditionInfo.from(definer2)); + conditionInfos.put(user, ConditionInfo.from(user)); + + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditionInfos); + + // User depends on both definers + Set userDeps = graph.getDependencies(user); + assertEquals(2, userDeps.size()); + assertTrue(userDeps.contains(definer1)); + assertTrue(userDeps.contains(definer2)); + } + + @Test + void testUnknownConditionReturnsEmptyDependencies() { + Condition known = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + Condition unknown = Condition.builder().fn(TestHelpers.isSet("Bucket")).build(); + + Map conditionInfos = new HashMap<>(); + conditionInfos.put(known, ConditionInfo.from(known)); + + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditionInfos); + + // Getting dependencies for unknown condition returns empty set + assertTrue(graph.getDependencies(unknown).isEmpty()); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategyTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategyTest.java new file mode 100644 index 00000000000..1f2d7a89958 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategyTest.java @@ -0,0 +1,221 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.ParseUrl; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.logic.ConditionInfo; +import software.amazon.smithy.rulesengine.logic.TestHelpers; + +class DefaultOrderingStrategyTest { + + @Test + void testIsSetComesFirst() { + // isSet should be ordered before other conditions + Condition isSetCond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + Condition stringEqualsCond = Condition.builder() + .fn(StringEquals.ofExpressions(Literal.of("{Region}"), Literal.of("us-east-1"))) + .build(); + + Condition[] conditions = {stringEqualsCond, isSetCond}; + Map infos = createInfoMap(conditions); + + List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + + // isSet should come first + assertEquals(isSetCond, ordered.get(0)); + assertEquals(stringEqualsCond, ordered.get(1)); + } + + @Test + void testVariableDefiningConditionsFirst() { + // Conditions that define variables should come before those that don't + Condition definer = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition nonDefiner = Condition.builder().fn(TestHelpers.isSet("Bucket")).build(); + + Condition[] conditions = {nonDefiner, definer}; + Map infos = createInfoMap(conditions); + + List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + + // Variable-defining condition should come first + assertEquals(definer, ordered.get(0)); + assertEquals(nonDefiner, ordered.get(1)); + } + + @Test + void testDependencyOrdering() { + // Condition that defines a variable + Condition definer = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + // Condition that uses the variable + Condition user = Condition.builder() + .fn(BooleanEquals.ofExpressions(Literal.of("{hasRegion}"), Literal.of(true))) + .build(); + + Condition[] conditions = {user, definer}; + Map infos = createInfoMap(conditions); + + List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + + // Definer must come before user + assertEquals(definer, ordered.get(0)); + assertEquals(user, ordered.get(1)); + } + + @Test + void testComplexityOrdering() { + // Simple condition + Condition simple = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + + // Complex condition (parseURL has higher cost) + Condition complex = Condition.builder().fn(ParseUrl.ofExpressions(Literal.of("https://example.com"))).build(); + + Condition[] conditions = {complex, simple}; + Map infos = createInfoMap(conditions); + + List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + + // Simple should come before complex + assertEquals(simple, ordered.get(0)); + assertEquals(complex, ordered.get(1)); + } + + @Test + void testCircularDependencyDetection() { + // Create conditions with circular dependency + // Note: This is a pathological case that shouldn't happen in practice + Condition cond1 = Condition.builder() + .fn(BooleanEquals.ofExpressions(Literal.of("{var2}"), Literal.of(true))) + .result(Identifier.of("var1")) + .build(); + + Condition cond2 = Condition.builder() + .fn(BooleanEquals.ofExpressions(Literal.of("{var1}"), Literal.of(true))) + .result(Identifier.of("var2")) + .build(); + + Condition[] conditions = {cond1, cond2}; + Map infos = createInfoMap(conditions); + + assertThrows(IllegalStateException.class, + () -> DefaultOrderingStrategy.orderConditions(conditions, infos)); + } + + @Test + void testMultiLevelDependencies() { + // A -> B -> C dependency chain + Condition condA = Condition.builder() + .fn(TestHelpers.isSet("input")) + .result(Identifier.of("var1")) + .build(); + + Condition condB = Condition.builder() + .fn(BooleanEquals.ofExpressions(Literal.of("{var1}"), Literal.of(true))) + .result(Identifier.of("var2")) + .build(); + + Condition condC = Condition.builder() + .fn(BooleanEquals.ofExpressions(Literal.of("{var2}"), Literal.of(false))) + .build(); + + // Mix up the order + Condition[] conditions = {condC, condA, condB}; + Map infos = createInfoMap(conditions); + + List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + + assertEquals(condA, ordered.get(0)); + assertEquals(condB, ordered.get(1)); + assertEquals(condC, ordered.get(2)); + } + + @Test + void testStableSortForEqualPriority() { + // Two similar conditions with no dependencies use stable sort + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.isSet("Bucket")).build(); + + Condition[] conditions = {cond1, cond2}; + Map infos = createInfoMap(conditions); + + List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + + // Order should be deterministic based on toString + assertEquals(2, ordered.size()); + assertTrue(ordered.contains(cond1)); + assertTrue(ordered.contains(cond2)); + } + + @Test + void testEmptyConditions() { + Condition[] conditions = new Condition[0]; + Map infos = new HashMap<>(); + + List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + + assertEquals(0, ordered.size()); + } + + @Test + void testSingleCondition() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + + Condition[] conditions = {cond}; + Map infos = createInfoMap(conditions); + + List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + + assertEquals(1, ordered.size()); + assertEquals(cond, ordered.get(0)); + } + + @Test + void testIsSetDependencyForSameVariable() { + // isSet and value check for same variable + Condition isSet = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + + Condition valueCheck = Condition.builder() + .fn(StringEquals.ofExpressions(Literal.of("{Region}"), Literal.of("us-east-1"))) + .build(); + + // Put value check first to test ordering + Condition[] conditions = {valueCheck, isSet}; + Map infos = createInfoMap(conditions); + + List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + + // isSet must come before value check + assertEquals(isSet, ordered.get(0)); + assertEquals(valueCheck, ordered.get(1)); + } + + private Map createInfoMap(Condition... conditions) { + Map map = new HashMap<>(); + for (Condition cond : conditions) { + map.put(cond, ConditionInfo.from(cond)); + } + return map; + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java new file mode 100644 index 00000000000..6c2cd3d5669 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java @@ -0,0 +1,196 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; + +class NodeReversalTest { + + @Test + void testSingleNodeBdd() { + // BDD with just terminal node + int[][] nodes = new int[][] { + {-1, 1, -1} // terminal + }; + + Bdd original = new Bdd( + Parameters.builder().build(), + new ArrayList<>(), + new ArrayList<>(), + nodes, + 1 // root is TRUE + ); + + NodeReversal reversal = new NodeReversal(); + Bdd reversed = reversal.apply(original); + + // Should be unchanged (2 nodes returns as-is). + assertEquals(1, reversed.getNodes().length); + assertEquals(1, reversed.getRootRef()); + assertArrayEquals(new int[] {-1, 1, -1}, reversed.getNodes()[0]); + } + + @Test + void testComplementEdges() { + // BDD with complement edges + int[][] nodes = new int[][] { + {-1, 1, -1}, // terminal + {0, 3, -2}, // condition 0, high to node 2, low to complement of node 1 + {1, 1, -1} // condition 1 + }; + + Bdd original = new Bdd( + Parameters.builder().build(), + Arrays.asList( + Condition.builder().fn(TestHelpers.isSet("Region")).build(), + Condition.builder().fn(TestHelpers.isSet("Bucket")).build()), + new ArrayList<>(), + nodes, + 2 // root points to node 1 + ); + + NodeReversal reversal = new NodeReversal(); + Bdd reversed = reversal.apply(original); + + // Mapping: 0->0, 1->2, 2->1 + // Ref mapping: 2->3, 3->2, -2->-3 + assertEquals(3, reversed.getRootRef()); + + // Check complement edge is properly remapped + int[] reversedNode2 = reversed.getNodes()[2]; + assertEquals(0, reversedNode2[0]); // condition index unchanged + assertEquals(2, reversedNode2[1]); // high ref 3 -> 2 + assertEquals(-3, reversedNode2[2]); // complement low ref -2 -> -3 + } + + @Test + void testResultNodes() { + // BDD with result terminals + int[][] nodes = new int[][] { + {-1, 1, -1}, // terminal + {0, Bdd.RESULT_OFFSET + 1, Bdd.RESULT_OFFSET}, // condition 0 at index 1 + {2, 1, -1}, // result 0 at index 2 + {3, 1, -1} // result 1 at index 3 + }; + + List results = Arrays.asList( + EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")), + ErrorRule.builder().error("Error occurred")); + + Bdd original = new Bdd( + Parameters.builder().build(), + Arrays.asList( + Condition.builder().fn(TestHelpers.isSet("Region")).build(), + Condition.builder().fn(TestHelpers.isSet("Bucket")).build()), + results, + nodes, + 2 // root points to node 1 + ); + + NodeReversal reversal = new NodeReversal(); + Bdd reversed = reversal.apply(original); + + assertEquals(4, reversed.getNodes().length); + assertEquals(4, reversed.getRootRef()); // root was ref 2, now ref 4 + + // Terminal stays at 0 + assertArrayEquals(new int[] {-1, 1, -1}, reversed.getNodes()[0]); + + // Original node 3 now at index 1 + assertArrayEquals(new int[] {3, 1, -1}, reversed.getNodes()[1]); + + // Original node 2 stays at index 2 + assertArrayEquals(new int[] {2, 1, -1}, reversed.getNodes()[2]); + + // Original node 1 now at index 3 + int[] conditionNode = reversed.getNodes()[3]; + assertEquals(0, conditionNode[0]); // condition index unchanged + assertEquals(Bdd.RESULT_OFFSET + 1, conditionNode[1]); // result references unchanged + assertEquals(Bdd.RESULT_OFFSET, conditionNode[2]); // result references unchanged + } + + @Test + void testFourNodeExample() { + // Simple 4-node example to verify reference mapping + int[][] nodes = new int[][] { + {-1, 1, -1}, // 0: terminal + {0, 3, 4}, // 1: points to nodes 2 and 3 + {1, 1, -1}, // 2: + {2, 1, -1} // 3: + }; + + Bdd original = new Bdd( + Parameters.builder().build(), + Arrays.asList( + Condition.builder().fn(TestHelpers.isSet("A")).build(), + Condition.builder().fn(TestHelpers.isSet("B")).build(), + Condition.builder().fn(TestHelpers.isSet("C")).build()), + new ArrayList<>(), + nodes, + 2 // root points to node 1 + ); + + NodeReversal reversal = new NodeReversal(); + Bdd reversed = reversal.apply(original); + + // Mapping: 0->0, 1->3, 2->2, 3->1 + // Ref mapping: 2->4, 3->3, 4->2 + assertEquals(4, reversed.getRootRef()); // root ref 2 -> 4 + + int[] nodeAtIndex3 = reversed.getNodes()[3]; // original node 1 + assertEquals(0, nodeAtIndex3[0]); + assertEquals(3, nodeAtIndex3[1]); // ref 3 stays 3 + assertEquals(2, nodeAtIndex3[2]); // ref 4 -> 2 + } + + @Test + void testImmutability() { + // Ensure original BDD is not modified + int[][] originalNodes = new int[][] { + {-1, 1, -1}, + {0, 3, -1}, + {1, 1, -1} + }; + + Bdd original = new Bdd( + Parameters.builder().build(), + Arrays.asList( + Condition.builder().fn(TestHelpers.isSet("Region")).build(), + Condition.builder().fn(TestHelpers.isSet("Bucket")).build()), + new ArrayList<>(), + originalNodes, + 2); + + // Clone original node arrays for comparison + int[][] originalNodesCopy = new int[original.getNodes().length][]; + for (int i = 0; i < original.getNodes().length; i++) { + originalNodesCopy[i] = original.getNodes()[i].clone(); + } + + NodeReversal reversal = new NodeReversal(); + Bdd reversed = reversal.apply(original); + + // Verify original is unchanged + assertEquals(originalNodesCopy.length, original.getNodes().length); + for (int i = 0; i < originalNodesCopy.length; i++) { + assertArrayEquals(originalNodesCopy[i], original.getNodes()[i]); + } + + assertNotSame(original.getNodes(), reversed.getNodes()); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraintsTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraintsTest.java new file mode 100644 index 00000000000..9cf4714c493 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraintsTest.java @@ -0,0 +1,152 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.logic.ConditionInfo; +import software.amazon.smithy.rulesengine.logic.TestHelpers; + +class OrderConstraintsTest { + + @Test + void testIndependentConditions() { + // Two conditions with no dependencies + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.isSet("Bucket")).build(); + + Map infos = createInfoMap(cond1, cond2); + ConditionDependencyGraph graph = new ConditionDependencyGraph(infos); + List conditions = Arrays.asList(cond1, cond2); + + OrderConstraints constraints = new OrderConstraints(graph, conditions); + + // Both conditions can be placed anywhere + assertTrue(constraints.canMove(0, 1)); + assertTrue(constraints.canMove(1, 0)); + + assertEquals(0, constraints.getMinValidPosition(0)); + assertEquals(1, constraints.getMaxValidPosition(0)); + assertEquals(0, constraints.getMinValidPosition(1)); + assertEquals(1, constraints.getMaxValidPosition(1)); + } + + @Test + void testDependentConditions() { + // cond1 defines var, cond2 uses it + Condition cond1 = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition cond2 = Condition.builder() + .fn(BooleanEquals.ofExpressions(Literal.of("{hasRegion}"), Literal.of(true))) + .build(); + + Map infos = createInfoMap(cond1, cond2); + ConditionDependencyGraph graph = new ConditionDependencyGraph(infos); + List conditions = Arrays.asList(cond1, cond2); + + OrderConstraints constraints = new OrderConstraints(graph, conditions); + + // cond1 can only stay in place (cannot move past its dependent) + assertTrue(constraints.canMove(0, 0)); // Stay in place + assertFalse(constraints.canMove(0, 1)); // Cannot move past cond2 + + // cond2 cannot move before cond1 + assertFalse(constraints.canMove(1, 0)); + assertTrue(constraints.canMove(1, 1)); // Stay in place + + assertEquals(0, constraints.getMinValidPosition(0)); + assertEquals(0, constraints.getMaxValidPosition(0)); // Must come before cond2 + assertEquals(1, constraints.getMinValidPosition(1)); // Must come after cond1 + assertEquals(1, constraints.getMaxValidPosition(1)); + } + + @Test + void testChainedDependencies() { + // A -> B -> C dependency chain + Condition condA = Condition.builder().fn(TestHelpers.isSet("input")).result(Identifier.of("var1")).build(); + Condition condB = Condition.builder() + .fn(BooleanEquals.ofExpressions(Literal.of("{var1}"), Literal.of(true))) + .result(Identifier.of("var2")) + .build(); + Condition condC = Condition.builder() + .fn(BooleanEquals.ofExpressions(Literal.of("{var2}"), Literal.of(false))) + .build(); + + Map infos = createInfoMap(condA, condB, condC); + ConditionDependencyGraph graph = new ConditionDependencyGraph(infos); + List conditions = Arrays.asList(condA, condB, condC); + + OrderConstraints constraints = new OrderConstraints(graph, conditions); + + // A can only be at position 0 + assertEquals(0, constraints.getMinValidPosition(0)); + assertEquals(0, constraints.getMaxValidPosition(0)); + + // B must be between A and C + assertEquals(1, constraints.getMinValidPosition(1)); + assertEquals(1, constraints.getMaxValidPosition(1)); + + // C must be last + assertEquals(2, constraints.getMinValidPosition(2)); + assertEquals(2, constraints.getMaxValidPosition(2)); + + // No movement possible in this rigid chain + assertFalse(constraints.canMove(0, 1)); + assertFalse(constraints.canMove(1, 0)); + assertFalse(constraints.canMove(1, 2)); + assertFalse(constraints.canMove(2, 1)); + } + + @Test + void testCanMoveToSamePosition() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + Map infos = createInfoMap(cond); + ConditionDependencyGraph graph = new ConditionDependencyGraph(infos); + List conditions = Collections.singletonList(cond); + + OrderConstraints constraints = new OrderConstraints(graph, conditions); + + // Moving to same position is always allowed + assertTrue(constraints.canMove(0, 0)); + } + + @Test + void testMismatchedSizes() { + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + Map infos = createInfoMap(cond1); + ConditionDependencyGraph graph = new ConditionDependencyGraph(infos); + + // Try to create constraints with more conditions than in graph + List conditions = Arrays.asList( + cond1, + Condition.builder().fn(TestHelpers.isSet("Bucket")).build()); + + assertThrows(IllegalArgumentException.class, () -> new OrderConstraints(graph, conditions)); + } + + private Map createInfoMap(Condition... conditions) { + Map map = new HashMap<>(); + for (Condition cond : conditions) { + map.put(cond, ConditionInfo.from(cond)); + } + return map; + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java new file mode 100644 index 00000000000..b213d1f59d4 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java @@ -0,0 +1,134 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; + +// Does some basic checks, but doesn't get too specific so we can easily change the sifting algorithm. +class SiftingOptimizationTest { + + @Test + void testBasicOptimization() { + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("A").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("B").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("C").type(ParameterType.STRING).build()) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("A")).build(), + Condition.builder().fn(TestHelpers.isSet("B")).build(), + Condition.builder().fn(TestHelpers.isSet("C")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Bdd originalBdd = Bdd.from(cfg); + + SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); + Bdd optimizedBdd = optimizer.apply(originalBdd); + + // Basic checks + assertEquals(originalBdd.getConditions().size(), optimizedBdd.getConditions().size()); + assertEquals(originalBdd.getResults().size(), optimizedBdd.getResults().size()); + + // Size should be same or smaller + assertTrue(optimizedBdd.getNodes().length <= originalBdd.getNodes().length); + } + + @Test + void testDependenciesPreserved() { + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Input").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(Condition.builder() + .fn(TestHelpers.isSet("Input")) + .result(Identifier.of("hasInput")) + .build(), + Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasInput")), + Literal.of(true))) + .build(), + Condition.builder() + .fn(TestHelpers.isSet("Region")) + .build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Bdd originalBdd = Bdd.from(cfg); + + SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); + Bdd optimizedBdd = optimizer.apply(originalBdd); + + // Find the positions of dependent conditions + int hasInputPos = -1; + int usesInputPos = -1; + for (int i = 0; i < optimizedBdd.getConditions().size(); i++) { + Condition cond = optimizedBdd.getConditions().get(i); + if (cond.getResult().isPresent() && + cond.getResult().get().toString().equals("hasInput")) { + hasInputPos = i; + } else if (cond.getFunction().toString().contains("hasInput")) { + usesInputPos = i; + } + } + + // Verify dependency is preserved: definer comes before user + assertTrue(hasInputPos < usesInputPos, + "Condition defining hasInput must come before condition using it"); + } + + @Test + void testSingleCondition() { + // Test a single condition edge case + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).addRule(rule).build(); + Cfg cfg = Cfg.from(ruleSet); + Bdd originalBdd = Bdd.from(cfg); + + SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); + Bdd optimizedBdd = optimizer.apply(originalBdd); + + // Should be unchanged + assertEquals(originalBdd.getNodes().length, optimizedBdd.getNodes().length); + assertEquals(1, optimizedBdd.getConditions().size()); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java new file mode 100644 index 00000000000..2099989d8f1 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java @@ -0,0 +1,280 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.evaluation.value.Value; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.ConditionReference; +import software.amazon.smithy.rulesengine.logic.TestHelpers; + +class CfgBuilderTest { + + private CfgBuilder builder; + private EndpointRuleSet ruleSet; + + @BeforeEach + void setUp() { + Parameter region = Parameter.builder().name("region").type(ParameterType.STRING).build(); + Parameter useFips = Parameter.builder() + .name("useFips") + .type(ParameterType.BOOLEAN) + .defaultValue(Value.booleanValue(false)) + .required(true) + .build(); + ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(region).addParameter(useFips).build()) + .build(); + + builder = new CfgBuilder(ruleSet); + } + + @Test + void buildRequiresNonNullRoot() { + assertThrows(NullPointerException.class, () -> builder.build(null)); + } + + @Test + void buildCreatesValidCfg() { + CfgNode root = ResultNode.terminal(); + Cfg cfg = builder.build(root); + + assertNotNull(cfg); + assertSame(root, cfg.getRoot()); + assertEquals(ruleSet, cfg.getRuleSet()); + } + + @Test + void createResultNodesCachesIdenticalRules() { + Rule rule1 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + Rule rule2 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + + CfgNode node1 = builder.createResult(rule1); + CfgNode node2 = builder.createResult(rule2); + + assertSame(node1, node2); + } + + @Test + void createResultNodesDistinguishesDifferentRules() { + Rule rule1 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example1.com")); + Rule rule2 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example2.com")); + + CfgNode node1 = builder.createResult(rule1); + CfgNode node2 = builder.createResult(rule2); + + assertNotSame(node1, node2); + } + + @Test + void createResultStripsConditionsBeforeCaching() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("region")).build(); + Rule ruleWithCondition = EndpointRule.builder() + .condition(cond) + .endpoint(TestHelpers.endpoint("https://example.com")); + Rule ruleWithoutCondition = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + + CfgNode node1 = builder.createResult(ruleWithCondition); + CfgNode node2 = builder.createResult(ruleWithoutCondition); + + assertSame(node1, node2); + } + + @Test + void createConditionCachesIdenticalNodes() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("region")).build(); + CfgNode trueBranch = ResultNode.terminal(); + CfgNode falseBranch = ResultNode.terminal(); + + CfgNode node1 = builder.createCondition(cond, trueBranch, falseBranch); + CfgNode node2 = builder.createCondition(cond, trueBranch, falseBranch); + + assertSame(node1, node2); + } + + @Test + void createConditionDistinguishesDifferentBranches() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("region")).build(); + CfgNode trueBranch1 = builder.createResult( + EndpointRule.builder().endpoint(TestHelpers.endpoint("https://true1.com"))); + CfgNode trueBranch2 = builder.createResult( + EndpointRule.builder().endpoint(TestHelpers.endpoint("https://true2.com"))); + CfgNode falseBranch = ResultNode.terminal(); + + CfgNode node1 = builder.createCondition(cond, trueBranch1, falseBranch); + CfgNode node2 = builder.createCondition(cond, trueBranch2, falseBranch); + + assertNotSame(node1, node2); + } + + @Test + void createConditionReferenceHandlesSimpleCondition() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("region")).build(); + + ConditionReference ref = builder.createConditionReference(cond); + + assertNotNull(ref); + assertEquals(cond, ref.getCondition()); + assertFalse(ref.isNegated()); + } + + @Test + void createConditionReferenceCachesIdenticalConditions() { + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("region")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.isSet("region")).build(); + + ConditionReference ref1 = builder.createConditionReference(cond1); + ConditionReference ref2 = builder.createConditionReference(cond2); + + assertSame(ref1, ref2); + } + + @Test + void createConditionReferenceHandlesNegation() { + Condition innerCond = Condition.builder().fn(TestHelpers.isSet("region")).build(); + Condition negatedCond = Condition.builder().fn(Not.ofExpressions(innerCond.getFunction())).build(); + ConditionReference ref = builder.createConditionReference(negatedCond); + + assertNotNull(ref); + assertTrue(ref.isNegated()); + assertEquals(innerCond.getFunction(), ref.getCondition().getFunction()); + } + + @Test + void createConditionReferenceSharesInfoForNegatedAndNonNegated() { + Condition cond = Condition.builder().fn(TestHelpers.isSet("region")).build(); + Condition negatedCond = Condition.builder().fn(Not.ofExpressions(cond.getFunction())).build(); + + ConditionReference ref1 = builder.createConditionReference(cond); + ConditionReference ref2 = builder.createConditionReference(negatedCond); + + assertEquals(ref1.getCondition(), ref2.getCondition()); + assertFalse(ref1.isNegated()); + assertTrue(ref2.isNegated()); + } + + @Test + void createConditionReferenceHandlesBooleanEqualsCanonicalizations() { + // Test booleanEquals(useFips, false) -> booleanEquals(useFips, true) with negation + Expression ref = Expression.getReference(Identifier.of("useFips")); + Condition cond = Condition.builder().fn(BooleanEquals.ofExpressions(ref, false)).build(); + + ConditionReference condRef = builder.createConditionReference(cond); + + // Should be canonicalized to booleanEquals(useFips, true) with negation + assertTrue(condRef.isNegated()); + assertInstanceOf(BooleanEquals.class, condRef.getCondition().getFunction()); + + BooleanEquals fn = (BooleanEquals) condRef.getCondition().getFunction(); + assertEquals(ref, fn.getArguments().get(0)); + assertEquals(Literal.booleanLiteral(true), fn.getArguments().get(1)); + } + + @Test + void createConditionReferenceDoesNotCanonicalizeWithoutDefault() { + // Test that booleanEquals(region, false) is not canonicalized (no default) + Expression ref = Expression.getReference(Identifier.of("region")); + Condition cond = Condition.builder().fn(BooleanEquals.ofExpressions(ref, false)).build(); + + ConditionReference condRef = builder.createConditionReference(cond); + + assertFalse(condRef.isNegated()); + assertEquals(cond.getFunction(), condRef.getCondition().getFunction()); + } + + @Test + void createConditionReferenceHandlesCommutativeCanonicalizations() { + Expression ref = Expression.getReference(Identifier.of("region")); + + // Create conditions with different argument orders + Condition cond1 = Condition.builder().fn(StringEquals.ofExpressions(ref, "us-east-1")).build(); + Condition cond2 = Condition.builder().fn(StringEquals.ofExpressions(Expression.of("us-east-1"), ref)).build(); + + ConditionReference ref1 = builder.createConditionReference(cond1); + ConditionReference ref2 = builder.createConditionReference(cond2); + + // Both should produce equivalent canonicalized references. + // They should have the same underlying condition after canonicalization + assertEquals(ref1.getCondition(), ref2.getCondition()); + assertEquals(ref1.isNegated(), ref2.isNegated()); + } + + @Test + void createConditionReferenceHandlesVariableBinding() { + Condition cond = Condition.builder() + .fn(TestHelpers.parseUrl("{url}")) + .result(Identifier.of("parsedUrl")) + .build(); + + ConditionReference ref = builder.createConditionReference(cond); + + assertNotNull(ref); + assertEquals("parsedUrl", ref.getReturnVariable()); + } + + @Test + void createConditionHandlesComplexNesting() { + // Build a nested structure to test caching + CfgNode endpoint1 = builder.createResult( + EndpointRule.builder().endpoint(TestHelpers.endpoint("https://endpoint1.com"))); + CfgNode endpoint2 = builder.createResult( + EndpointRule.builder().endpoint(TestHelpers.endpoint("https://endpoint2.com"))); + CfgNode errorNode = builder.createResult(ErrorRule.builder().error("Invalid configuration")); + + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("region")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.stringEquals("region", "us-east-1")).build(); + + // Create nested conditions + CfgNode inner = builder.createCondition(cond2, endpoint1, endpoint2); + CfgNode outer = builder.createCondition(cond1, inner, errorNode); + + assertInstanceOf(ConditionNode.class, outer); + ConditionNode outerNode = (ConditionNode) outer; + assertEquals(cond1, outerNode.getCondition().getCondition()); + assertSame(inner, outerNode.getTrueBranch()); + assertSame(errorNode, outerNode.getFalseBranch()); + } + + @Test + void createConditionReferenceIgnoresNegationWithVariableBinding() { + // Negation with variable binding should not unwrap + Condition innerCond = Condition.builder().fn(TestHelpers.isSet("region")).build(); + + Condition negatedWithBinding = Condition.builder() + .fn(Not.ofExpressions(innerCond.getFunction())) + .result(Identifier.of("notRegionSet")) + .build(); + + ConditionReference ref = builder.createConditionReference(negatedWithBinding); + + // Should not be treated as simple negation due to variable binding + assertFalse(ref.isNegated()); + assertInstanceOf(Not.class, ref.getCondition().getFunction()); + assertEquals("notRegionSet", ref.getReturnVariable()); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java new file mode 100644 index 00000000000..ba2921510d0 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java @@ -0,0 +1,267 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Set; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.ConditionReference; +import software.amazon.smithy.rulesengine.logic.TestHelpers; + +class CfgTest { + + @Test + void gettersReturnConstructorValues() { + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().build()) + .build(); + CfgNode root = ResultNode.terminal(); + + Cfg cfg = new Cfg(ruleSet, root); + + assertSame(ruleSet, cfg.getRuleSet()); + assertSame(root, cfg.getRoot()); + } + + @Test + void fromCreatesSimpleCfg() { + EndpointRule rule = EndpointRule.builder() + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().build()) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + + assertNotNull(cfg); + assertNotNull(cfg.getRoot()); + assertEquals(ruleSet, cfg.getRuleSet()); + + // Root should be a result node for a simple endpoint rule + assertInstanceOf(ResultNode.class, cfg.getRoot()); + ResultNode resultNode = (ResultNode) cfg.getRoot(); + assertEquals(rule.withoutConditions(), resultNode.getResult()); + } + + @Test + void fromCreatesConditionNode() { + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("region").type(ParameterType.STRING).build()) + .build(); + + EndpointRule rule = EndpointRule.builder() + .condition(Condition.builder().fn(TestHelpers.isSet("region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + + // Root should be a condition node + assertInstanceOf(ConditionNode.class, cfg.getRoot()); + ConditionNode condNode = (ConditionNode) cfg.getRoot(); + assertEquals("isSet(region)", condNode.getCondition().getCondition().toString()); + } + + @Test + void fromHandlesMultipleRules() { + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("region").type(ParameterType.STRING).build()) + .build(); + + // TreeRule with isSet check followed by stringEquals + Rule treeRule = TreeRule.builder() + .condition(Condition.builder().fn(TestHelpers.isSet("region")).build()) + .treeRule( + EndpointRule.builder() + .condition( + Condition.builder().fn(TestHelpers.stringEquals("region", "us-east-1")).build()) + .endpoint(TestHelpers.endpoint("https://us-east-1.example.com")), + EndpointRule.builder() + .condition( + Condition.builder().fn(TestHelpers.stringEquals("region", "eu-west-1")).build()) + .endpoint(TestHelpers.endpoint("https://eu-west-1.example.com"))); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(treeRule) + .addRule(ErrorRule.builder().error("Unknown region")) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + + assertInstanceOf(ConditionNode.class, cfg.getRoot()); + } + + @Test + void getConditionDataCachesResult() { + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("region").type(ParameterType.STRING).build()) + .build(); + + EndpointRule rule = EndpointRule.builder() + .condition(Condition.builder().fn(TestHelpers.isSet("region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + + ConditionData data1 = cfg.getConditionData(); + ConditionData data2 = cfg.getConditionData(); + + assertNotNull(data1); + assertSame(data1, data2); + assertEquals(1, data1.getConditions().length); + } + + @Test + void iteratorVisitsAllNodes() { + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("region").type(ParameterType.STRING).build()) + .build(); + + Rule rule1 = EndpointRule.builder() + .condition(Condition.builder().fn(TestHelpers.isSet("region")).build()) + .endpoint(TestHelpers.endpoint("https://with-region.com")); + + Rule rule2 = EndpointRule.builder() + .endpoint(TestHelpers.endpoint("https://no-region.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule1) + .addRule(rule2) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + + Set visited = new HashSet<>(); + for (CfgNode node : cfg) { + visited.add(node); + } + + // Should have at least 3 nodes: condition node and 2 result nodes + assertTrue(visited.size() >= 3); + } + + @Test + void iteratorHandlesEmptyCfg() { + CfgNode root = ResultNode.terminal(); + Cfg cfg = new Cfg(null, root); + + List nodes = new ArrayList<>(); + for (CfgNode node : cfg) { + nodes.add(node); + } + + assertEquals(1, nodes.size()); + assertSame(root, nodes.get(0)); + } + + @Test + void iteratorThrowsNoSuchElementException() { + CfgNode root = ResultNode.terminal(); + Cfg cfg = new Cfg(null, root); + + Iterator iterator = cfg.iterator(); + assertTrue(iterator.hasNext()); + iterator.next(); + assertFalse(iterator.hasNext()); + + assertThrows(NoSuchElementException.class, iterator::next); + } + + @Test + void iteratorDoesNotVisitNodesTwice() { + // Create a diamond-shaped CFG where multiple paths lead to the same node + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("a").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("b").type(ParameterType.STRING).build()) + .build(); + + CfgBuilder builder = new CfgBuilder(EndpointRuleSet.builder() + .parameters(params) + .build()); + + CfgNode sharedResult = builder.createResult( + EndpointRule.builder().endpoint(TestHelpers.endpoint("https://shared.com"))); + + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("a")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.isSet("b")).build(); + + ConditionReference ref1 = builder.createConditionReference(cond1); + ConditionReference ref2 = builder.createConditionReference(cond2); + + // Both conditions can lead to the same result + CfgNode branch1 = builder.createCondition(ref2, sharedResult, sharedResult); + CfgNode root = builder.createCondition(ref1, branch1, sharedResult); + + Cfg cfg = builder.build(root); + + List visitedNodes = new ArrayList<>(); + for (CfgNode node : cfg) { + visitedNodes.add(node); + } + + // Count occurrences of sharedResult + long sharedResultCount = visitedNodes.stream() + .filter(node -> node == sharedResult) + .count(); + + assertEquals(1, sharedResultCount, "Shared node should only be visited once"); + } + + @Test + void equalsAndHashCodeBasedOnRoot() { + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().build()) + .build(); + CfgNode root1 = ResultNode.terminal(); + CfgNode root2 = new ResultNode( + EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com"))); + + Cfg cfg1a = new Cfg(ruleSet, root1); + Cfg cfg1b = new Cfg(ruleSet, root1); + Cfg cfg2 = new Cfg(ruleSet, root2); + + // Same root + assertEquals(cfg1a, cfg1b); + assertEquals(cfg1a.hashCode(), cfg1b.hashCode()); + + // Different root + assertNotEquals(cfg1a, cfg2); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDataTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDataTest.java new file mode 100644 index 00000000000..886c5eabf91 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDataTest.java @@ -0,0 +1,170 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Map; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.ConditionInfo; +import software.amazon.smithy.rulesengine.logic.TestHelpers; + +class ConditionDataTest { + + @Test + void extractsConditionsFromSimpleCfg() { + // Build a simple ruleset with two conditions + Condition cond1 = Condition.builder() + .fn(TestHelpers.isSet("param1")) + .build(); + + // For stringEquals, we need to ensure param2 is set first + Rule rule = TreeRule.builder() + .condition(cond1) + .treeRule( + TreeRule.builder() + .condition(Condition.builder().fn(TestHelpers.isSet("param2")).build()) + .treeRule( + EndpointRule.builder() + .condition(Condition.builder() + .fn(TestHelpers.stringEquals("param2", "value")) + .build()) + .endpoint(TestHelpers.endpoint("https://example.com")))); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("param1").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("param2").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .addRule(rule) + .parameters(params) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + ConditionData data = ConditionData.from(cfg); + + // Verify condition extraction + Condition[] conditions = data.getConditions(); + assertEquals(3, conditions.length); // isSet(param1), isSet(param2), stringEquals + + // Verify condition infos + Map infos = data.getConditionInfos(); + assertEquals(3, infos.size()); + } + + @Test + void deduplicatesIdenticalConditions() { + // Create identical conditions used in different rules + Condition cond = Condition.builder() + .fn(TestHelpers.isSet("param")) + .build(); + + Rule rule1 = EndpointRule.builder().conditions(cond).endpoint(TestHelpers.endpoint("https://endpoint1.com")); + Rule rule2 = EndpointRule.builder().conditions(cond).endpoint(TestHelpers.endpoint("https://endpoint2.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("param").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule1) + .addRule(rule2) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + ConditionData data = ConditionData.from(cfg); + + // Should only have one condition despite being used twice + assertEquals(1, data.getConditions().length); + assertEquals(cond, data.getConditions()[0]); + } + + @Test + void handlesNestedTreeRules() { + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("param1")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.isSet("param2")).build(); + Condition cond3 = Condition.builder().fn(TestHelpers.isSet("param3")).build(); + + Rule innerRule = TreeRule.builder() + .conditions(cond2) + .treeRule(EndpointRule.builder() + .condition(cond3) + .endpoint(TestHelpers.endpoint("https://example.com"))); + + Rule outerRule = TreeRule.builder().condition(cond1).treeRule(innerRule); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("param1").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("param2").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("param3").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .addRule(outerRule) + .parameters(params) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + ConditionData data = ConditionData.from(cfg); + + // Should extract all three conditions + assertEquals(3, data.getConditions().length); + assertEquals(3, data.getConditionInfos().size()); + } + + @Test + void handlesCfgWithOnlyResults() { + // Rule with no conditions, just a result + Rule rule = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://default.com")); + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().build()) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + ConditionData data = ConditionData.from(cfg); + + // Should have no conditions + assertEquals(0, data.getConditions().length); + assertTrue(data.getConditionInfos().isEmpty()); + } + + @Test + void cachesResultOnCfg() { + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("param").type(ParameterType.STRING).build()) + .build(); + + EndpointRule rule = EndpointRule.builder() + .condition(Condition.builder().fn(TestHelpers.isSet("param")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + Cfg cfg = Cfg.from(ruleSet); + + // First call should create the data + ConditionData data1 = cfg.getConditionData(); + assertNotNull(data1); + + // Second call should return the same instance + ConditionData data2 = cfg.getConditionData(); + assertSame(data1, data2); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableDisambiguatorTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableDisambiguatorTest.java new file mode 100644 index 00000000000..0ad5494d09d --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableDisambiguatorTest.java @@ -0,0 +1,236 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; + +public class VariableDisambiguatorTest { + + @Test + void testNoDisambiguationNeeded() { + // When variables are not shadowed, they should remain unchanged + Parameter bucketParam = Parameter.builder() + .name("Bucket") + .type(ParameterType.STRING) + .build(); + + Condition condition1 = Condition.builder() + .fn(StringEquals.ofExpressions(Expression.of("Bucket"), Expression.of("mybucket"))) + .result("bucketMatches") + .build(); + + EndpointRule rule = EndpointRule.builder() + .conditions(Collections.singletonList(condition1)) + .endpoint(endpoint("https://example.com")); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(bucketParam).build()) + .rules(Collections.singletonList(rule)) + .version("1.0") + .build(); + + EndpointRuleSet result = SsaTransform.transform(original); + + // Should be unchanged + assertEquals(original, result); + } + + @Test + void testSimpleShadowing() { + // Test when the same variable name is bound to different expressions + Parameter param = Parameter.builder() + .name("Input") + .type(ParameterType.STRING) + .build(); + + // Create rules that will have shadowing after disambiguation + List rules = Arrays.asList( + createRuleWithBinding("Input", "a", "temp", "https://branch1.com"), + createRuleWithBinding("Input", "b", "temp", "https://branch2.com")); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(param).build()) + .rules(rules) + .version("1.0") + .build(); + + EndpointRuleSet result = SsaTransform.transform(original); + + // The second "temp" should be renamed + List resultRules = result.getRules(); + assertEquals(2, resultRules.size()); + + // First rule should keep "temp" + EndpointRule resultRule1 = (EndpointRule) resultRules.get(0); + assertEquals("temp", resultRule1.getConditions().get(0).getResult().get().toString()); + + // Second rule should have "temp_1" + EndpointRule resultRule2 = (EndpointRule) resultRules.get(1); + assertEquals("temp_1", resultRule2.getConditions().get(0).getResult().get().toString()); + } + + @Test + void testMultipleShadowsOfSameVariable() { + // Test when a variable is shadowed multiple times + Parameter param = Parameter.builder() + .name("Input") + .type(ParameterType.STRING) + .build(); + + List rules = Arrays.asList( + createRuleWithBinding("Input", "x", "temp", "https://1.com"), + createRuleWithBinding("Input", "y", "temp", "https://2.com"), + createRuleWithBinding("Input", "z", "temp", "https://3.com")); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(param).build()) + .rules(rules) + .version("1.0") + .build(); + + EndpointRuleSet result = SsaTransform.transform(original); + + List resultRules = result.getRules(); + assertEquals("temp", (resultRules.get(0)).getConditions().get(0).getResult().get().toString()); + assertEquals("temp_1", (resultRules.get(1)).getConditions().get(0).getResult().get().toString()); + assertEquals("temp_2", (resultRules.get(2)).getConditions().get(0).getResult().get().toString()); + } + + @Test + void testErrorRuleHandling() { + // Test that error rules are handled correctly + Parameter param = Parameter.builder() + .name("Input") + .type(ParameterType.STRING) + .build(); + + Condition cond = Condition.builder() + .fn(StringEquals.ofExpressions(Expression.of("Input"), Expression.of("error"))) + .result("hasError") + .build(); + + ErrorRule errorRule = ErrorRule.builder() + .conditions(Collections.singletonList(cond)) + .error(Expression.of("Error occurred")); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(param).build()) + .rules(Collections.singletonList(errorRule)) + .version("1.0") + .build(); + + EndpointRuleSet result = SsaTransform.transform(original); + + // Should handle error rules without issues + assertEquals(1, result.getRules().size()); + assertInstanceOf(ErrorRule.class, result.getRules().get(0)); + } + + @Test + void testTreeRuleHandling() { + // Test tree rules with unique variable names at each level + Parameter param = Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build(); + + // Outer condition with one variable + Condition outerCond = Condition.builder() + .fn(StringEquals.ofExpressions(Expression.of("Region"), Expression.of("us-*"))) + .result("isUS") + .build(); + + // Inner rules with their own variables + EndpointRule innerRule1 = createRuleWithBinding("Region", "us-east-1", "isEast", "https://east.com"); + EndpointRule innerRule2 = createRuleWithBinding("Region", "us-west-2", "isWest", "https://west.com"); + + TreeRule treeRule = TreeRule.builder() + .conditions(Collections.singletonList(outerCond)) + .treeRule(innerRule1, innerRule2); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(param).build()) + .rules(Collections.singletonList(treeRule)) + .version("1.0") + .build(); + + EndpointRuleSet result = SsaTransform.transform(original); + + // Check structure is preserved + assertInstanceOf(TreeRule.class, result.getRules().get(0)); + TreeRule resultTree = (TreeRule) result.getRules().get(0); + assertEquals(2, resultTree.getRules().size()); + } + + @Test + void testParameterShadowingAttempt() { + // Test that attempting to shadow a parameter gets disambiguated + Parameter bucketParam = Parameter.builder() + .name("Bucket") + .type(ParameterType.STRING) + .build(); + + // Create a condition that assigns to "Bucket_shadow" to avoid direct conflict + Condition shadowingCond = Condition.builder() + .fn(StringEquals.ofExpressions(Expression.of("Bucket"), Expression.of("test"))) + .result("Bucket_shadow") + .build(); + + EndpointRule rule = EndpointRule.builder() + .conditions(Collections.singletonList(shadowingCond)) + .endpoint(endpoint("https://example.com")); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(Parameters.builder().addParameter(bucketParam).build()) + .rules(Collections.singletonList(rule)) + .version("1.0") + .build(); + + EndpointRuleSet result = SsaTransform.transform(original); + + // Should handle without issues + EndpointRule resultRule = (EndpointRule) result.getRules().get(0); + assertEquals("Bucket_shadow", resultRule.getConditions().get(0).getResult().get().toString()); + } + + private static EndpointRule createRuleWithBinding(String param, String value, String resultVar, String url) { + Condition cond = Condition.builder() + .fn(StringEquals.ofExpressions(Expression.of(param), Expression.of(value))) + .result(resultVar) + .build(); + + return EndpointRule.builder() + .conditions(Collections.singletonList(cond)) + .endpoint(endpoint(url)); + } + + private static Expression expr(String value) { + return Literal.stringLiteral(Template.fromString(value)); + } + + private static Endpoint endpoint(String value) { + return Endpoint.builder().url(expr(value)).build(); + } +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/substring.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/substring.smithy index 42d93b922d4..f76b3d279bc 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/substring.smithy +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/substring.smithy @@ -158,7 +158,7 @@ use smithy.rules#endpointTests "documentation": "unicode characters always return `None`", "params": { "TestCaseId": "1", - "Input": "abcdef\uD83D\uDC31" + "Input": "\uD83D\uDC31abcdef" }, "expect": { "error": "No tests matched" @@ -168,7 +168,7 @@ use smithy.rules#endpointTests "documentation": "non-ascii cause substring to always return `None`", "params": { "TestCaseId": "1", - "Input": "abcdef\u0080" + "Input": "ab\u0080cdef" }, "expect": { "error": "No tests matched" diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/invalid-rules/empty-rule.json5 b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/invalid-rules/empty-rule.json5 index d9f6d0bffb2..c508ad3f5cc 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/invalid-rules/empty-rule.json5 +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/invalid-rules/empty-rule.json5 @@ -1,7 +1,7 @@ // when parsing endpoint ruleset // while parsing rule // at invalid-rules/empty-rule.json5:15 -// Missing expected member `conditions`. +// Missing expected member `type`. { "version": "1.2", "parameters": { diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json index 92b6d8292e5..eda21ee8440 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json @@ -9,7 +9,6 @@ }, "rules": [ { - "conditions": [], "documentation": "base rule", "endpoint": { "url": "https://{Region}.amazonaws.com", diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.errors new file mode 100644 index 00000000000..5ead7e2607a --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.errors @@ -0,0 +1 @@ +[ERROR] smithy.example#ValidBddService: Error creating trait `smithy.rules#bdd`: Input byte array has wrong 4-byte ending unit | Model diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.smithy new file mode 100644 index 00000000000..e99bbd51f0f --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.smithy @@ -0,0 +1,35 @@ +$version: "2.0" + +namespace smithy.example + +use smithy.rules#bdd + +@bdd({ + parameters: { + Region: { + type: "string" + required: true + documentation: "The AWS region" + } + } + conditions: [ + { + fn: "isSet" + argv: [{ref: "Region"}] + } + ] + results: [ + { + type: "endpoint" + endpoint: { + url: "https://service.{Region}.amazonaws.com" + } + } + ] + nodes: "ABCD=" // invalid base64 + nodeCount: 3 + root: 1 +}) +service ValidBddService { + version: "2022-01-01" +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.errors new file mode 100644 index 00000000000..a94a00b4490 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.errors @@ -0,0 +1 @@ +[ERROR] smithy.example#ValidBddService: Error creating trait `smithy.rules#bdd`: Failed to decode BDD nodes | Model diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.smithy new file mode 100644 index 00000000000..c89c0d4967e --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.smithy @@ -0,0 +1,51 @@ +$version: "2.0" + +namespace smithy.example + +use smithy.rules#bdd + +@bdd({ + parameters: { + Region: { + type: "string" + required: true + documentation: "The AWS region" + } + UseFips: { + type: "boolean" + required: true + default: false + documentation: "Use FIPS endpoints" + } + } + conditions: [ + { + fn: "isSet" + argv: [{ref: "Region"}] + } + { + fn: "booleanEquals" + argv: [{ref: "UseFips"}, true] + } + ] + results: [ + { + type: "endpoint" + endpoint: { + url: "https://service.{Region}.amazonaws.com" + } + } + { + type: "endpoint" + endpoint: { + url: "https://service-fips.{Region}.amazonaws.com" + } + } + ] + nodes: "AQB" // bad data, valid base64 + nodeCount: 3 + root: 1 +}) +service ValidBddService { + version: "2022-01-01" +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.errors new file mode 100644 index 00000000000..2cdf5489406 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.errors @@ -0,0 +1 @@ +[ERROR] smithy.example#InvalidRootRefService: Error creating trait `smithy.rules#bdd`: Root reference cannot be complemented: -5 | Model diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy new file mode 100644 index 00000000000..a22a1ea2a45 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy @@ -0,0 +1,21 @@ +$version: "2.0" + +namespace smithy.example + +use smithy.rules#bdd + +@bdd({ + parameters: {} + conditions: [] + results: [] + nodes: "AAAA" // Base64 encoded empty node array + root: -5 // Invalid negative root reference (only -1 is allowed for FALSE) + nodeCount: 0 +}) +service InvalidRootRefService { + version: "2022-01-01" + operations: [GetThing] +} + +@readonly +operation GetThing {} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.errors new file mode 100644 index 00000000000..a4999d204d9 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.errors @@ -0,0 +1 @@ +[ERROR] smithy.example#ValidBddService: Error creating trait `smithy.rules#bdd`: Extra data found after decoding 2 nodes. 9 bytes remaining. | Model diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.smithy new file mode 100644 index 00000000000..c22a4416804 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.smithy @@ -0,0 +1,51 @@ +$version: "2.0" + +namespace smithy.example + +use smithy.rules#bdd + +@bdd({ + parameters: { + Region: { + type: "string" + required: true + documentation: "The AWS region" + } + UseFips: { + type: "boolean" + required: true + default: false + documentation: "Use FIPS endpoints" + } + } + conditions: [ + { + fn: "isSet" + argv: [{ref: "Region"}] + } + { + fn: "booleanEquals" + argv: [{ref: "UseFips"}, true] + } + ] + results: [ + { + type: "endpoint" + endpoint: { + url: "https://service.{Region}.amazonaws.com" + } + } + { + type: "endpoint" + endpoint: { + url: "https://service-fips.{Region}.amazonaws.com" + } + } + ] + nodes: "AQIBAAQBAoKS9AGAkvQB" + nodeCount: 2 + root: 1 +}) +service ValidBddService { + version: "2022-01-01" +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.errors new file mode 100644 index 00000000000..d55caf64ef8 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.errors @@ -0,0 +1 @@ +[WARNING] smithy.example#ValidBddService: This shape applies a trait that is unstable: smithy.rules#bdd | UnstableTrait.smithy.rules#bdd diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy new file mode 100644 index 00000000000..637a8bd4a6b --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy @@ -0,0 +1,56 @@ +$version: "2.0" + +namespace smithy.example + +use smithy.rules#bdd + +@bdd({ + "parameters": { + "Region": { + "required": true, + "documentation": "The AWS region", + "type": "string" + }, + "UseFips": { + "required": true, + "default": false, + "documentation": "Use FIPS endpoints", + "type": "boolean" + } + }, + "conditions": [ + { + "fn": "booleanEquals", + "argv": [ + { + "ref": "UseFips" + }, + true + ] + } + ], + "results": [ + { + "endpoint": { + "url": "https://service-fips.{Region}.amazonaws.com", + "properties": {}, + "headers": {} + }, + "type": "endpoint" + }, + { + "endpoint": { + "url": "https://service.{Region}.amazonaws.com", + "properties": {}, + "headers": {} + }, + "type": "endpoint" + } + ], + "root": 2, + "nodes": "AQIBAIKEr1+EhK9f", + "nodeCount": 2 +}) +service ValidBddService { + version: "2022-01-01" +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/endpoint-tests-without-ruleset.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/endpoint-tests-without-ruleset.errors index 053df0d3352..b0779649d6f 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/endpoint-tests-without-ruleset.errors +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/endpoint-tests-without-ruleset.errors @@ -1 +1 @@ -[ERROR] smithy.example#InvalidService: Trait `smithy.rules#endpointTests` cannot be applied to `smithy.example#InvalidService`. This trait may only be applied to shapes that match the following selector: service[trait|smithy.rules#endpointRuleSet] | TraitTarget +[ERROR] smithy.example#InvalidService: Trait `smithy.rules#endpointTests` cannot be applied to `smithy.example#InvalidService`. This trait may only be applied to shapes that match the following selector: service :is([trait|smithy.rules#endpointRuleSet], [trait|smithy.rules#bdd]) | TraitTarget From 448d89f5dd9e6c8e31ab7c8241e21e280e2f44c5 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Wed, 23 Jul 2025 16:11:27 -0500 Subject: [PATCH 02/23] Add separate BddFormatter --- .../smithy/rulesengine/logic/bdd/Bdd.java | 163 ++++++------- .../rulesengine/logic/bdd/BddFormatter.java | 217 ++++++++++++++++++ .../smithy/rulesengine/logic/bdd/BddTest.java | 12 - 3 files changed, 288 insertions(+), 104 deletions(-) create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddFormatter.java diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java index ef7fd0005d5..48802443e06 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java @@ -4,6 +4,11 @@ */ package software.amazon.smithy.rulesengine.logic.bdd; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.Writer; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.List; import java.util.Objects; @@ -247,107 +252,81 @@ public int hashCode() { @Override public String toString() { - return toString(new StringBuilder()).toString(); - } - - /** - * Appends a string representation to the given StringBuilder. - * - * @param sb the StringBuilder to append to - * @return the given string builder. - */ - public StringBuilder toString(StringBuilder sb) { - // Calculate max width needed for first column identifiers - int maxConditionIdx = conditions.size() - 1; - int maxResultIdx = results.size() - 1; - - // Width needed for "C" + maxConditionIdx or "R" + maxResultIdx - int conditionWidth = maxConditionIdx >= 0 ? String.valueOf(maxConditionIdx).length() + 1 : 2; - int resultWidth = maxResultIdx >= 0 ? String.valueOf(maxResultIdx).length() + 1 : 2; - int varWidth = Math.max(conditionWidth, resultWidth); - - sb.append("Bdd{\n"); - - // Conditions - sb.append(" conditions (").append(getConditionCount()).append("):\n"); - for (int i = 0; i < conditions.size(); i++) { - sb.append(String.format(" %" + varWidth + "s: %s%n", "C" + i, conditions.get(i))); - } - - // Results - sb.append(" results (").append(results.size()).append("):\n"); - for (int i = 0; i < results.size(); i++) { - sb.append(String.format(" %" + varWidth + "s: ", "R" + i)); - appendResult(sb, results.get(i)); - sb.append("\n"); - } + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + Writer writer = new OutputStreamWriter(baos, StandardCharsets.UTF_8); + + // Calculate width for condition/result indices + int maxConditionIdx = conditions.size() - 1; + int maxResultIdx = results.size() - 1; + int conditionWidth = maxConditionIdx >= 0 ? String.valueOf(maxConditionIdx).length() + 1 : 2; + int resultWidth = maxResultIdx >= 0 ? String.valueOf(maxResultIdx).length() + 1 : 2; + int varWidth = Math.max(conditionWidth, resultWidth); + + writer.write("Bdd{\n"); + + // Write conditions + writer.write(" conditions ("); + writer.write(String.valueOf(getConditionCount())); + writer.write("):\n"); + for (int i = 0; i < conditions.size(); i++) { + writer.write(String.format(" %" + varWidth + "s: %s%n", "C" + i, conditions.get(i))); + } - // Root - sb.append(" root: ").append(formatReference(rootRef)).append("\n"); - - // Nodes - sb.append(" nodes (").append(nodes.length).append("):\n"); - - // Calculate width needed for node indices - int indexWidth = String.valueOf(nodes.length - 1).length(); - - for (int i = 0; i < nodes.length; i++) { - sb.append(String.format(" %" + indexWidth + "d: ", i)); - if (i == 0) { - sb.append("terminal"); - } else { - int[] node = nodes[i]; - int varIdx = node[0]; - sb.append("["); - - // Use the calculated width for variable/result references - if (varIdx < conditions.size()) { - sb.append(String.format("%" + varWidth + "s", "C" + varIdx)); - } else { - sb.append(String.format("%" + varWidth + "s", "R" + (varIdx - conditions.size()))); - } - - // Format the references with consistent spacing - sb.append(", ") - .append(String.format("%6s", formatReference(node[1]))) - .append(", ") - .append(String.format("%6s", formatReference(node[2]))) - .append("]"); + // Write results + writer.write(" results ("); + writer.write(String.valueOf(results.size())); + writer.write("):\n"); + for (int i = 0; i < results.size(); i++) { + writer.write(String.format(" %" + varWidth + "s: ", "R" + i)); + appendResult(writer, results.get(i)); + writer.write("\n"); } - sb.append("\n"); - } - sb.append("}"); - return sb; + // Write root + writer.write(" root: "); + writer.write(BddFormatter.formatReference(rootRef)); + writer.write("\n"); + + // Write nodes header + writer.write(" nodes ("); + writer.write(String.valueOf(nodes.length)); + writer.write("):\n"); + + writer.flush(); + + // Use BddFormatter for nodes - no need to strip anything since we control the indent + BddFormatter.builder() + .writer(writer) + .nodes(nodes) + .rootRef(rootRef) + .conditionCount(conditions.size()) + .resultCount(results.size()) + .indent(" ") + .build() + .format(); + + writer.write("}"); + writer.flush(); + + return baos.toString(StandardCharsets.UTF_8.name()); + } catch (IOException e) { + // Should never happen with ByteArrayOutputStream + throw new RuntimeException("Failed to format BDD", e); + } } - private void appendResult(StringBuilder sb, Rule result) { + private void appendResult(Writer writer, Rule result) throws IOException { if (result == null) { - sb.append("(no match)"); + writer.write("(no match)"); } else if (result instanceof EndpointRule) { - sb.append("Endpoint: ").append(((EndpointRule) result).getEndpoint().getUrl()); + writer.write("Endpoint: "); + writer.write(((EndpointRule) result).getEndpoint().getUrl().toString()); } else if (result instanceof ErrorRule) { - sb.append("Error: ").append(((ErrorRule) result).getError()); - } else { - sb.append(result.getClass().getSimpleName()); - } - } - - private String formatReference(int ref) { - if (ref == 0) { - return "INVALID"; - } else if (ref == 1) { - return "TRUE"; - } else if (ref == -1) { - return "FALSE"; - } else if (ref >= Bdd.RESULT_OFFSET) { - // This is a result reference - int resultIdx = ref - Bdd.RESULT_OFFSET; - return "R" + resultIdx; - } else if (ref < 0) { - return "!" + (Math.abs(ref) - 1); + writer.write("Error: "); + writer.write(((ErrorRule) result).getError().toString()); } else { - return String.valueOf(ref - 1); + writer.write(result.getClass().getSimpleName()); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddFormatter.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddFormatter.java new file mode 100644 index 00000000000..fd415876f08 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddFormatter.java @@ -0,0 +1,217 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.UncheckedIOException; +import java.io.Writer; +import java.nio.charset.StandardCharsets; + +/** + * Formats BDD node structures to a stream without building the entire representation in memory. + */ +public final class BddFormatter { + + private final Writer writer; + private final int[][] nodes; + private final int rootRef; + private final int conditionCount; + private final int resultCount; + private final String indent; + + private BddFormatter(Builder builder) { + this.writer = builder.writer; + this.nodes = builder.nodes; + this.rootRef = builder.rootRef; + this.conditionCount = builder.conditionCount; + this.resultCount = builder.resultCount; + this.indent = builder.indent; + } + + /** + * Creates a builder for BddFormatter. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Formats the BDD node structure. + */ + public void format() { + try { + // Calculate formatting widths + FormatContext ctx = calculateFormatContext(); + + // Write root + writer.write(indent); + writer.write("Root: "); + writer.write(formatReference(rootRef)); + writer.write("\n"); + + // Write nodes + writer.write(indent); + writer.write("Nodes:\n"); + + for (int i = 0; i < nodes.length; i++) { + writer.write(indent); + writer.write(" "); + writer.write(String.format("%" + ctx.indexWidth + "d", i)); + writer.write(": "); + + if (i == 0 && nodes[i][0] == -1) { + writer.write("terminal"); + } else { + formatNode(nodes[i], ctx); + } + writer.write("\n"); + } + + writer.flush(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private FormatContext calculateFormatContext() { + int maxVarIdx = -1; + + // Scan nodes to find max variable index + for (int i = 1; i < nodes.length; i++) { + int varIdx = nodes[i][0]; + if (varIdx >= 0) { + maxVarIdx = Math.max(maxVarIdx, varIdx); + } + } + + // Calculate widths + int conditionWidth = conditionCount > 0 ? String.valueOf(conditionCount - 1).length() + 1 : 2; + int resultWidth = resultCount > 0 ? String.valueOf(resultCount - 1).length() + 1 : 2; + int varWidth = Math.max(Math.max(conditionWidth, resultWidth), String.valueOf(maxVarIdx).length()); + int indexWidth = String.valueOf(nodes.length - 1).length(); + + return new FormatContext(varWidth, indexWidth); + } + + private void formatNode(int[] node, FormatContext ctx) throws IOException { + writer.write("["); + + // Variable reference + int varIdx = node[0]; + String varRef = formatVariableIndex(varIdx); + writer.write(String.format("%" + ctx.varWidth + "s", varRef)); + + // High and low references + writer.write(", "); + writer.write(String.format("%6s", formatReference(node[1]))); + writer.write(", "); + writer.write(String.format("%6s", formatReference(node[2]))); + writer.write("]"); + } + + private String formatVariableIndex(int varIdx) { + if (conditionCount > 0 && varIdx < conditionCount) { + return "C" + varIdx; + } else if (conditionCount > 0 && resultCount > 0) { + return "R" + (varIdx - conditionCount); + } else { + return String.valueOf(varIdx); + } + } + + /** + * Formats a BDD reference (node pointer) to a human-readable string. + * + * @param ref the reference to format + * @return the formatted reference string + */ + public static String formatReference(int ref) { + if (ref == 0) { + return "INVALID"; + } else if (ref == 1) { + return "TRUE"; + } else if (ref == -1) { + return "FALSE"; + } else if (ref >= Bdd.RESULT_OFFSET) { + return "R" + (ref - Bdd.RESULT_OFFSET); + } else if (ref < 0) { + return "!" + (-ref - 1); + } else { + return String.valueOf(ref - 1); + } + } + + private static class FormatContext { + final int varWidth; + final int indexWidth; + + FormatContext(int varWidth, int indexWidth) { + this.varWidth = varWidth; + this.indexWidth = indexWidth; + } + } + + /** + * Builder for BddFormatter. + */ + public static final class Builder { + private Writer writer; + private int[][] nodes; + private int rootRef; + private int conditionCount = 0; + private int resultCount = 0; + private String indent = ""; + + private Builder() {} + + public Builder writer(Writer writer) { + this.writer = writer; + return this; + } + + public Builder writer(OutputStream out) { + return writer(new OutputStreamWriter(out, StandardCharsets.UTF_8)); + } + + public Builder nodes(int[][] nodes) { + this.nodes = nodes; + return this; + } + + public Builder rootRef(int rootRef) { + this.rootRef = rootRef; + return this; + } + + public Builder conditionCount(int conditionCount) { + this.conditionCount = conditionCount; + return this; + } + + public Builder resultCount(int resultCount) { + this.resultCount = resultCount; + return this; + } + + public Builder indent(String indent) { + this.indent = indent; + return this; + } + + public BddFormatter build() { + if (writer == null) { + throw new IllegalStateException("writer is required"); + } + if (nodes == null) { + throw new IllegalStateException("nodes are required"); + } + return new BddFormatter(this); + } + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java index d2b24c04e40..8493ae0ea25 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java @@ -140,18 +140,6 @@ void testToString() { assertTrue(str.contains("nodes")); } - @Test - void testToStringBuilder() { - Bdd bdd = createSimpleBdd(); - StringBuilder sb = new StringBuilder(); - bdd.toString(sb); - - String str = sb.toString(); - assertTrue(str.contains("C0:")); // Condition index - assertTrue(str.contains("R0:")); // Result index - assertTrue(str.contains("terminal")); // Terminal node - } - @Test void testToNodeAndFromNode() { Bdd original = createSimpleBdd(); From af9cd9a4f2a28768670a81393676a3eb712601ff Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Thu, 24 Jul 2025 20:34:19 -0500 Subject: [PATCH 03/23] Add BDD validation, same as ruleset trait --- .../RuleSetAwsBuiltInValidator.java | 39 ++-- .../language/evaluation/RuleEvaluator.java | 91 ++++++-- .../language/evaluation/TestEvaluator.java | 21 +- .../language/syntax/parameters/Parameter.java | 3 +- .../EndpointTestsTraitValidator.java | 125 ++++++----- .../OperationContextParamsTraitValidator.java | 5 +- .../RuleSetAuthSchemesValidator.java | 206 ++++++++++-------- .../validators/RuleSetBuiltInValidator.java | 60 +++-- .../RuleSetParamMissingDocsValidator.java | 29 ++- .../validators/RuleSetParameterValidator.java | 104 ++++----- .../validators/RuleSetTestCaseValidator.java | 46 ++-- .../validators/RuleSetUriValidator.java | 114 ++++------ .../StaticContextParamsTraitValidator.java | 5 +- ...e.amazon.smithy.model.validation.Validator | 6 +- .../invalid/invalid-endpoint-uri.errors | 3 + .../invalid/invalid-endpoint-uri.smithy | 47 ++++ .../rulesengine/language/minimal-ruleset.json | 2 +- .../traits/errorfiles/bdd/bdd-valid.errors | 1 + .../traits/errorfiles/bdd/bdd-valid.smithy | 5 + 19 files changed, 551 insertions(+), 361 deletions(-) rename smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/{traits => validators}/EndpointTestsTraitValidator.java (64%) rename smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/{traits => validators}/OperationContextParamsTraitValidator.java (97%) rename smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/{traits => validators}/StaticContextParamsTraitValidator.java (89%) create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/invalid-endpoint-uri.errors create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/invalid-endpoint-uri.smithy diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java index e0c5cdd0663..e9005a82598 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java @@ -6,7 +6,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.Optional; import java.util.Set; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.Model; @@ -14,8 +13,8 @@ import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.aws.language.functions.AwsBuiltIns; -import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.traits.BddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.utils.SetUtils; @@ -33,36 +32,30 @@ public class RuleSetAwsBuiltInValidator extends AbstractValidator { @Override public List validate(Model model) { List events = new ArrayList<>(); - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { - events.addAll(validateRuleSetAwsBuiltIns(serviceShape, - serviceShape.expectTrait(EndpointRuleSetTrait.class) - .getEndpointRuleSet())); + + for (ServiceShape s : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { + EndpointRuleSetTrait trait = s.expectTrait(EndpointRuleSetTrait.class); + validateRuleSetAwsBuiltIns(events, s, trait.getEndpointRuleSet().getParameters()); + } + + for (ServiceShape s : model.getServiceShapesWithTrait(BddTrait.class)) { + validateRuleSetAwsBuiltIns(events, s, s.expectTrait(BddTrait.class).getBdd().getParameters()); } + return events; } - private List validateRuleSetAwsBuiltIns(ServiceShape serviceShape, EndpointRuleSet ruleSet) { - List events = new ArrayList<>(); - for (Parameter parameter : ruleSet.getParameters()) { + private void validateRuleSetAwsBuiltIns(List events, ServiceShape s, Iterable params) { + for (Parameter parameter : params) { if (parameter.isBuiltIn()) { - validateBuiltIn(serviceShape, parameter.getBuiltIn().get(), parameter).ifPresent(events::add); + validateBuiltIn(events, s, parameter.getBuiltIn().get(), parameter); } } - return events; } - private Optional validateBuiltIn( - ServiceShape serviceShape, - String builtInName, - FromSourceLocation source - ) { - if (ADDITIONAL_CONSIDERATION_BUILT_INS.contains(builtInName)) { - return Optional.of(danger( - serviceShape, - source, - String.format(ADDITIONAL_CONSIDERATION_MESSAGE, builtInName), - builtInName)); + private void validateBuiltIn(List events, ServiceShape s, String name, FromSourceLocation source) { + if (ADDITIONAL_CONSIDERATION_BUILT_INS.contains(name)) { + events.add(danger(s, source, String.format(ADDITIONAL_CONSIDERATION_MESSAGE, name), name)); } - return Optional.empty(); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java index a913bb9a3d3..a7842c47c4a 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java @@ -21,8 +21,13 @@ import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.language.syntax.rule.RuleValueVisitor; +import software.amazon.smithy.rulesengine.logic.RuleBasedConditionEvaluator; +import software.amazon.smithy.rulesengine.logic.bdd.Bdd; +import software.amazon.smithy.rulesengine.logic.bdd.BddEvaluator; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -45,6 +50,44 @@ public static Value evaluate(EndpointRuleSet ruleset, Map par return new RuleEvaluator().evaluateRuleSet(ruleset, parameterArguments); } + /** + * Initializes a new {@link RuleEvaluator} instances, and evaluates the provided BDD and parameter arguments. + * + * @param bdd The endpoint bdd. + * @param parameterArguments The rule-set parameter identifiers and values to evaluate the BDD against. + * @return The resulting value from the final matched rule. + */ + public static Value evaluate(Bdd bdd, Map parameterArguments) { + return new RuleEvaluator().evaluateBdd(bdd, parameterArguments); + } + + private Value evaluateBdd(Bdd bdd, Map parameterArguments) { + return scope.inScope(() -> { + for (Parameter parameter : bdd.getParameters()) { + parameter.getDefault().ifPresent(value -> scope.insert(parameter.getName(), value)); + } + + parameterArguments.forEach(scope::insert); + BddEvaluator evaluator = BddEvaluator.from(bdd); + Condition[] conds = bdd.getConditions().toArray(new Condition[0]); + RuleBasedConditionEvaluator conditionEvaluator = new RuleBasedConditionEvaluator(this, conds); + int result = evaluator.evaluate(conditionEvaluator); + + if (result <= 0) { + throw new RuntimeException("No BDD result matched"); + } + + Rule rule = bdd.getResults().get(result); + if (rule instanceof EndpointRule) { + return resolveEndpoint(this, ((EndpointRule) rule).getEndpoint()); + } else if (rule instanceof ErrorRule) { + return resolveError(this, ((ErrorRule) rule).getError()); + } else { + throw new RuntimeException("Invalid BDD rule result: " + rule); + } + }); + } + /** * Evaluate the provided ruleset and parameter arguments. * @@ -175,32 +218,40 @@ public Value visitTreeRule(List rules) { @Override public Value visitErrorRule(Expression error) { - return error.accept(self); + return resolveError(self, error); } @Override public Value visitEndpointRule(Endpoint endpoint) { - EndpointValue.Builder builder = EndpointValue.builder() - .sourceLocation(endpoint) - .url(endpoint.getUrl() - .accept(RuleEvaluator.this) - .expectStringValue() - .getValue()); - - for (Map.Entry entry : endpoint.getProperties().entrySet()) { - builder.putProperty(entry.getKey().toString(), entry.getValue().accept(RuleEvaluator.this)); - } - - for (Map.Entry> entry : endpoint.getHeaders().entrySet()) { - List values = new ArrayList<>(); - for (Expression expression : entry.getValue()) { - values.add(expression.accept(RuleEvaluator.this).expectStringValue().getValue()); - } - builder.putHeader(entry.getKey(), values); - } - return builder.build(); + return resolveEndpoint(self, endpoint); } }); }); } + + private static Value resolveEndpoint(RuleEvaluator self, Endpoint endpoint) { + EndpointValue.Builder builder = EndpointValue.builder() + .sourceLocation(endpoint) + .url(endpoint.getUrl() + .accept(self) + .expectStringValue() + .getValue()); + + for (Map.Entry entry : endpoint.getProperties().entrySet()) { + builder.putProperty(entry.getKey().toString(), entry.getValue().accept(self)); + } + + for (Map.Entry> entry : endpoint.getHeaders().entrySet()) { + List values = new ArrayList<>(); + for (Expression expression : entry.getValue()) { + values.add(expression.accept(self).expectStringValue().getValue()); + } + builder.putHeader(entry.getKey(), values); + } + return builder.build(); + } + + private static Value resolveError(RuleEvaluator self, Expression error) { + return error.accept(self); + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java index b80efa02c61..5cd6b13009f 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java @@ -13,6 +13,7 @@ import software.amazon.smithy.rulesengine.language.evaluation.value.EndpointValue; import software.amazon.smithy.rulesengine.language.evaluation.value.Value; import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.logic.bdd.Bdd; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestExpectation; import software.amazon.smithy.rulesengine.traits.ExpectedEndpoint; @@ -34,12 +35,30 @@ private TestEvaluator() {} * @param testCase The test case. */ public static void evaluate(EndpointRuleSet ruleset, EndpointTestCase testCase) { + Value result = RuleEvaluator.evaluate(ruleset, createParams(testCase)); + processResult(result, testCase); + } + + /** + * Evaluate the given rule-set and test case. Throws an exception in the event the test case does not pass. + * + * @param bdd The BDD to be tested. + * @param testCase The test case. + */ + public static void evaluate(Bdd bdd, EndpointTestCase testCase) { + Value result = RuleEvaluator.evaluate(bdd, createParams(testCase)); + processResult(result, testCase); + } + + private static Map createParams(EndpointTestCase testCase) { Map parameters = new LinkedHashMap<>(); for (Map.Entry entry : testCase.getParams().getMembers().entrySet()) { parameters.put(Identifier.of(entry.getKey()), Value.fromNode(entry.getValue())); } - Value result = RuleEvaluator.evaluate(ruleset, parameters); + return parameters; + } + private static void processResult(Value result, EndpointTestCase testCase) { StringBuilder messageBuilder = new StringBuilder("while executing test case"); if (testCase.getDocumentation().isPresent()) { messageBuilder.append(" ").append(testCase.getDocumentation().get()); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameter.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameter.java index 3dbe14dbdde..1a36efb4bcf 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameter.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/parameters/Parameter.java @@ -24,6 +24,7 @@ import software.amazon.smithy.utils.ListUtils; import software.amazon.smithy.utils.SmithyBuilder; import software.amazon.smithy.utils.SmithyUnstableApi; +import software.amazon.smithy.utils.StringUtils; import software.amazon.smithy.utils.ToSmithyBuilder; /** @@ -272,7 +273,7 @@ public Node toNode() { if (documentation != null) { node.withMember(DOCUMENTATION, documentation); } - node.withMember(TYPE, type.toString()); + node.withMember(TYPE, StringUtils.uncapitalize(type.toString())); return node.build(); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointTestsTraitValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/EndpointTestsTraitValidator.java similarity index 64% rename from smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointTestsTraitValidator.java rename to smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/EndpointTestsTraitValidator.java index bb8fd6f37f5..cfba5f764ab 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointTestsTraitValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/EndpointTestsTraitValidator.java @@ -2,7 +2,7 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rulesengine.traits; +package software.amazon.smithy.rulesengine.validators; import java.util.ArrayList; import java.util.HashMap; @@ -18,8 +18,13 @@ import software.amazon.smithy.model.validation.NodeValidationVisitor; import software.amazon.smithy.model.validation.Severity; import software.amazon.smithy.model.validation.ValidationEvent; -import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.traits.BddTrait; +import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; +import software.amazon.smithy.rulesengine.traits.EndpointTestCase; +import software.amazon.smithy.rulesengine.traits.EndpointTestOperationInput; +import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -39,61 +44,81 @@ public List validate(Model model) { operationNameMap.put(operationShape.getId().getName(), operationShape); } - // Precompute the built-ins and their default states, as this will - // be used frequently in downstream validation. - List builtInParamsWithDefaults = new ArrayList<>(); - List builtInParamsWithoutDefaults = new ArrayList<>(); - EndpointRuleSet ruleSet = serviceShape.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); - for (Parameter parameter : ruleSet.getParameters()) { - if (parameter.isBuiltIn()) { - if (parameter.getDefault().isPresent()) { - builtInParamsWithDefaults.add(parameter); - } else { - builtInParamsWithoutDefaults.add(parameter); - } + serviceShape.getTrait(EndpointRuleSetTrait.class).ifPresent(trait -> { + validateEndpointRuleSet(events, + model, + serviceShape, + trait.getEndpointRuleSet().getParameters(), + operationNameMap); + }); + + serviceShape.getTrait(BddTrait.class).ifPresent(trait -> { + validateEndpointRuleSet(events, model, serviceShape, trait.getBdd().getParameters(), operationNameMap); + }); + } + + return events; + } + + private void validateEndpointRuleSet( + List events, + Model model, + ServiceShape serviceShape, + Parameters parameters, + Map operationNameMap + ) { + // Precompute the built-ins and their default states, as this will + // be used frequently in downstream validation. + List builtInParamsWithDefaults = new ArrayList<>(); + List builtInParamsWithoutDefaults = new ArrayList<>(); + + for (Parameter parameter : parameters) { + if (parameter.isBuiltIn()) { + if (parameter.getDefault().isPresent()) { + builtInParamsWithDefaults.add(parameter); + } else { + builtInParamsWithoutDefaults.add(parameter); } } + } - for (EndpointTestCase testCase : serviceShape.expectTrait(EndpointTestsTrait.class).getTestCases()) { - // If values for built-in parameters don't match the default, they MUST - // be specified in the operation inputs. Precompute the ones that don't match - // and capture their value. - Map builtInParamsWithNonDefaultValues = - getBuiltInParamsWithNonDefaultValues(builtInParamsWithDefaults, testCase); - - for (EndpointTestOperationInput testOperationInput : testCase.getOperationInputs()) { - String operationName = testOperationInput.getOperationName(); - - // It's possible for an operation defined to not be in the service closure. - if (!operationNameMap.containsKey(operationName)) { - events.add(error(serviceShape, - testOperationInput, - String.format("Test case operation `%s` does not exist in service `%s`", - operationName, - serviceShape.getId()))); - continue; - } - - // Still emit events if the operation exists, but was just not bound. - validateConfiguredBuiltInValues(serviceShape, - builtInParamsWithNonDefaultValues, - testOperationInput, - events); - validateBuiltInsWithoutDefaultsHaveValues(serviceShape, - builtInParamsWithoutDefaults, - testCase, - testOperationInput, - events); + for (EndpointTestCase testCase : serviceShape.expectTrait(EndpointTestsTrait.class).getTestCases()) { + // If values for built-in parameters don't match the default, they MUST + // be specified in the operation inputs. Precompute the ones that don't match + // and capture their value. + Map builtInParamsWithNonDefaultValues = + getBuiltInParamsWithNonDefaultValues(builtInParamsWithDefaults, testCase); + + for (EndpointTestOperationInput testOperationInput : testCase.getOperationInputs()) { + String operationName = testOperationInput.getOperationName(); - StructureShape inputShape = model.expectShape( - operationNameMap.get(operationName).getInputShape(), - StructureShape.class); - validateOperationInput(model, serviceShape, inputShape, testCase, testOperationInput, events); + // It's possible for an operation defined to not be in the service closure. + if (!operationNameMap.containsKey(operationName)) { + events.add(error(serviceShape, + testOperationInput, + String.format("Test case operation `%s` does not exist in service `%s`", + operationName, + serviceShape.getId()))); + continue; } + + // Still emit events if the operation exists, but was just not bound. + validateConfiguredBuiltInValues(serviceShape, + builtInParamsWithNonDefaultValues, + testOperationInput, + events); + validateBuiltInsWithoutDefaultsHaveValues(serviceShape, + builtInParamsWithoutDefaults, + testCase, + testOperationInput, + events); + + StructureShape inputShape = model.expectShape( + operationNameMap.get(operationName).getInputShape(), + StructureShape.class); + validateOperationInput(model, serviceShape, inputShape, testCase, testOperationInput, events); } } - - return events; } private Map getBuiltInParamsWithNonDefaultValues( diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/OperationContextParamsTraitValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/OperationContextParamsTraitValidator.java similarity index 97% rename from smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/OperationContextParamsTraitValidator.java rename to smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/OperationContextParamsTraitValidator.java index e06df8df172..59f1524591e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/OperationContextParamsTraitValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/OperationContextParamsTraitValidator.java @@ -2,7 +2,7 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rulesengine.traits; +package software.amazon.smithy.rulesengine.validators; import java.util.ArrayList; import java.util.Collections; @@ -36,6 +36,9 @@ import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.analysis.OperationContextParamsChecker; +import software.amazon.smithy.rulesengine.traits.ContextIndex; +import software.amazon.smithy.rulesengine.traits.OperationContextParamDefinition; +import software.amazon.smithy.rulesengine.traits.OperationContextParamsTrait; import software.amazon.smithy.utils.ListUtils; import software.amazon.smithy.utils.SmithyUnstableApi; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java index 1c678f82bd2..99fdae7b137 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java @@ -8,11 +8,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; -import java.util.function.BiFunction; -import java.util.stream.Collectors; -import java.util.stream.Stream; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.shapes.ServiceShape; @@ -20,11 +16,13 @@ import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.Endpoint; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; -import software.amazon.smithy.rulesengine.language.TraversingVisitor; import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.traits.BddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; -import software.amazon.smithy.utils.ListUtils; /** * Validator which verifies an endpoint with an authSchemes property conforms to a strict schema. @@ -35,111 +33,141 @@ public final class RuleSetAuthSchemesValidator extends AbstractValidator { @Override public List validate(Model model) { List events = new ArrayList<>(); - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { - Validator validator = new Validator(serviceShape); - events.addAll(validator.visitRuleset( - serviceShape.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet()) - .collect(Collectors.toList())); + for (ServiceShape serviceShape : model.getServiceShapes()) { + visitRuleset(events, serviceShape, serviceShape.getTrait(EndpointRuleSetTrait.class).orElse(null)); + visitBdd(events, serviceShape, serviceShape.getTrait(BddTrait.class).orElse(null)); } return events; } - private class Validator extends TraversingVisitor { - private final ServiceShape serviceShape; - - Validator(ServiceShape serviceShape) { - this.serviceShape = serviceShape; + private void visitRuleset(List events, ServiceShape serviceShape, EndpointRuleSetTrait trait) { + if (trait != null) { + for (Rule rule : trait.getEndpointRuleSet().getRules()) { + traverse(events, serviceShape, rule); + } } + } - @Override - public Stream visitEndpoint(Endpoint endpoint) { - List events = new ArrayList<>(); - - Literal authSchemes = endpoint.getProperties().get(Identifier.of("authSchemes")); - if (authSchemes != null) { - BiFunction emitter = getEventEmitter(); - Optional> authSchemeList = authSchemes.asTupleLiteral(); - if (!authSchemeList.isPresent()) { - return Stream.of(emitter.apply(authSchemes, - String.format("Expected `authSchemes` to be a list, found: `%s`", authSchemes))); + private void visitBdd(List events, ServiceShape serviceShape, BddTrait trait) { + if (trait != null) { + for (Rule result : trait.getBdd().getResults()) { + if (result instanceof EndpointRule) { + visitEndpoint(events, serviceShape, (EndpointRule) result); } + } + } + } - Set authSchemeNames = new HashSet<>(); - Set duplicateAuthSchemeNames = new HashSet<>(); - for (Literal authSchemeEntry : authSchemeList.get()) { - Optional> authSchemeMap = authSchemeEntry.asRecordLiteral(); - if (authSchemeMap.isPresent()) { - // Validate the name property so that we can also check that they're unique. - Map authScheme = authSchemeMap.get(); - Optional event = validateAuthSchemeName(authScheme, authSchemeEntry); - if (event.isPresent()) { - events.add(event.get()); - continue; - } - String schemeName = authScheme.get(NAME).asStringLiteral().get().expectLiteral(); - if (!authSchemeNames.add(schemeName)) { - duplicateAuthSchemeNames.add(schemeName); - } - - events.addAll(validateAuthScheme(schemeName, authScheme, authSchemeEntry)); - } else { - events.add(emitter.apply(authSchemes, - String.format("Expected `authSchemes` to be a list of objects, but found: `%s`", - authSchemeEntry))); - } + private void traverse(List events, ServiceShape service, Rule rule) { + if (rule instanceof EndpointRule) { + visitEndpoint(events, service, (EndpointRule) rule); + } else if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + for (Rule child : treeRule.getRules()) { + traverse(events, service, child); + } + } + } + + private void visitEndpoint(List events, ServiceShape service, EndpointRule endpointRule) { + Endpoint endpoint = endpointRule.getEndpoint(); + Literal authSchemes = endpoint.getProperties().get(Identifier.of("authSchemes")); + + if (authSchemes != null) { + List authSchemeList = authSchemes.asTupleLiteral().orElse(null); + if (authSchemeList == null) { + events.add(error(service, + authSchemes, + String.format( + "Expected `authSchemes` to be a list, found: `%s`", + authSchemes))); + return; + } + + Set authSchemeNames = new HashSet<>(); + Set duplicateAuthSchemeNames = new HashSet<>(); + for (Literal authSchemeEntry : authSchemeList) { + Map authSchemeMap = authSchemeEntry.asRecordLiteral().orElse(null); + if (authSchemeMap == null) { + events.add(error(service, + authSchemes, + String.format( + "Expected `authSchemes` to be a list of objects, but found: `%s`", + authSchemeEntry))); + continue; } - // Emit events for each duplicated auth scheme name. - for (String duplicateAuthSchemeName : duplicateAuthSchemeNames) { - events.add(emitter.apply(authSchemes, - String.format("Found duplicate `name` of `%s` in the " - + "`authSchemes` list", duplicateAuthSchemeName))); + String schemeName = validateAuthSchemeName(events, service, authSchemeMap, authSchemeEntry); + if (schemeName != null) { + if (!authSchemeNames.add(schemeName)) { + duplicateAuthSchemeNames.add(schemeName); + } + validateAuthScheme(events, service, schemeName, authSchemeMap, authSchemeEntry); } } - return events.stream(); - } - - private Optional validateAuthSchemeName( - Map authScheme, - FromSourceLocation sourceLocation - ) { - if (!authScheme.containsKey(NAME) || !authScheme.get(NAME).asStringLiteral().isPresent()) { - return Optional.of(error(serviceShape, - sourceLocation, - String.format("Expected `authSchemes` to have a `name` key with a string value but it did not: " - + "`%s`", authScheme))); + // Emit events for each duplicated auth scheme name. + for (String duplicateAuthSchemeName : duplicateAuthSchemeNames) { + events.add(error(service, + authSchemes, + String.format( + "Found duplicate `name` of `%s` in the `authSchemes` list", + duplicateAuthSchemeName))); } - return Optional.empty(); } + } - private List validateAuthScheme( - String schemeName, - Map authScheme, - FromSourceLocation sourceLocation - ) { - List events = new ArrayList<>(); + private String validateAuthSchemeName( + List events, + ServiceShape service, + Map authScheme, + FromSourceLocation sourceLocation + ) { + Literal nameLiteral = authScheme.get(NAME); + if (nameLiteral == null) { + events.add(error(service, + sourceLocation, + String.format( + "Expected `authSchemes` to have a `name` key with a string value but it did not: `%s`", + authScheme))); + return null; + } - BiFunction emitter = getEventEmitter(); + String name = nameLiteral.asStringLiteral().map(s -> s.expectLiteral()).orElse(null); + if (name == null) { + events.add(error(service, + sourceLocation, + String.format( + "Expected `authSchemes` to have a `name` key with a string value but it did not: `%s`", + authScheme))); + return null; + } - boolean validatedAuth = false; - for (AuthSchemeValidator authSchemeValidator : EndpointRuleSet.getAuthSchemeValidators()) { - if (authSchemeValidator.test(schemeName)) { - events.addAll(authSchemeValidator.validateScheme(authScheme, sourceLocation, emitter)); - validatedAuth = true; - } - } + return name; + } - if (validatedAuth) { - return events; + private void validateAuthScheme( + List events, + ServiceShape service, + String schemeName, + Map authScheme, + FromSourceLocation sourceLocation + ) { + boolean validatedAuth = false; + for (AuthSchemeValidator authSchemeValidator : EndpointRuleSet.getAuthSchemeValidators()) { + if (authSchemeValidator.test(schemeName)) { + events.addAll(authSchemeValidator.validateScheme(authScheme, + sourceLocation, + (location, message) -> error(service, location, message))); + validatedAuth = true; } - return ListUtils.of(warning(serviceShape, - String.format("Did not find a validator for the `%s` " - + "auth scheme", schemeName))); } - private BiFunction getEventEmitter() { - return (sourceLocation, message) -> error(serviceShape, sourceLocation, message); + if (!validatedAuth) { + events.add(warning(service, + String.format( + "Did not find a validator for the `%s` auth scheme", + schemeName))); } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java index 430412c1d1f..12a5f7d473e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java @@ -7,7 +7,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.Optional; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.node.StringNode; @@ -16,6 +15,7 @@ import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.traits.BddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestOperationInput; @@ -28,67 +28,63 @@ public final class RuleSetBuiltInValidator extends AbstractValidator { @Override public List validate(Model model) { List events = new ArrayList<>(); - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { - events.addAll(validateRuleSetBuiltIns(serviceShape, - serviceShape.expectTrait(EndpointRuleSetTrait.class) - .getEndpointRuleSet())); + + for (ServiceShape s : model.getServiceShapesWithTrait(BddTrait.class)) { + validateParams(events, s, s.expectTrait(BddTrait.class).getBdd().getParameters()); + } + + for (ServiceShape s : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { + validateParams(events, s, s.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet().getParameters()); } - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointTestsTrait.class)) { - events.addAll(validateTestTraitBuiltIns(serviceShape, serviceShape.expectTrait(EndpointTestsTrait.class))); + for (ServiceShape s : model.getServiceShapesWithTrait(EndpointTestsTrait.class)) { + validateTestBuiltIns(events, s, s.expectTrait(EndpointTestsTrait.class)); } + return events; } - private List validateRuleSetBuiltIns(ServiceShape serviceShape, EndpointRuleSet ruleSet) { - List events = new ArrayList<>(); - for (Parameter parameter : ruleSet.getParameters()) { + private void validateParams(List events, ServiceShape service, Iterable params) { + for (Parameter parameter : params) { if (parameter.isBuiltIn()) { - validateBuiltIn(serviceShape, parameter.getBuiltIn().get(), parameter, "RuleSet") - .ifPresent(events::add); + validateBuiltIn(events, service, parameter.getBuiltIn().get(), parameter, "RuleSet"); } } - return events; } - private List validateTestTraitBuiltIns(ServiceShape serviceShape, EndpointTestsTrait testSuite) { - List events = new ArrayList<>(); + private void validateTestBuiltIns(List events, ServiceShape service, EndpointTestsTrait suite) { int testIndex = 0; - for (EndpointTestCase testCase : testSuite.getTestCases()) { + for (EndpointTestCase testCase : suite.getTestCases()) { int inputIndex = 0; for (EndpointTestOperationInput operationInput : testCase.getOperationInputs()) { for (StringNode builtInNode : operationInput.getBuiltInParams().getMembers().keySet()) { - validateBuiltIn(serviceShape, + validateBuiltIn(events, + service, builtInNode.getValue(), operationInput, "TestCase", String.valueOf(testIndex), "Inputs", - String.valueOf(inputIndex)) - .ifPresent(events::add); + String.valueOf(inputIndex)); } inputIndex++; } testIndex++; } - return events; } - private Optional validateBuiltIn( - ServiceShape serviceShape, - String builtInName, + private void validateBuiltIn( + List events, + ServiceShape service, + String name, FromSourceLocation source, String... eventIdSuffixes ) { - if (!EndpointRuleSet.hasBuiltIn(builtInName)) { - return Optional.of(error(serviceShape, - source, - String.format( - "The `%s` built-in used is not registered, valid built-ins: %s", - builtInName, - EndpointRuleSet.getKeyString()), - String.join(".", Arrays.asList(eventIdSuffixes)))); + if (!EndpointRuleSet.hasBuiltIn(name)) { + String msg = String.format("The `%s` built-in used is not registered, valid built-ins: %s", + name, + EndpointRuleSet.getKeyString()); + events.add(error(service, source, msg, String.join(".", Arrays.asList(eventIdSuffixes)))); } - return Optional.empty(); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java index 5263f27e7aa..1022925078a 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java @@ -10,34 +10,43 @@ import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; -import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.traits.BddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; /** - * Validator to ensure that all parameters have documentation. + * Validator to ensure that all parameters have documentation (in BDD and ruleset). */ public final class RuleSetParamMissingDocsValidator extends AbstractValidator { @Override public List validate(Model model) { List events = new ArrayList<>(); - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { - events.addAll(validateRuleSet(serviceShape, - serviceShape.expectTrait(EndpointRuleSetTrait.class) - .getEndpointRuleSet())); + for (ServiceShape serviceShape : model.getServiceShapes()) { + visitRuleset(events, serviceShape, serviceShape.getTrait(EndpointRuleSetTrait.class).orElse(null)); + visitBdd(events, serviceShape, serviceShape.getTrait(BddTrait.class).orElse(null)); } return events; } - public List validateRuleSet(ServiceShape serviceShape, EndpointRuleSet ruleSet) { - List events = new ArrayList<>(); - for (Parameter parameter : ruleSet.getParameters()) { + private void visitRuleset(List events, ServiceShape serviceShape, EndpointRuleSetTrait trait) { + if (trait != null) { + visitParams(events, serviceShape, trait.getEndpointRuleSet().getParameters()); + } + } + + private void visitBdd(List events, ServiceShape serviceShape, BddTrait trait) { + if (trait != null) { + visitParams(events, serviceShape, trait.getBdd().getParameters()); + } + } + + public void visitParams(List events, ServiceShape serviceShape, Iterable parameters) { + for (Parameter parameter : parameters) { if (!parameter.getDocumentation().isPresent()) { events.add(warning(serviceShape, parameter, String.format("Parameter `%s` does not have documentation", parameter.getName()))); } } - return events; } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java index 54d445bac93..8eeaa16dbb2 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java @@ -9,7 +9,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import software.amazon.smithy.model.FromSourceLocation; import software.amazon.smithy.model.Model; @@ -23,22 +22,21 @@ import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.analysis.OperationContextParamsChecker; -import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.evaluation.value.Value; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; -import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.traits.BddTrait; import software.amazon.smithy.rulesengine.traits.ClientContextParamDefinition; import software.amazon.smithy.rulesengine.traits.ClientContextParamsTrait; import software.amazon.smithy.rulesengine.traits.ContextParamTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait; +import software.amazon.smithy.rulesengine.traits.OperationContextParamDefinition; import software.amazon.smithy.rulesengine.traits.OperationContextParamsTrait; import software.amazon.smithy.rulesengine.traits.StaticContextParamDefinition; import software.amazon.smithy.rulesengine.traits.StaticContextParamsTrait; import software.amazon.smithy.utils.ListUtils; -import software.amazon.smithy.utils.Pair; /** * Validator for rule-set parameters. @@ -47,45 +45,52 @@ public final class RuleSetParameterValidator extends AbstractValidator { @Override public List validate(Model model) { TopDownIndex topDownIndex = TopDownIndex.of(model); - List errors = new ArrayList<>(); - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { - // Pull all the parameters used in this service related to endpoints, validating that - // they are of matching types across the traits that can define them. - Pair, Map> errorsParamsPair = validateAndExtractParameters( - model, - serviceShape, - topDownIndex.getContainedOperations(serviceShape)); - errors.addAll(errorsParamsPair.getLeft()); - - // Make sure parameters align across Params <-> RuleSet transitions. - EndpointRuleSet ruleSet = serviceShape.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); - errors.addAll(validateParametersMatching(serviceShape, - ruleSet.getParameters(), - errorsParamsPair.getRight())); - // Check that tests declare required parameters, only defined parameters, etc. - if (serviceShape.hasTrait(EndpointTestsTrait.ID)) { - errors.addAll(validateTestsParameters( - serviceShape, - serviceShape.expectTrait(EndpointTestsTrait.class), - ruleSet)); + for (ServiceShape service : model.getServiceShapes()) { + EndpointRuleSetTrait epTrait = service.getTrait(EndpointRuleSetTrait.class).orElse(null); + BddTrait bddTrait = service.getTrait(BddTrait.class).orElse(null); + if (epTrait != null) { + validate(model, topDownIndex, service, errors, epTrait, epTrait.getEndpointRuleSet().getParameters()); + } + if (bddTrait != null) { + validate(model, topDownIndex, service, errors, bddTrait, bddTrait.getBdd().getParameters()); } } return errors; } - private Pair, Map> validateAndExtractParameters( + private void validate( Model model, - ServiceShape serviceShape, + TopDownIndex topDownIndex, + ServiceShape service, + List errors, + FromSourceLocation sourceLocation, + Iterable parameters + ) { + // Pull all the parameters used in this service related to endpoints, validating that + // they are of matching types across the traits that can define them. + Set operations = topDownIndex.getContainedOperations(service); + Map modelParams = validateAndExtractParameters(errors, model, service, operations); + // Make sure parameters align across Params <-> RuleSet transitions. + validateParametersMatching(errors, service, sourceLocation, parameters, modelParams); + // Check that tests declare required parameters, only defined parameters, etc. + if (service.hasTrait(EndpointTestsTrait.ID)) { + validateTestsParameters(errors, service, service.expectTrait(EndpointTestsTrait.class), parameters); + } + } + + private Map validateAndExtractParameters( + List errors, + Model model, + ServiceShape service, Set containedOperations ) { - List errors = new ArrayList<>(); Map endpointParams = new HashMap<>(); - if (serviceShape.hasTrait(ClientContextParamsTrait.ID)) { - ClientContextParamsTrait trait = serviceShape.expectTrait(ClientContextParamsTrait.class); + if (service.hasTrait(ClientContextParamsTrait.ID)) { + ClientContextParamsTrait trait = service.expectTrait(ClientContextParamsTrait.class); for (Map.Entry entry : trait.getParameters().entrySet()) { endpointParams.put(entry.getKey(), Parameter.builder() @@ -120,10 +125,14 @@ private Pair, Map> validateAndExtractPa if (operationShape.hasTrait(OperationContextParamsTrait.ID)) { OperationContextParamsTrait trait = operationShape.expectTrait(OperationContextParamsTrait.class); - trait.getParameters().forEach((name, p) -> { - Optional maybeType = OperationContextParamsChecker - .inferParameterType(p, operationShape, model); - maybeType.ifPresent(parameterType -> { + for (Map.Entry entry : trait.getParameters().entrySet()) { + String name = entry.getKey(); + OperationContextParamDefinition p = entry.getValue(); + ParameterType parameterType = OperationContextParamsChecker + .inferParameterType(p, operationShape, model) + .orElse(null); + + if (parameterType != null) { if (endpointParams.containsKey(name) && endpointParams.get(name).getType() != parameterType) { errors.add(parameterError(operationShape, trait, @@ -136,8 +145,8 @@ private Pair, Map> validateAndExtractPa .type(parameterType) .build()); } - }); - }); + } + } } StructureShape input = model.expectShape(operationShape.getInputShape(), StructureShape.class); @@ -173,15 +182,16 @@ private Pair, Map> validateAndExtractPa } } - return Pair.of(errors, endpointParams); + return endpointParams; } - private List validateParametersMatching( + private void validateParametersMatching( + List errors, ServiceShape serviceShape, - Parameters ruleSetParams, + FromSourceLocation sourceLocation, + Iterable ruleSetParams, Map modelParams ) { - List errors = new ArrayList<>(); Set matchedParams = new HashSet<>(); for (Parameter parameter : ruleSetParams) { String name = parameter.getName().toString(); @@ -213,27 +223,25 @@ private List validateParametersMatching( for (Map.Entry entry : modelParams.entrySet()) { if (!matchedParams.contains(entry.getKey())) { errors.add(parameterError(serviceShape, - serviceShape.expectTrait(EndpointRuleSetTrait.class), + sourceLocation, "RuleSet.UnmatchedName", String.format("Parameter `%s` exists in service model but not in ruleset, existing params: %s", entry.getKey(), matchedParams))); } } - - return errors; } - private List validateTestsParameters( + private void validateTestsParameters( + List errors, ServiceShape serviceShape, EndpointTestsTrait trait, - EndpointRuleSet ruleSet + Iterable parameters ) { - List errors = new ArrayList<>(); Set rulesetParamNames = new HashSet<>(); Map> testSuiteParams = extractTestSuiteParameters(trait.getTestCases()); - for (Parameter parameter : ruleSet.getParameters()) { + for (Parameter parameter : parameters) { String name = parameter.getName().toString(); rulesetParamNames.add(name); boolean testSuiteHasParam = testSuiteParams.containsKey(name); @@ -278,8 +286,6 @@ private List validateTestsParameters( String.format("Test parameter `%s` is not defined in ruleset", entry.getKey()))); } } - - return errors; } private Map> extractTestSuiteParameters(List testCases) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java index bfccecaf19c..157da6199c8 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java @@ -12,6 +12,8 @@ import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.evaluation.TestEvaluator; +import software.amazon.smithy.rulesengine.logic.bdd.Bdd; +import software.amazon.smithy.rulesengine.traits.BddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait; @@ -23,22 +25,38 @@ public class RuleSetTestCaseValidator extends AbstractValidator { @Override public List validate(Model model) { List events = new ArrayList<>(); - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { - if (serviceShape.hasTrait(EndpointTestsTrait.ID)) { - EndpointRuleSet ruleSet = serviceShape.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); - EndpointTestsTrait testsTrait = serviceShape.expectTrait(EndpointTestsTrait.class); - - // Test/Rule evaluation throws RuntimeExceptions when evaluating, wrap these - // up into ValidationEvents for automatic validation. - for (EndpointTestCase endpointTestCase : testsTrait.getTestCases()) { - try { - TestEvaluator.evaluate(ruleSet, endpointTestCase); - } catch (RuntimeException e) { - events.add(error(serviceShape, endpointTestCase, e.getMessage())); - } - } + for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointTestsTrait.class)) { + EndpointTestsTrait testsTrait = serviceShape.expectTrait(EndpointTestsTrait.class); + if (serviceShape.hasTrait(EndpointRuleSetTrait.class)) { + validate(serviceShape, testsTrait, events); + } else if (serviceShape.hasTrait(BddTrait.class)) { + validateBdd(serviceShape, testsTrait, events); } } return events; } + + // Test/Rule evaluation throws RuntimeExceptions when evaluating, wrap these + // up into ValidationEvents for automatic validation. + private void validate(ServiceShape serviceShape, EndpointTestsTrait testsTrait, List events) { + EndpointRuleSet ruleSet = serviceShape.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); + for (EndpointTestCase endpointTestCase : testsTrait.getTestCases()) { + try { + TestEvaluator.evaluate(ruleSet, endpointTestCase); + } catch (RuntimeException e) { + events.add(error(serviceShape, endpointTestCase, e.getMessage())); + } + } + } + + private void validateBdd(ServiceShape serviceShape, EndpointTestsTrait testsTrait, List events) { + Bdd bdd = serviceShape.expectTrait(BddTrait.class).getBdd(); + for (EndpointTestCase endpointTestCase : testsTrait.getTestCases()) { + try { + TestEvaluator.evaluate(bdd, endpointTestCase); + } catch (RuntimeException e) { + events.add(error(serviceShape, endpointTestCase, e.getMessage())); + } + } + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java index c78a3a5b04a..12d0ef6a3c7 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java @@ -6,22 +6,19 @@ import java.util.ArrayList; import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; -import java.util.stream.Stream; import software.amazon.smithy.model.Model; import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.Endpoint; -import software.amazon.smithy.rulesengine.language.TraversingVisitor; -import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.LiteralVisitor; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.traits.BddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; -import software.amazon.smithy.utils.OptionalUtils; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -32,75 +29,60 @@ public final class RuleSetUriValidator extends AbstractValidator { @Override public List validate(Model model) { List events = new ArrayList<>(); - for (ServiceShape serviceShape : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { - events.addAll(new UriSchemeVisitor(serviceShape) - .visitRuleset(serviceShape.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet()) - .collect(Collectors.toList())); + for (ServiceShape serviceShape : model.getServiceShapes()) { + visitRuleset(events, serviceShape, serviceShape.getTrait(EndpointRuleSetTrait.class).orElse(null)); + visitBdd(events, serviceShape, serviceShape.getTrait(BddTrait.class).orElse(null)); } return events; } - private final class UriSchemeVisitor extends TraversingVisitor { - private final ServiceShape serviceShape; - private boolean checkingEndpoint = false; - - UriSchemeVisitor(ServiceShape serviceShape) { - this.serviceShape = serviceShape; - } - - @Override - public Stream visitEndpoint(Endpoint endpoint) { - checkingEndpoint = true; - Stream errors = endpoint.getUrl().accept(this); - checkingEndpoint = false; - return errors; + private void visitRuleset(List events, ServiceShape serviceShape, EndpointRuleSetTrait trait) { + if (trait != null) { + for (Rule rule : trait.getEndpointRuleSet().getRules()) { + traverse(events, serviceShape, rule); + } } + } - @Override - public Stream visitLiteral(Literal literal) { - return literal.accept(new LiteralVisitor>() { - @Override - public Stream visitBoolean(boolean b) { - return Stream.empty(); - } - - @Override - public Stream visitString(Template value) { - return OptionalUtils.stream(validateTemplate(value)); - } - - @Override - public Stream visitRecord(Map members) { - return Stream.empty(); + private void visitBdd(List events, ServiceShape serviceShape, BddTrait trait) { + if (trait != null) { + for (Rule result : trait.getBdd().getResults()) { + if (result instanceof EndpointRule) { + visitEndpoint(events, serviceShape, (EndpointRule) result); } + } + } + } - @Override - public Stream visitTuple(List members) { - return Stream.empty(); - } + private void traverse(List events, ServiceShape service, Rule rule) { + if (rule instanceof EndpointRule) { + visitEndpoint(events, service, (EndpointRule) rule); + } else if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + for (Rule child : treeRule.getRules()) { + traverse(events, service, child); + } + } + } - @Override - public Stream visitInteger(int value) { - return Stream.empty(); - } - }); + private void visitEndpoint(List events, ServiceShape serviceShape, EndpointRule endpointRule) { + Endpoint endpoint = endpointRule.getEndpoint(); + Expression url = endpoint.getUrl(); + if (url instanceof StringLiteral) { + StringLiteral s = (StringLiteral) url; + visitTemplate(events, serviceShape, s.value()); } + } - private Optional validateTemplate(Template template) { - if (checkingEndpoint) { - Template.Part part = template.getParts().get(0); - if (part instanceof Template.Literal) { - String scheme = ((Template.Literal) part).getValue(); - if (!(scheme.startsWith("http://") || scheme.startsWith("https://"))) { - return Optional.of(error(serviceShape, - template, - "URI should start with `http://` or `https://` but the URI started with " - + scheme)); - } - } - // Allow dynamic URIs for now — we should lint that at looks like a scheme at some point + private void visitTemplate(List events, ServiceShape serviceShape, Template template) { + Template.Part part = template.getParts().get(0); + if (part instanceof Template.Literal) { + String scheme = ((Template.Literal) part).getValue(); + if (!(scheme.startsWith("http://") || scheme.startsWith("https://"))) { + events.add(error(serviceShape, + template, + "URI should start with `http://` or `https://` but the URI started with " + scheme)); } - return Optional.empty(); } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/StaticContextParamsTraitValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/StaticContextParamsTraitValidator.java similarity index 89% rename from smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/StaticContextParamsTraitValidator.java rename to smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/StaticContextParamsTraitValidator.java index de3a344c261..8b01f796cc4 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/StaticContextParamsTraitValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/StaticContextParamsTraitValidator.java @@ -2,7 +2,7 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rulesengine.traits; +package software.amazon.smithy.rulesengine.validators; import java.util.ArrayList; import java.util.Collections; @@ -13,6 +13,9 @@ import software.amazon.smithy.model.shapes.OperationShape; import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; +import software.amazon.smithy.rulesengine.traits.ContextIndex; +import software.amazon.smithy.rulesengine.traits.StaticContextParamDefinition; +import software.amazon.smithy.rulesengine.traits.StaticContextParamsTrait; import software.amazon.smithy.utils.SmithyUnstableApi; /** diff --git a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator index fa5f972e343..d1243825d5b 100644 --- a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator +++ b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator @@ -1,6 +1,6 @@ -software.amazon.smithy.rulesengine.traits.EndpointTestsTraitValidator -software.amazon.smithy.rulesengine.traits.StaticContextParamsTraitValidator -software.amazon.smithy.rulesengine.traits.OperationContextParamsTraitValidator +software.amazon.smithy.rulesengine.validators.EndpointTestsTraitValidator +software.amazon.smithy.rulesengine.validators.StaticContextParamsTraitValidator +software.amazon.smithy.rulesengine.validators.OperationContextParamsTraitValidator software.amazon.smithy.rulesengine.validators.RuleSetAuthSchemesValidator software.amazon.smithy.rulesengine.validators.RuleSetBuiltInValidator software.amazon.smithy.rulesengine.validators.RuleSetUriValidator diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/invalid-endpoint-uri.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/invalid-endpoint-uri.errors new file mode 100644 index 00000000000..95d19a87d72 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/invalid-endpoint-uri.errors @@ -0,0 +1,3 @@ +[WARNING] example#FizzBuzz: This shape applies a trait that is unstable: smithy.rules#clientContextParams | UnstableTrait.smithy.rules#clientContextParams +[WARNING] example#FizzBuzz: This shape applies a trait that is unstable: smithy.rules#endpointRuleSet | UnstableTrait.smithy.rules#endpointRuleSet +[ERROR] example#FizzBuzz: URI should start with `http://` or `https://` but the URI started with foo://example.com/ | RuleSetUri diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/invalid-endpoint-uri.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/invalid-endpoint-uri.smithy new file mode 100644 index 00000000000..7df32ec370e --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/invalid-endpoint-uri.smithy @@ -0,0 +1,47 @@ +$version: "2.0" + +namespace example + +use smithy.rules#clientContextParams +use smithy.rules#endpointRuleSet + +@clientContextParams( + bar: {type: "string", documentation: "a client string parameter"} +) +@endpointRuleSet({ + version: "1.0", + parameters: { + bar: { + type: "string", + documentation: "docs" + } + }, + rules: [ + { + "documentation": "lorem ipsum dolor", + "conditions": [ + { + "fn": "isSet", + "argv": [ + { + "ref": "bar" + } + ] + } + ], + "type": "endpoint", + "endpoint": { + "url": "foo://example.com/" + } + }, + { + "conditions": [], + "documentation": "error fallthrough", + "error": "endpoint error", + "type": "error" + } + ] +}) +service FizzBuzz { + version: "2022-01-01" +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json index eda21ee8440..24ed8bc5eb1 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json @@ -4,7 +4,7 @@ "Region": { "builtIn": "AWS::Region", "required": true, - "type": "String" + "type": "string" } }, "rules": [ diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.errors index d55caf64ef8..b72ae3c7545 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.errors +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.errors @@ -1 +1,2 @@ [WARNING] smithy.example#ValidBddService: This shape applies a trait that is unstable: smithy.rules#bdd | UnstableTrait.smithy.rules#bdd +[WARNING] smithy.example#ValidBddService: This shape applies a trait that is unstable: smithy.rules#clientContextParams | UnstableTrait.smithy.rules#clientContextParams diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy index 637a8bd4a6b..19b79f90b41 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy @@ -3,7 +3,12 @@ $version: "2.0" namespace smithy.example use smithy.rules#bdd +use smithy.rules#clientContextParams +@clientContextParams( + Region: {type: "string", documentation: "docs"} + UseFips: {type: "boolean", documentation: "docs"} +) @bdd({ "parameters": { "Region": { From 1c05a2f4780a489d104cf06c24c92fccd48fb1c6 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Sat, 26 Jul 2025 22:59:22 -0500 Subject: [PATCH 04/23] Hide node layout of BDD and move domain to trait Rather than have the Bdd class contain Condition, Results, Parameters, etc, it now just deals with nodes. It also now hides the implementation detail of how the BDD nodes are laid out internally. BDD evaluation is internalized to the BDD as well rather than a separate BddEvaluator. This change provides faster evaluation, makes it possible to change the internal node data layout if necessary, and cleans up all the interacts we had with BddTrait (no need to always reach into Bdd). --- .../RuleSetAwsBuiltInValidator.java | 4 +- .../language/evaluation/RuleEvaluator.java | 43 +- .../language/evaluation/TestEvaluator.java | 8 +- .../smithy/rulesengine/logic/bdd/Bdd.java | 332 +++++++--------- .../rulesengine/logic/bdd/BddBuilder.java | 234 +++++++---- .../rulesengine/logic/bdd/BddCompiler.java | 6 +- .../logic/bdd/BddEquivalenceChecker.java | 38 +- .../rulesengine/logic/bdd/BddEvaluator.java | 86 ---- .../rulesengine/logic/bdd/BddFormatter.java | 212 +++++----- .../logic/bdd/BddNodeConsumer.java | 19 + .../rulesengine/logic/bdd/BddNodeHelpers.java | 150 ------- .../rulesengine/logic/bdd/BddTrait.java | 375 ++++++++++++++++++ .../rulesengine/logic/bdd/NodeReversal.java | 46 +-- .../logic/bdd/SiftingOptimization.java | 40 +- .../smithy/rulesengine/traits/BddTrait.java | 77 ---- .../validators/BddTraitValidator.java | 59 +-- .../EndpointTestsTraitValidator.java | 4 +- .../RuleSetAuthSchemesValidator.java | 4 +- .../validators/RuleSetBuiltInValidator.java | 4 +- .../RuleSetParamMissingDocsValidator.java | 4 +- .../validators/RuleSetParameterValidator.java | 4 +- .../validators/RuleSetTestCaseValidator.java | 7 +- .../validators/RuleSetUriValidator.java | 4 +- ...re.amazon.smithy.model.traits.TraitService | 2 +- .../rulesengine/logic/bdd/BddBuilderTest.java | 44 +- .../logic/bdd/BddCompilerTest.java | 55 ++- .../logic/bdd/BddEvaluatorTest.java | 215 ---------- .../smithy/rulesengine/logic/bdd/BddTest.java | 211 ++++------ .../rulesengine/logic/bdd/BddTraitTest.java | 78 ++++ .../logic/bdd/NodeReversalTest.java | 214 +++++----- .../logic/bdd/SiftingOptimizationTest.java | 118 ++++-- 31 files changed, 1351 insertions(+), 1346 deletions(-) delete mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluator.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddNodeConsumer.java delete mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddNodeHelpers.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java delete mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/BddTrait.java delete mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluatorTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java index e9005a82598..d59dae89833 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java @@ -14,7 +14,7 @@ import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.aws.language.functions.AwsBuiltIns; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; -import software.amazon.smithy.rulesengine.traits.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.utils.SetUtils; @@ -39,7 +39,7 @@ public List validate(Model model) { } for (ServiceShape s : model.getServiceShapesWithTrait(BddTrait.class)) { - validateRuleSetAwsBuiltIns(events, s, s.expectTrait(BddTrait.class).getBdd().getParameters()); + validateRuleSetAwsBuiltIns(events, s, s.expectTrait(BddTrait.class).getParameters()); } return events; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java index a7842c47c4a..e22635e6f53 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java @@ -27,7 +27,7 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.RuleValueVisitor; import software.amazon.smithy.rulesengine.logic.RuleBasedConditionEvaluator; import software.amazon.smithy.rulesengine.logic.bdd.Bdd; -import software.amazon.smithy.rulesengine.logic.bdd.BddEvaluator; +import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -50,6 +50,17 @@ public static Value evaluate(EndpointRuleSet ruleset, Map par return new RuleEvaluator().evaluateRuleSet(ruleset, parameterArguments); } + /** + * Initializes a new {@link RuleEvaluator} instances, and evaluates the provided BDD and parameter arguments. + * + * @param trait The trait to evaluate. + * @param args The rule-set parameter identifiers and values to evaluate the BDD against. + * @return The resulting value from the final matched rule. + */ + public static Value evaluate(BddTrait trait, Map args) { + return evaluate(trait.getBdd(), trait.getParameters(), trait.getConditions(), trait.getResults(), args); + } + /** * Initializes a new {@link RuleEvaluator} instances, and evaluates the provided BDD and parameter arguments. * @@ -57,27 +68,39 @@ public static Value evaluate(EndpointRuleSet ruleset, Map par * @param parameterArguments The rule-set parameter identifiers and values to evaluate the BDD against. * @return The resulting value from the final matched rule. */ - public static Value evaluate(Bdd bdd, Map parameterArguments) { - return new RuleEvaluator().evaluateBdd(bdd, parameterArguments); + public static Value evaluate( + Bdd bdd, + Parameters parameters, + List conditions, + List results, + Map parameterArguments + ) { + return new RuleEvaluator().evaluateBdd(bdd, parameters, conditions, results, parameterArguments); } - private Value evaluateBdd(Bdd bdd, Map parameterArguments) { + private Value evaluateBdd( + Bdd bdd, + Parameters parameters, + List conditions, + List results, + Map parameterArguments + ) { return scope.inScope(() -> { - for (Parameter parameter : bdd.getParameters()) { + for (Parameter parameter : parameters) { parameter.getDefault().ifPresent(value -> scope.insert(parameter.getName(), value)); } parameterArguments.forEach(scope::insert); - BddEvaluator evaluator = BddEvaluator.from(bdd); - Condition[] conds = bdd.getConditions().toArray(new Condition[0]); + + Condition[] conds = conditions.toArray(new Condition[0]); RuleBasedConditionEvaluator conditionEvaluator = new RuleBasedConditionEvaluator(this, conds); - int result = evaluator.evaluate(conditionEvaluator); + int result = bdd.evaluate(conditionEvaluator); - if (result <= 0) { + if (result < 0) { throw new RuntimeException("No BDD result matched"); } - Rule rule = bdd.getResults().get(result); + Rule rule = results.get(result); if (rule instanceof EndpointRule) { return resolveEndpoint(this, ((EndpointRule) rule).getEndpoint()); } else if (rule instanceof ErrorRule) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java index 5cd6b13009f..50cefa59276 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java @@ -13,7 +13,7 @@ import software.amazon.smithy.rulesengine.language.evaluation.value.EndpointValue; import software.amazon.smithy.rulesengine.language.evaluation.value.Value; import software.amazon.smithy.rulesengine.language.syntax.Identifier; -import software.amazon.smithy.rulesengine.logic.bdd.Bdd; +import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestExpectation; import software.amazon.smithy.rulesengine.traits.ExpectedEndpoint; @@ -40,12 +40,12 @@ public static void evaluate(EndpointRuleSet ruleset, EndpointTestCase testCase) } /** - * Evaluate the given rule-set and test case. Throws an exception in the event the test case does not pass. + * Evaluate the given BDD and test case. Throws an exception in the event the test case does not pass. * - * @param bdd The BDD to be tested. + * @param bdd The BDD trait to be tested. * @param testCase The test case. */ - public static void evaluate(Bdd bdd, EndpointTestCase testCase) { + public static void evaluate(BddTrait bdd, EndpointTestCase testCase) { Value result = RuleEvaluator.evaluate(bdd, createParams(testCase)); processResult(result, testCase); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java index 48802443e06..eab744c8cd5 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java @@ -7,107 +7,110 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.OutputStreamWriter; +import java.io.UncheckedIOException; import java.io.Writer; import java.nio.charset.StandardCharsets; import java.util.Arrays; -import java.util.List; import java.util.Objects; -import java.util.function.Function; -import software.amazon.smithy.model.node.Node; -import software.amazon.smithy.model.node.ToNode; -import software.amazon.smithy.rulesengine.language.EndpointRuleSet; -import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; -import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import java.util.function.Consumer; +import software.amazon.smithy.rulesengine.logic.ConditionEvaluator; /** - * Binary Decision Diagram (BDD) with complement edges for efficient endpoint rule evaluation. + * Binary Decision Diagram (BDD) with complement edges for efficient rule evaluation. * - *

A BDD provides a compact representation of decision logic where each condition is evaluated at most once along - * any path. Complement edges (negative references) enable further size reduction through node sharing. + *

This class represents a pure BDD structure without any knowledge of the specific + * conditions or results it represents. The interpretation of condition indices and + * result indices is left to the caller. * *

Reference Encoding: *

    *
  • {@code 0}: Invalid/unused reference (never appears in valid BDDs)
  • - *
  • {@code 1}: TRUE terminal; represents boolean true, treated as "no match" in endpoint resolution
  • - *
  • {@code -1}: FALSE terminal; represents boolean false, treated as "no match" in endpoint resolution
  • + *
  • {@code 1}: TRUE terminal
  • + *
  • {@code -1}: FALSE terminal
  • *
  • {@code 2, 3, ...}: Node references (points to nodes array at index ref-1)
  • - *
  • {@code -2, -3, ...}: Complement node references (logical NOT of the referenced node)
  • + *
  • {@code -2, -3, ...}: Complement node references (logical NOT)
  • *
  • {@code 100_000_000+}: Result terminals (100_000_000 + resultIndex)
  • *
- * - *

Result terminals are encoded as special references starting at 100_000_000 (RESULT_OFFSET). - * When evaluating the BDD, any reference >= 100_000_000 represents a result terminal that - * indexes into the results array (resultIndex = ref - 100_000_000). These are not stored - * as nodes in the nodes array. - * - *

Node Format: {@code [variable, high, low]} - *

    - *
  • {@code variable}: Condition index (0 to conditionCount-1)
  • - *
  • {@code high}: Reference to follow when condition evaluates to true
  • - *
  • {@code low}: Reference to follow when condition evaluates to false
  • - *
*/ -public final class Bdd implements ToNode { +public final class Bdd { /** - * Result reference encoding. - * - *

Results start at 100M to avoid collision with node references. + * Result reference encoding offset. */ public static final int RESULT_OFFSET = 100_000_000; - private final Parameters parameters; - private final List conditions; - private final List results; - private final int[][] nodes; + private final int[] variables; + private final int[] highs; + private final int[] lows; private final int rootRef; + private final int conditionCount; + private final int resultCount; /** - * Builds a BDD from an endpoint ruleset. + * Creates a BDD by streaming nodes directly into the structure. * - * @param ruleSet the ruleset to convert - * @return the constructed BDD + * @param rootRef the root reference + * @param conditionCount the number of conditions + * @param resultCount the number of results + * @param nodeCount the exact number of nodes + * @param nodeHandler a handler that will provide nodes via a consumer */ - public static Bdd from(EndpointRuleSet ruleSet) { - return from(Cfg.from(ruleSet)); - } + public Bdd(int rootRef, int conditionCount, int resultCount, int nodeCount, Consumer nodeHandler) { + this.rootRef = rootRef; + this.conditionCount = conditionCount; + this.resultCount = resultCount; - /** - * Builds a BDD from a control flow graph. - * - * @param cfg the control flow graph - * @return the constructed BDD - */ - public static Bdd from(Cfg cfg) { - return from(cfg, new BddBuilder(), ConditionOrderingStrategy.defaultOrdering()); + if (rootRef < 0 && rootRef != -1) { + throw new IllegalArgumentException("Root reference cannot be complemented: " + rootRef); + } + + InputNodeConsumer consumer = new InputNodeConsumer(nodeCount); + nodeHandler.accept(consumer); + + this.variables = consumer.variables; + this.highs = consumer.highs; + this.lows = consumer.lows; + + if (consumer.index != nodeCount) { + throw new IllegalStateException("Expected " + nodeCount + " node, but got " + consumer.index); + } } - static Bdd from(Cfg cfg, BddBuilder bddBuilder, ConditionOrderingStrategy orderingStrategy) { - return new BddCompiler(cfg, orderingStrategy, bddBuilder).compile(); + private static final class InputNodeConsumer implements BddNodeConsumer { + private int index = 0; + private final int[] variables; + private final int[] highs; + private final int[] lows; + + private InputNodeConsumer(int nodeCount) { + this.variables = new int[nodeCount]; + this.highs = new int[nodeCount]; + this.lows = new int[nodeCount]; + } + + @Override + public void accept(int var, int high, int low) { + variables[index] = var; + highs[index] = high; + lows[index] = low; + index++; + } } - public Bdd(Parameters params, List conditions, List results, int[][] nodes, int rootRef) { - this.parameters = Objects.requireNonNull(params, "params is null"); - this.conditions = conditions; - this.results = results; - this.nodes = nodes; + Bdd(int[] variables, int[] highs, int[] lows, int rootRef, int conditionCount, int resultCount) { + this.variables = Objects.requireNonNull(variables, "variables is null"); + this.highs = Objects.requireNonNull(highs, "highs is null"); + this.lows = Objects.requireNonNull(lows, "lows is null"); this.rootRef = rootRef; + this.conditionCount = conditionCount; + this.resultCount = resultCount; if (rootRef < 0 && rootRef != -1) { throw new IllegalArgumentException("Root reference cannot be complemented: " + rootRef); } - } - /** - * Gets the ordered list of conditions. - * - * @return list of conditions in evaluation order - */ - public List getConditions() { - return conditions; + if (variables.length != highs.length || variables.length != lows.length) { + throw new IllegalArgumentException("Array lengths must match"); + } } /** @@ -116,25 +119,25 @@ public List getConditions() { * @return condition count */ public int getConditionCount() { - return conditions.size(); + return conditionCount; } /** - * Gets the ordered list of results. + * Gets the number of results. * - * @return list of results (null represents no match) + * @return result count */ - public List getResults() { - return results; + public int getResultCount() { + return resultCount; } /** - * Gets the BDD nodes. + * Gets the number of nodes in the BDD. * - * @return array of node triples + * @return the node count */ - public int[][] getNodes() { - return nodes; + public int getNodeCount() { + return variables.length; } /** @@ -147,22 +150,68 @@ public int getRootRef() { } /** - * Get the input parameters of the ruleset. + * Gets the variable index for a node. + * + * @param nodeIndex the node index (0-based) + * @return the variable index + */ + public int getVariable(int nodeIndex) { + return variables[nodeIndex]; + } + + /** + * Gets the high (true) reference for a node. * - * @return input parameters. + * @param nodeIndex the node index (0-based) + * @return the high reference */ - public Parameters getParameters() { - return parameters; + public int getHigh(int nodeIndex) { + return highs[nodeIndex]; } /** - * Applies a transformation to the BDD and return a new BDD. + * Gets the low (false) reference for a node. * - * @param transformer Optimization to apply. - * @return the optimized BDD. + * @param nodeIndex the node index (0-based) + * @return the low reference */ - public Bdd transform(Function transformer) { - return transformer.apply(this); + public int getLow(int nodeIndex) { + return lows[nodeIndex]; + } + + /** + * Write all nodes to the consumer. + * + * @param consumer the consumer to receive the integers + */ + public void getNodes(BddNodeConsumer consumer) { + for (int i = 0; i < variables.length; i++) { + consumer.accept(variables[i], highs[i], lows[i]); + } + } + + /** + * Evaluates the BDD using the provided condition evaluator. + * + * @param ev the condition evaluator + * @return the result index, or -1 for no match + */ + public int evaluate(ConditionEvaluator ev) { + int ref = rootRef; + int[] vars = this.variables; + int[] hi = this.highs; + int[] lo = this.lows; + int off = RESULT_OFFSET; + + // keep walking while ref is a non-terminal node + while ((ref > 1 && ref < off) || (ref < -1 && ref > -off)) { + int idx = ref > 0 ? ref - 1 : -ref - 1; // Math.abs + // test ^ complement, pick hi or lo + ref = (ev.test(vars[idx]) ^ (ref < 0)) ? hi[idx] : lo[idx]; + } + + // +1/-1 => no match + return (ref == 1 || ref == -1) ? -1 : (ref - off); } /** @@ -205,7 +254,6 @@ public static boolean isTerminal(int ref) { * @return true if the reference is complemented */ public static boolean isComplemented(int ref) { - // -1 is FALSE terminal, not a complement return ref < 0 && ref != -1; } @@ -218,34 +266,22 @@ public boolean equals(Object obj) { } Bdd other = (Bdd) obj; return rootRef == other.rootRef - && conditions.equals(other.conditions) - && results.equals(other.results) - && nodesEqual(nodes, other.nodes) - && Objects.equals(parameters, other.parameters); - } - - private static boolean nodesEqual(int[][] a, int[][] b) { - if (a.length != b.length) { - return false; - } - for (int i = 0; i < a.length; i++) { - if (!Arrays.equals(a[i], b[i])) { - return false; - } - } - return true; + && conditionCount == other.conditionCount + && resultCount == other.resultCount + && Arrays.equals(variables, other.variables) + && Arrays.equals(highs, other.highs) + && Arrays.equals(lows, other.lows); } @Override public int hashCode() { - int hash = 31 * rootRef + nodes.length; + int hash = 31 * rootRef + variables.length; // Sample up to 16 nodes distributed across the BDD - int step = Math.max(1, nodes.length / 16); - for (int i = 0; i < nodes.length; i += step) { - int[] node = nodes[i]; - hash = 31 * hash + node[0]; - hash = 31 * hash + node[1]; - hash = 31 * hash + node[2]; + int step = Math.max(1, variables.length / 16); + for (int i = 0; i < variables.length; i += step) { + hash = 31 * hash + variables[i]; + hash = 31 * hash + highs[i]; + hash = 31 * hash + lows[i]; } return hash; } @@ -255,87 +291,11 @@ public String toString() { try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); Writer writer = new OutputStreamWriter(baos, StandardCharsets.UTF_8); - - // Calculate width for condition/result indices - int maxConditionIdx = conditions.size() - 1; - int maxResultIdx = results.size() - 1; - int conditionWidth = maxConditionIdx >= 0 ? String.valueOf(maxConditionIdx).length() + 1 : 2; - int resultWidth = maxResultIdx >= 0 ? String.valueOf(maxResultIdx).length() + 1 : 2; - int varWidth = Math.max(conditionWidth, resultWidth); - - writer.write("Bdd{\n"); - - // Write conditions - writer.write(" conditions ("); - writer.write(String.valueOf(getConditionCount())); - writer.write("):\n"); - for (int i = 0; i < conditions.size(); i++) { - writer.write(String.format(" %" + varWidth + "s: %s%n", "C" + i, conditions.get(i))); - } - - // Write results - writer.write(" results ("); - writer.write(String.valueOf(results.size())); - writer.write("):\n"); - for (int i = 0; i < results.size(); i++) { - writer.write(String.format(" %" + varWidth + "s: ", "R" + i)); - appendResult(writer, results.get(i)); - writer.write("\n"); - } - - // Write root - writer.write(" root: "); - writer.write(BddFormatter.formatReference(rootRef)); - writer.write("\n"); - - // Write nodes header - writer.write(" nodes ("); - writer.write(String.valueOf(nodes.length)); - writer.write("):\n"); - + new BddFormatter(this, writer, "").format(); writer.flush(); - - // Use BddFormatter for nodes - no need to strip anything since we control the indent - BddFormatter.builder() - .writer(writer) - .nodes(nodes) - .rootRef(rootRef) - .conditionCount(conditions.size()) - .resultCount(results.size()) - .indent(" ") - .build() - .format(); - - writer.write("}"); - writer.flush(); - return baos.toString(StandardCharsets.UTF_8.name()); } catch (IOException e) { - // Should never happen with ByteArrayOutputStream - throw new RuntimeException("Failed to format BDD", e); + throw new UncheckedIOException(e); } } - - private void appendResult(Writer writer, Rule result) throws IOException { - if (result == null) { - writer.write("(no match)"); - } else if (result instanceof EndpointRule) { - writer.write("Endpoint: "); - writer.write(((EndpointRule) result).getEndpoint().getUrl().toString()); - } else if (result instanceof ErrorRule) { - writer.write("Error: "); - writer.write(((ErrorRule) result).getError().toString()); - } else { - writer.write(result.getClass().getSimpleName()); - } - } - - public static Bdd fromNode(Node node) { - return BddNodeHelpers.fromNode(node); - } - - @Override - public Node toNode() { - return BddNodeHelpers.toNode(this); - } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java index adfd2bd2e27..68af133e085 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java @@ -4,10 +4,8 @@ */ package software.amazon.smithy.rulesengine.logic.bdd; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; -import java.util.List; import java.util.Map; /** @@ -26,13 +24,6 @@ *

  • -2, -3, -4, ...: Complement of BDD nodes
  • *
  • Bdd.RESULT_OFFSET+: Result terminals (100_000_000 + resultIndex)
  • * - * - *

    Node storage format: [variableIndex, highRef, lowRef] - * where variableIndex identifies the condition being tested: - *

      - *
    • -1: terminal node marker (only used for index 0)
    • - *
    • 0 to conditionCount-1: condition indices
    • - *
    */ final class BddBuilder { @@ -42,8 +33,13 @@ final class BddBuilder { // ITE operation cache for memoization private final Map iteCache; - // Node storage: index 0 is reserved for the terminal node - private List nodes; + + // Node storage: three separate arrays + private int[] variables; + private int[] highs; + private int[] lows; + private int nodeCount; + // Unique table for node deduplication private Map uniqueTable; // Track the boundary between conditions and results @@ -55,11 +51,16 @@ final class BddBuilder { * Creates a new BDD engine. */ public BddBuilder() { - this.nodes = new ArrayList<>(); + this.variables = new int[4]; + this.highs = new int[4]; + this.lows = new int[4]; + this.nodeCount = 1; this.uniqueTable = new HashMap<>(); this.iteCache = new HashMap<>(); // Initialize with terminal node at index 0 - nodes.add(new int[] {-1, TRUE_REF, FALSE_REF}); + variables[0] = -1; + highs[0] = TRUE_REF; + lows[0] = FALSE_REF; } /** @@ -134,7 +135,7 @@ public int makeNode(int var, int high, int low) { // Complement edge canonicalization: ensure complement only on low branch. // Don't apply this to result nodes or when branches contain results boolean flip = false; - if (!isResultVariable(var) && !isResult(high) && !isResult(low) && isComplement(low)) { + if (isComplement(low) && !isResult(high) && !isResult(low)) { high = negate(high); low = negate(low); flip = true; @@ -150,27 +151,44 @@ public int makeNode(int var, int high, int low) { } // Create new node - return insertNode(var, high, low, flip, nodes, uniqueTable); + return insertNode(var, high, low, flip); } - private int insertNode(int var, int high, int low, boolean flip, List nodes, Map tbl) { - int idx = nodes.size(); - nodes.add(new int[] {var, high, low}); - tbl.put(new TripleKey(var, high, low), idx); + private int insertNode(int var, int high, int low, boolean flip) { + ensureCapacity(); + + int idx = nodeCount; + variables[idx] = var; + highs[idx] = high; + lows[idx] = low; + nodeCount++; + + uniqueTable.put(new TripleKey(var, high, low), idx); int ref = toReference(idx); return flip ? negate(ref) : ref; } + private void ensureCapacity() { + if (nodeCount >= variables.length) { + // Grow by 50% + int newCapacity = variables.length + (variables.length >> 1); + variables = Arrays.copyOf(variables, newCapacity); + highs = Arrays.copyOf(highs, newCapacity); + lows = Arrays.copyOf(lows, newCapacity); + } + } + /** * Negates a BDD reference (logical NOT). * * @param ref the reference to negate * @return the negated reference - * @throws IllegalArgumentException if ref is a result terminal + * @throws IllegalArgumentException if ref is a result terminal or invalid. */ public int negate(int ref) { - if (isResult(ref)) { - throw new IllegalArgumentException("Cannot negate result terminal: " + ref); + if (ref == 0 || isResult(ref)) { + throw new IllegalArgumentException( + "Cannot negate " + (ref == 0 ? "invalid reference: " : "result terminal: ") + ref); } return -ref; } @@ -208,26 +226,6 @@ public boolean isResult(int ref) { return ref >= Bdd.RESULT_OFFSET; } - /** - * Checks if a variable index represents a result. - * - * @param varIdx the variable index - * @return true if result - */ - private boolean isResultVariable(int varIdx) { - return conditionCount != -1 && varIdx >= conditionCount; - } - - /** - * Checks if a reference is any kind of terminal. - * - * @param ref the reference to check - * @return true if any terminal type - */ - private boolean isAnyTerminal(int ref) { - return isTerminal(ref) || isResult(ref); - } - /** * Gets the variable index for a BDD node. * @@ -238,10 +236,13 @@ public int getVariable(int ref) { if (isTerminal(ref)) { return -1; } else if (isResult(ref)) { - // For results, return the virtual variable index (conditionCount + resultIndex) - return conditionCount + (ref - Bdd.RESULT_OFFSET); + return -1; // Result terminals are leaves and don't test variables } else { - return nodes.get(Math.abs(ref) - 1)[0]; + int nodeIndex = Math.abs(ref) - 1; + if (nodeIndex >= nodeCount || nodeIndex < 0) { + throw new IllegalStateException("Invalid node index: " + nodeIndex); + } + return variables[nodeIndex]; } } @@ -255,18 +256,21 @@ public int getVariable(int ref) { */ public int cofactor(int bdd, int varIndex, boolean value) { // Terminals and results are unaffected by cofactoring - if (isAnyTerminal(bdd)) { + if (isTerminal(bdd) || isResult(bdd)) { return bdd; } boolean complemented = isComplement(bdd); int nodeIndex = toNodeIndex(bdd); - int[] node = nodes.get(nodeIndex); - int nodeVar = node[0]; + if (nodeIndex >= nodeCount || nodeIndex < 0) { + throw new IllegalStateException("Invalid node index: " + nodeIndex); + } + + int nodeVar = variables[nodeIndex]; if (nodeVar == varIndex) { // This node tests our variable, so take the appropriate branch - int child = value ? node[1] : node[2]; + int child = value ? highs[nodeIndex] : lows[nodeIndex]; // Only negate if child is not a result return (complemented && !isResult(child)) ? negate(child) : child; } else if (nodeVar > varIndex) { @@ -274,8 +278,8 @@ public int cofactor(int bdd, int varIndex, boolean value) { return bdd; } else { // Variable appears deeper, so recurse on both branches - int high = cofactor(node[1], varIndex, value); - int low = cofactor(node[2], varIndex, value); + int high = cofactor(highs[nodeIndex], varIndex, value); + int low = cofactor(lows[nodeIndex], varIndex, value); int result = makeNode(nodeVar, high, low); return (complemented && !isResult(result)) ? negate(result) : result; } @@ -368,7 +372,7 @@ public int ite(int f, int g, int h) { // Create the actual key, and reserve cache slot to handle recursive calls TripleKey key = new TripleKey(f, g, h); - iteCache.put(key, FALSE_REF); + iteCache.put(key, 0); // invalid place holder // Shannon expansion: find the top variable int v = getTopVariable(f, g, h); @@ -409,27 +413,46 @@ public int reduce(int rootRef) { int absRoot = rootComp ? negate(rootRef) : rootRef; // Prep new storage - int N = nodes.size(); - List newNodes = new ArrayList<>(N); - Map newUnique = new HashMap<>(N * 2); - newNodes.add(new int[] {-1, TRUE_REF, FALSE_REF}); + int[] newVariables = new int[nodeCount]; + int[] newHighs = new int[nodeCount]; + int[] newLows = new int[nodeCount]; + Map newUnique = new HashMap<>(nodeCount * 2); + + // Initialize terminal node + newVariables[0] = -1; + newHighs[0] = TRUE_REF; + newLows[0] = FALSE_REF; + + // Create a mutable counter to track nodes added + int[] newNodeCounter = new int[] {1}; // Start at 1 (terminal already added) // Mapping array - int[] oldToNew = new int[N]; + int[] oldToNew = new int[nodeCount]; Arrays.fill(oldToNew, -1); // Recurse - int newRoot = reduceRec(absRoot, oldToNew, newNodes, newUnique); + int newRoot = reduceRec(absRoot, oldToNew, newVariables, newHighs, newLows, newNodeCounter, newUnique); - // Swap in - this.nodes = newNodes; + // Swap in - use the actual count of nodes created + this.variables = Arrays.copyOf(newVariables, newNodeCounter[0]); + this.highs = Arrays.copyOf(newHighs, newNodeCounter[0]); + this.lows = Arrays.copyOf(newLows, newNodeCounter[0]); + this.nodeCount = newNodeCounter[0]; this.uniqueTable = newUnique; clearCaches(); return rootComp ? negate(newRoot) : newRoot; } - private int reduceRec(int ref, int[] oldToNew, List newNodes, Map newUnique) { + private int reduceRec( + int ref, + int[] oldToNew, + int[] newVariables, + int[] newHighs, + int[] newLows, + int[] newNodeCounter, + Map newUnique + ) { // Handle terminals and results first if (isTerminal(ref)) { return ref; @@ -445,6 +468,11 @@ private int reduceRec(int ref, int[] oldToNew, List newNodes, Map= nodeCount || idx < 0) { + throw new IllegalStateException("Invalid node index: " + idx + " (nodeCount=" + nodeCount + ")"); + } + // Already processed? int mapped = oldToNew[idx]; if (mapped != -1) { @@ -452,31 +480,39 @@ private int reduceRec(int ref, int[] oldToNew, List newNodes, Map newNodes, Map newUnique) { + private int makeNodeInNew( + int var, + int hi, + int lo, + int[] newVariables, + int[] newHighs, + int[] newLows, + int[] newNodeCounter, + Map newUnique + ) { if (hi == lo) { return hi; } // Canonicalize complement edges (but not for result nodes) boolean comp = false; - if (!isResultVariable(var) && !isResult(hi) && !isResult(lo) && isComplement(lo)) { + if (!isResult(hi) && !isResult(lo) && isComplement(lo)) { hi = negate(hi); lo = negate(lo); comp = true; @@ -488,8 +524,21 @@ private int makeNodeInNew(int var, int hi, int lo, List newNodes, Map= newVariables.length) { + throw new IllegalStateException("Insufficient space allocated for reduction"); + } + + newVariables[idx] = var; + newHighs[idx] = hi; + newLows[idx] = lo; + + newUnique.put(new TripleKey(var, hi, lo), idx); + newNodeCounter[0]++; // Increment the node counter + + int ref = toReference(idx); + return comp ? negate(ref) : ref; } } @@ -531,32 +580,51 @@ public void clearCaches() { public BddBuilder reset() { clearCaches(); uniqueTable.clear(); - nodes.clear(); - nodes.add(new int[] {-1, TRUE_REF, FALSE_REF}); + Arrays.fill(variables, 0, nodeCount, 0); + Arrays.fill(highs, 0, nodeCount, 0); + Arrays.fill(lows, 0, nodeCount, 0); + nodeCount = 1; + // Re-initialize terminal node + variables[0] = -1; + highs[0] = TRUE_REF; + lows[0] = FALSE_REF; conditionCount = -1; return this; } /** - * Returns a defensive copy of the node table. + * Get the nodes as a flat array. * - * @return list of node arrays + * @return array of nodes, trimmed to actual size. */ - public List getNodes() { - List copy = new ArrayList<>(nodes.size()); - for (int[] node : nodes) { - copy.add(node.clone()); + public int[] getNodesArray() { + // Convert back to flat array for compatibility + int[] result = new int[nodeCount * 3]; + for (int i = 0; i < nodeCount; i++) { + int baseIdx = i * 3; + result[baseIdx] = variables[i]; + result[baseIdx + 1] = highs[i]; + result[baseIdx + 2] = lows[i]; } - return copy; + return result; } /** - * Get the array of nodes. + * Builds a BDD from the current state of the builder. * - * @return array of nodes. + * @return a new BDD instance + * @throws IllegalStateException if condition count has not been set */ - public int[][] getNodesArray() { - return nodes.toArray(new int[0][]); + Bdd build(int rootRef, int resultCount) { + if (conditionCount == -1) { + throw new IllegalStateException("Condition count must be set before building BDD"); + } + + // Create trimmed copies of the arrays with only the used portion + int[] trimmedVariables = Arrays.copyOf(variables, nodeCount); + int[] trimmedHighs = Arrays.copyOf(highs, nodeCount); + int[] trimmedLows = Arrays.copyOf(lows, nodeCount); + return new Bdd(trimmedVariables, trimmedHighs, trimmedLows, rootRef, conditionCount, resultCount); } private void validateBooleanOperands(int f, int g, String operation) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java index 1d47e379d1f..7268fe8fb10 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java @@ -11,7 +11,6 @@ import java.util.Map; import java.util.Objects; import java.util.logging.Logger; -import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; @@ -63,15 +62,14 @@ Bdd compile() { noMatchIndex = getOrCreateResultIndex(NoMatchRule.INSTANCE); int rootRef = convertCfgToBdd(cfg.getRoot()); rootRef = bddBuilder.reduce(rootRef); - Parameters parameters = cfg.getRuleSet().getParameters(); - Bdd bdd = new Bdd(parameters, orderedConditions, indexedResults, bddBuilder.getNodesArray(), rootRef); + Bdd bdd = bddBuilder.build(rootRef, indexedResults.size()); long elapsed = System.currentTimeMillis() - start; LOGGER.fine(String.format( "BDD compilation complete: %d conditions, %d results, %d BDD nodes in %dms", orderedConditions.size(), indexedResults.size(), - bddBuilder.getNodes().size() - 1, + bdd.getNodeCount(), elapsed)); return bdd; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java index 0cdf6028d6a..e83dd0c176e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java @@ -19,6 +19,7 @@ import software.amazon.smithy.rulesengine.logic.ConditionEvaluator; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; import software.amazon.smithy.rulesengine.logic.cfg.CfgNode; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionData; import software.amazon.smithy.rulesengine.logic.cfg.ConditionNode; import software.amazon.smithy.rulesengine.logic.cfg.ResultNode; @@ -39,6 +40,8 @@ public final class BddEquivalenceChecker { private final Cfg cfg; private final Bdd bdd; + private final List conditions; + private final List results; private final List parameters; private final Map conditionToIndex = new HashMap<>(); @@ -48,17 +51,19 @@ public final class BddEquivalenceChecker { private int testsRun = 0; private long startTime; - public static BddEquivalenceChecker of(Cfg cfg, Bdd bdd) { - return new BddEquivalenceChecker(cfg, bdd); + public static BddEquivalenceChecker of(Cfg cfg, Bdd bdd, List conditions, List results) { + return new BddEquivalenceChecker(cfg, bdd, conditions, results); } - private BddEquivalenceChecker(Cfg cfg, Bdd bdd) { + private BddEquivalenceChecker(Cfg cfg, Bdd bdd, List conditions, List results) { this.cfg = cfg; this.bdd = bdd; + this.conditions = conditions; + this.results = results; this.parameters = new ArrayList<>(cfg.getRuleSet().getParameters().toList()); - for (int i = 0; i < bdd.getConditions().size(); i++) { - conditionToIndex.put(bdd.getConditions().get(i), i); + for (int i = 0; i < conditions.size(); i++) { + conditionToIndex.put(conditions.get(i), i); } } @@ -127,7 +132,7 @@ private void verifyResults() { } // Remove the NoMatchRule that's added by default. It's not in the CFG. - Set bddResults = new HashSet<>(bdd.getResults()); + Set bddResults = new HashSet<>(results); bddResults.removeIf(v -> v == NoMatchRule.INSTANCE); if (!cfgResults.equals(bddResults)) { @@ -242,8 +247,8 @@ private void verifyCase(long mask) { errorMsg.append("Test case #").append(testsRun).append("\n"); errorMsg.append("Condition mask: ").append(Long.toBinaryString(mask)).append("\n"); errorMsg.append("\nCondition details:\n"); - for (int i = 0; i < bdd.getConditions().size(); i++) { - Condition condition = bdd.getConditions().get(i); + for (int i = 0; i < conditions.size(); i++) { + Condition condition = conditions.get(i); boolean value = (mask & (1L << i)) != 0; errorMsg.append(" Condition ") .append(i) @@ -261,7 +266,15 @@ private void verifyCase(long mask) { } private Rule evaluateCfgWithMask(ConditionEvaluator maskEvaluator) { - CfgNode result = evaluateCfgNode(cfg.getRoot(), conditionToIndex, maskEvaluator); + // Get the condition data from CFG + ConditionData conditionData = cfg.getConditionData(); + Map cfgConditionToIndex = new HashMap<>(); + Condition[] cfgConditions = conditionData.getConditions(); + for (int i = 0; i < cfgConditions.length; i++) { + cfgConditionToIndex.put(cfgConditions[i], i); + } + + CfgNode result = evaluateCfgNode(cfg.getRoot(), cfgConditionToIndex, maskEvaluator); if (result instanceof ResultNode) { return ((ResultNode) result).getResult(); } @@ -285,7 +298,7 @@ private CfgNode evaluateCfgNode( Integer index = conditionToIndex.get(condition); if (index == null) { - throw new IllegalStateException("Condition not found in BDD: " + condition); + throw new IllegalStateException("Condition not found in CFG: " + condition); } boolean conditionResult = maskEvaluator.test(index); @@ -308,9 +321,8 @@ private CfgNode evaluateCfgNode( private Rule evaluateBdd(long mask) { FixedMaskEvaluator evaluator = new FixedMaskEvaluator(mask); - BddEvaluator bddEvaluator = BddEvaluator.from(bdd); - int resultIndex = bddEvaluator.evaluate(evaluator); - return resultIndex < 0 ? null : bdd.getResults().get(resultIndex); + int resultIndex = bdd.evaluate(evaluator); + return resultIndex < 0 ? null : results.get(resultIndex); } private boolean resultsEqual(Rule r1, Rule r2) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluator.java deleted file mode 100644 index a9aead5155a..00000000000 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluator.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.logic.bdd; - -import software.amazon.smithy.rulesengine.logic.ConditionEvaluator; - -/** - * Simple BDD evaluator that works directly with BDD nodes. - */ -public final class BddEvaluator { - - private final int[][] nodes; - private final int rootRef; - private final int conditionCount; - - private BddEvaluator(int[][] nodes, int rootRef, int conditionCount) { - this.nodes = nodes; - this.rootRef = rootRef; - this.conditionCount = conditionCount; - } - - /** - * Create evaluator from a Bdd object. - * - * @param bdd the BDD - * @return the evaluator - */ - public static BddEvaluator from(Bdd bdd) { - return from(bdd.getNodes(), bdd.getRootRef(), bdd.getConditionCount()); - } - - /** - * Create evaluator from BDD data. - * - * @param nodes BDD nodes array - * @param rootRef root reference - * @param conditionCount number of conditions - * @return the evaluator - */ - public static BddEvaluator from(int[][] nodes, int rootRef, int conditionCount) { - return new BddEvaluator(nodes, rootRef, conditionCount); - } - - /** - * Evaluates the BDD. - * - * @param evaluator the condition evaluator - * @return the result index, or -1 for no match - */ - public int evaluate(ConditionEvaluator evaluator) { - int resultOff = Bdd.RESULT_OFFSET; - int ref = this.rootRef; - - while (true) { - int abs = Math.abs(ref); - // stop once we hit a terminal (+/-1) or a result node (|ref| >= resultOff) - if (abs <= 1 || abs >= resultOff) { - break; - } - - int[] node = this.nodes[abs - 1]; - int varIdx = node[0]; - int hi = node[1]; - int lo = node[2]; - - // swap branches for a complemented pointer - if (ref < 0) { - int tmp = hi; - hi = lo; - lo = tmp; - } - - ref = evaluator.test(varIdx) ? hi : lo; - } - - // +/-1 means no match. - if (ref == 1 || ref == -1) { - return -1; - } - - int resultIdx = ref - resultOff; - return resultIdx == 0 ? -1 : resultIdx; - } -} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddFormatter.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddFormatter.java index fd415876f08..124be11f74d 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddFormatter.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddFormatter.java @@ -4,122 +4,162 @@ */ package software.amazon.smithy.rulesengine.logic.bdd; +import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.OutputStream; import java.io.OutputStreamWriter; -import java.io.UncheckedIOException; import java.io.Writer; import java.nio.charset.StandardCharsets; /** - * Formats BDD node structures to a stream without building the entire representation in memory. + * Formats BDD node structures to a writer. */ public final class BddFormatter { + private final Bdd bdd; private final Writer writer; - private final int[][] nodes; - private final int rootRef; - private final int conditionCount; - private final int resultCount; private final String indent; - private BddFormatter(Builder builder) { - this.writer = builder.writer; - this.nodes = builder.nodes; - this.rootRef = builder.rootRef; - this.conditionCount = builder.conditionCount; - this.resultCount = builder.resultCount; - this.indent = builder.indent; + /** + * Creates a BDD formatter. + * + * @param bdd the BDD to format + * @param writer the writer to format to + * @param indent the indentation string + */ + public BddFormatter(Bdd bdd, Writer writer, String indent) { + this.bdd = bdd; + this.writer = writer; + this.indent = indent; } /** - * Creates a builder for BddFormatter. + * Formats a BDD to a string. * - * @return a new builder + * @param bdd the BDD to format + * @return a formatted string representation */ - public static Builder builder() { - return new Builder(); + public static String format(Bdd bdd) { + return format(bdd, ""); } /** - * Formats the BDD node structure. + * Formats a BDD to a string with custom indent. + * + * @param bdd the BDD to format + * @param indent the indentation string + * @return a formatted string representation */ - public void format() { + public static String format(Bdd bdd, String indent) { try { - // Calculate formatting widths - FormatContext ctx = calculateFormatContext(); - - // Write root - writer.write(indent); - writer.write("Root: "); - writer.write(formatReference(rootRef)); - writer.write("\n"); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + Writer writer = new OutputStreamWriter(baos, StandardCharsets.UTF_8); + new BddFormatter(bdd, writer, indent).format(); + writer.flush(); + return baos.toString(StandardCharsets.UTF_8.name()); + } catch (IOException e) { + // Should never happen with ByteArrayOutputStream + throw new RuntimeException("Failed to format BDD", e); + } + } - // Write nodes + /** + * Formats the BDD structure. + * + * @throws IOException if writing fails + */ + public void format() throws IOException { + // Calculate formatting widths + FormatContext ctx = calculateFormatContext(); + + // Write header + writer.write(indent); + writer.write("Bdd {\n"); + + // Write counts + writer.write(indent); + writer.write(" conditions: "); + writer.write(String.valueOf(bdd.getConditionCount())); + writer.write("\n"); + + writer.write(indent); + writer.write(" results: "); + writer.write(String.valueOf(bdd.getResultCount())); + writer.write("\n"); + + // Write root + writer.write(indent); + writer.write(" root: "); + writer.write(formatReference(bdd.getRootRef())); + writer.write("\n"); + + // Write nodes + writer.write(indent); + writer.write(" nodes ("); + writer.write(String.valueOf(bdd.getNodeCount())); + writer.write("):\n"); + + for (int i = 0; i < bdd.getNodeCount(); i++) { writer.write(indent); - writer.write("Nodes:\n"); - - for (int i = 0; i < nodes.length; i++) { - writer.write(indent); - writer.write(" "); - writer.write(String.format("%" + ctx.indexWidth + "d", i)); - writer.write(": "); - - if (i == 0 && nodes[i][0] == -1) { - writer.write("terminal"); - } else { - formatNode(nodes[i], ctx); - } - writer.write("\n"); + writer.write(" "); + writer.write(String.format("%" + ctx.indexWidth + "d", i)); + writer.write(": "); + + if (i == 0 && bdd.getVariable(0) == -1) { + writer.write("terminal"); + } else { + formatNode(i, ctx); } - - writer.flush(); - } catch (IOException e) { - throw new UncheckedIOException(e); + writer.write("\n"); } + + writer.write(indent); + writer.write("}"); } private FormatContext calculateFormatContext() { + int nodeCount = bdd.getNodeCount(); int maxVarIdx = -1; // Scan nodes to find max variable index - for (int i = 1; i < nodes.length; i++) { - int varIdx = nodes[i][0]; + for (int i = 1; i < nodeCount; i++) { + int varIdx = bdd.getVariable(i); if (varIdx >= 0) { maxVarIdx = Math.max(maxVarIdx, varIdx); } } // Calculate widths + int conditionCount = bdd.getConditionCount(); + int resultCount = bdd.getResultCount(); int conditionWidth = conditionCount > 0 ? String.valueOf(conditionCount - 1).length() + 1 : 2; int resultWidth = resultCount > 0 ? String.valueOf(resultCount - 1).length() + 1 : 2; int varWidth = Math.max(Math.max(conditionWidth, resultWidth), String.valueOf(maxVarIdx).length()); - int indexWidth = String.valueOf(nodes.length - 1).length(); + int indexWidth = String.valueOf(nodeCount - 1).length(); return new FormatContext(varWidth, indexWidth); } - private void formatNode(int[] node, FormatContext ctx) throws IOException { + private void formatNode(int nodeIndex, FormatContext ctx) throws IOException { writer.write("["); // Variable reference - int varIdx = node[0]; + int varIdx = bdd.getVariable(nodeIndex); String varRef = formatVariableIndex(varIdx); writer.write(String.format("%" + ctx.varWidth + "s", varRef)); // High and low references writer.write(", "); - writer.write(String.format("%6s", formatReference(node[1]))); + writer.write(String.format("%6s", formatReference(bdd.getHigh(nodeIndex)))); writer.write(", "); - writer.write(String.format("%6s", formatReference(node[2]))); + writer.write(String.format("%6s", formatReference(bdd.getLow(nodeIndex)))); writer.write("]"); } private String formatVariableIndex(int varIdx) { - if (conditionCount > 0 && varIdx < conditionCount) { + if (bdd.getConditionCount() > 0 && varIdx < bdd.getConditionCount()) { return "C" + varIdx; - } else if (conditionCount > 0 && resultCount > 0) { - return "R" + (varIdx - conditionCount); + } else if (bdd.getConditionCount() > 0 && bdd.getResultCount() > 0) { + return "R" + (varIdx - bdd.getConditionCount()); } else { return String.valueOf(varIdx); } @@ -156,62 +196,4 @@ private static class FormatContext { this.indexWidth = indexWidth; } } - - /** - * Builder for BddFormatter. - */ - public static final class Builder { - private Writer writer; - private int[][] nodes; - private int rootRef; - private int conditionCount = 0; - private int resultCount = 0; - private String indent = ""; - - private Builder() {} - - public Builder writer(Writer writer) { - this.writer = writer; - return this; - } - - public Builder writer(OutputStream out) { - return writer(new OutputStreamWriter(out, StandardCharsets.UTF_8)); - } - - public Builder nodes(int[][] nodes) { - this.nodes = nodes; - return this; - } - - public Builder rootRef(int rootRef) { - this.rootRef = rootRef; - return this; - } - - public Builder conditionCount(int conditionCount) { - this.conditionCount = conditionCount; - return this; - } - - public Builder resultCount(int resultCount) { - this.resultCount = resultCount; - return this; - } - - public Builder indent(String indent) { - this.indent = indent; - return this; - } - - public BddFormatter build() { - if (writer == null) { - throw new IllegalStateException("writer is required"); - } - if (nodes == null) { - throw new IllegalStateException("nodes are required"); - } - return new BddFormatter(this); - } - } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddNodeConsumer.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddNodeConsumer.java new file mode 100644 index 00000000000..4c34cea386a --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddNodeConsumer.java @@ -0,0 +1,19 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +/** + * Consumer that receives every node in a {@link Bdd}. + */ +public interface BddNodeConsumer { + /** + * Receives a BDD node. + * + * @param var Variable. + * @param high High reference. + * @param low Low reference. + */ + void accept(int var, int high, int low); +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddNodeHelpers.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddNodeHelpers.java deleted file mode 100644 index d3cd0ed153f..00000000000 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddNodeHelpers.java +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.logic.bdd; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.DataInputStream; -import java.io.DataOutputStream; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Base64; -import java.util.List; -import java.util.Set; -import software.amazon.smithy.model.node.Node; -import software.amazon.smithy.model.node.ObjectNode; -import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; -import software.amazon.smithy.utils.SetUtils; - -final class BddNodeHelpers { - private static final int[] TERMINAL_NODE = new int[] {-1, 1, -1}; - private static final Set ALLOWED_PROPERTIES = SetUtils.of( - "parameters", - "conditions", - "results", - "root", - "nodes", - "nodeCount"); - - private BddNodeHelpers() {} - - static Node toNode(Bdd bdd) { - ObjectNode.Builder builder = ObjectNode.builder(); - - List conditions = new ArrayList<>(); - for (Condition c : bdd.getConditions()) { - conditions.add(c.toNode()); - } - - List results = new ArrayList<>(); - if (!(bdd.getResults().get(0) instanceof NoMatchRule)) { - throw new IllegalArgumentException("BDD must always have a NoMatchRule as the first result"); - } - for (int i = 1; i < bdd.getResults().size(); i++) { - Rule result = bdd.getResults().get(i); - if (result instanceof NoMatchRule) { - throw new IllegalArgumentException("NoMatch rules can only appear at rule index 0. Found at index" + i); - } - results.add(bdd.getResults().get(i).toNode()); - } - - return builder - .withMember("parameters", bdd.getParameters().toNode()) - .withMember("conditions", Node.fromNodes(conditions)) - .withMember("results", Node.fromNodes(results)) - .withMember("root", bdd.getRootRef()) - .withMember("nodes", encodeNodes(bdd)) - .withMember("nodeCount", bdd.getNodes().length) - .build(); - } - - static Bdd fromNode(Node node) { - ObjectNode obj = node.expectObjectNode(); - obj.warnIfAdditionalProperties(ALLOWED_PROPERTIES); - Parameters params = Parameters.fromNode(obj.expectObjectMember("parameters")); - List conditions = obj.expectArrayMember("conditions").getElementsAs(Condition::fromNode); - - // Read the results and prepend NoMatchRule at index 0 - List serializedResults = obj.expectArrayMember("results").getElementsAs(Rule::fromNode); - List results = new ArrayList<>(); - results.add(NoMatchRule.INSTANCE); // Always add no-match at index 0 - results.addAll(serializedResults); - - String nodesBase64 = obj.expectStringMember("nodes").getValue(); - int nodeCount = obj.expectNumberMember("nodeCount").getValue().intValue(); - int[][] nodes = decodeNodes(nodesBase64, nodeCount); - int rootRef = obj.expectNumberMember("root").getValue().intValue(); - return new Bdd(params, conditions, results, nodes, rootRef); - } - - static String encodeNodes(Bdd bdd) { - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - DataOutputStream dos = new DataOutputStream(baos)) { - int[][] nodes = bdd.getNodes(); - for (int[] node : nodes) { - writeVarInt(dos, node[0]); - writeVarInt(dos, node[1]); - writeVarInt(dos, node[2]); - } - dos.flush(); - return Base64.getEncoder().encodeToString(baos.toByteArray()); - } catch (IOException e) { - throw new RuntimeException("Failed to encode BDD nodes", e); - } - } - - static int[][] decodeNodes(String base64, int nodeCount) { - if (base64.isEmpty() || nodeCount == 0) { - return new int[][] {TERMINAL_NODE}; - } - - byte[] data = Base64.getDecoder().decode(base64); - int[][] nodes = new int[nodeCount][]; - - try (ByteArrayInputStream bais = new ByteArrayInputStream(data); - DataInputStream dis = new DataInputStream(bais)) { - for (int i = 0; i < nodeCount; i++) { - int varIdx = readVarInt(dis); - int high = readVarInt(dis); - int low = readVarInt(dis); - nodes[i] = new int[] {varIdx, high, low}; - } - if (bais.available() > 0) { - throw new IllegalArgumentException("Extra data found after decoding " + nodeCount + - " nodes. " + bais.available() + " bytes remaining."); - } - return nodes; - } catch (IOException e) { - throw new RuntimeException("Failed to decode BDD nodes", e); - } - } - - // Zig-zag + varint encode of a signed int - private static void writeVarInt(DataOutputStream dos, int value) throws IOException { - int zz = (value << 1) ^ (value >> 31); - while ((zz & ~0x7F) != 0) { - dos.writeByte((zz & 0x7F) | 0x80); - zz >>>= 7; - } - dos.writeByte(zz); - } - - // Decode a signed int from varint + zig-zag. - private static int readVarInt(DataInputStream dis) throws IOException { - int shift = 0, result = 0; - while (true) { - byte b = dis.readByte(); - result |= (b & 0x7F) << shift; - if ((b & 0x80) == 0) - break; - shift += 7; - } - // reverse zig-zag - return (result >>> 1) ^ -(result & 1); - } -} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java new file mode 100644 index 00000000000..4116837bbab --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java @@ -0,0 +1,375 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.model.node.ObjectNode; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.model.traits.AbstractTrait; +import software.amazon.smithy.model.traits.AbstractTraitBuilder; +import software.amazon.smithy.model.traits.Trait; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.logic.cfg.CfgNode; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionData; +import software.amazon.smithy.rulesengine.logic.cfg.ResultNode; +import software.amazon.smithy.utils.SetUtils; +import software.amazon.smithy.utils.SmithyBuilder; +import software.amazon.smithy.utils.ToSmithyBuilder; + +/** + * Trait containing a precompiled BDD with full context for endpoint resolution. + */ +public final class BddTrait extends AbstractTrait implements ToSmithyBuilder { + public static final ShapeId ID = ShapeId.from("smithy.rules#bdd"); + + private static final Set ALLOWED_PROPERTIES = SetUtils.of( + "parameters", + "conditions", + "results", + "root", + "nodes", + "nodeCount"); + + private final Parameters parameters; + private final List conditions; + private final List results; + private final Bdd bdd; + + private BddTrait(Builder builder) { + super(ID, builder.getSourceLocation()); + this.parameters = SmithyBuilder.requiredState("parameters", builder.parameters); + this.conditions = SmithyBuilder.requiredState("conditions", builder.conditions); + this.results = SmithyBuilder.requiredState("results", builder.results); + this.bdd = SmithyBuilder.requiredState("bdd", builder.bdd); + } + + /** + * Creates a BddTrait from a control flow graph. + * + * @param cfg the control flow graph to compile + * @return the BddTrait containing the compiled BDD and all context + */ + public static BddTrait from(Cfg cfg) { + ConditionData conditionData = cfg.getConditionData(); + List conditions = Arrays.asList(conditionData.getConditions()); + + // Compile the BDD + BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + Bdd bdd = compiler.compile(); + + List results = extractResultsFromCfg(cfg, bdd); + Parameters parameters = cfg.getRuleSet().getParameters(); + return builder().parameters(parameters).conditions(conditions).results(results).bdd(bdd).build(); + } + + private static List extractResultsFromCfg(Cfg cfg, Bdd bdd) { + // The BddCompiler always puts NoMatchRule at index 0 + List results = new ArrayList<>(); + results.add(NoMatchRule.INSTANCE); + + Set uniqueResults = new LinkedHashSet<>(); + for (CfgNode node : cfg) { + if (node instanceof ResultNode) { + Rule result = ((ResultNode) node).getResult(); + if (result != null && !(result instanceof NoMatchRule)) { + uniqueResults.add(result.withoutConditions()); + } + } + } + + results.addAll(uniqueResults); + + if (results.size() != bdd.getResultCount()) { + throw new IllegalStateException(String.format( + "Result count mismatch: found %d results in CFG but BDD expects %d", + results.size(), + bdd.getResultCount())); + } + + return results; + } + + /** + * Gets the parameters for the endpoint rules. + * + * @return the parameters + */ + public Parameters getParameters() { + return parameters; + } + + /** + * Gets the ordered list of conditions. + * + * @return the conditions in evaluation order + */ + public List getConditions() { + return conditions; + } + + /** + * Gets the ordered list of results. + * + * @return the results (index 0 is always NoMatchRule) + */ + public List getResults() { + return results; + } + + /** + * Gets the BDD structure. + * + * @return the BDD + */ + public Bdd getBdd() { + return bdd; + } + + @Override + protected Node createNode() { + ObjectNode.Builder builder = ObjectNode.builder(); + builder.withMember("parameters", parameters.toNode()); + + List conditionNodes = new ArrayList<>(); + for (Condition c : conditions) { + conditionNodes.add(c.toNode()); + } + builder.withMember("conditions", Node.fromNodes(conditionNodes)); + + // Results (skip NoMatchRule at index 0 for serialization) + List resultNodes = new ArrayList<>(); + if (!results.isEmpty() && !(results.get(0) instanceof NoMatchRule)) { + throw new IllegalStateException("BDD must always have a NoMatchRule as the first result"); + } + for (int i = 1; i < results.size(); i++) { + Rule result = results.get(i); + if (result instanceof NoMatchRule) { + throw new IllegalStateException("NoMatch rules can only appear at rule index 0. Found at index " + i); + } + resultNodes.add(result.toNode()); + } + builder.withMember("results", Node.fromNodes(resultNodes)); + + builder.withMember("root", bdd.getRootRef()); + builder.withMember("nodeCount", bdd.getNodeCount()); + builder.withMember("nodes", encodeNodes(bdd)); + + return builder.build(); + } + + /** + * Creates a BddTrait from a Node representation. + * + * @param node the node to parse + * @return the BddTrait + */ + public static BddTrait fromNode(Node node) { + ObjectNode obj = node.expectObjectNode(); + obj.warnIfAdditionalProperties(ALLOWED_PROPERTIES); + Parameters params = Parameters.fromNode(obj.expectObjectMember("parameters")); + List conditions = obj.expectArrayMember("conditions").getElementsAs(Condition::fromNode); + + List serializedResults = obj.expectArrayMember("results").getElementsAs(Rule::fromNode); + List results = new ArrayList<>(); + results.add(NoMatchRule.INSTANCE); // Always add no-match at index 0 + results.addAll(serializedResults); + + String nodesBase64 = obj.expectStringMember("nodes").getValue(); + int nodeCount = obj.expectNumberMember("nodeCount").getValue().intValue(); + int rootRef = obj.expectNumberMember("root").getValue().intValue(); + + Bdd bdd = decodeBdd(nodesBase64, nodeCount, rootRef, conditions.size(), results.size()); + + BddTrait trait = builder() + .sourceLocation(node) + .parameters(params) + .conditions(conditions) + .results(results) + .bdd(bdd) + .build(); + trait.setNodeCache(node); + return trait; + } + + private static String encodeNodes(Bdd bdd) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataOutputStream dos = new DataOutputStream(baos)) { + bdd.getNodes((varIdx, high, low) -> { + try { + writeVarInt(dos, varIdx); + writeVarInt(dos, high); + writeVarInt(dos, low); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + dos.flush(); + return Base64.getEncoder().encodeToString(baos.toByteArray()); + } catch (IOException e) { + throw new RuntimeException("Failed to encode BDD nodes", e); + } catch (UncheckedIOException e) { + throw new RuntimeException("Failed to encode BDD nodes", e.getCause()); + } + } + + private static Bdd decodeBdd(String base64, int nodeCount, int rootRef, int conditionCount, int resultCount) { + // Special case for empty BDD with just terminal (should never happen, but just in case). + if (base64.isEmpty() || nodeCount == 0) { + return new Bdd(rootRef, conditionCount, resultCount, 1, consumer -> { + consumer.accept(-1, 1, -1); + }); + } + + byte[] data = Base64.getDecoder().decode(base64); + return new Bdd(rootRef, conditionCount, resultCount, nodeCount, consumer -> { + try (ByteArrayInputStream bais = new ByteArrayInputStream(data); + DataInputStream dis = new DataInputStream(bais)) { + for (int i = 0; i < nodeCount; i++) { + consumer.accept(readVarInt(dis), readVarInt(dis), readVarInt(dis)); + } + if (bais.available() > 0) { + throw new IllegalArgumentException("Extra data found after decoding " + nodeCount + + " nodes. " + bais.available() + " bytes remaining."); + } + } catch (IOException e) { + throw new RuntimeException("Failed to decode BDD nodes", e); + } + }); + } + + // Zig-zag + varint encode of a signed int + private static void writeVarInt(DataOutputStream dos, int value) throws IOException { + int zz = (value << 1) ^ (value >> 31); + while ((zz & ~0x7F) != 0) { + dos.writeByte((zz & 0x7F) | 0x80); + zz >>>= 7; + } + dos.writeByte(zz); + } + + // Decode a signed int from varint + zig-zag + private static int readVarInt(DataInputStream dis) throws IOException { + int shift = 0, result = 0; + while (true) { + byte b = dis.readByte(); + result |= (b & 0x7F) << shift; + if ((b & 0x80) == 0) { + break; + } + shift += 7; + } + // reverse zig-zag + return (result >>> 1) ^ -(result & 1); + } + + /** + * Creates a new builder for BddTrait. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); + } + + @Override + public Builder toBuilder() { + return builder() + .sourceLocation(getSourceLocation()) + .parameters(parameters) + .conditions(conditions) + .results(results) + .bdd(bdd); + } + + /** + * Builder for BddTrait. + */ + public static final class Builder extends AbstractTraitBuilder { + private Parameters parameters; + private List conditions; + private List results; + private Bdd bdd; + + private Builder() {} + + /** + * Sets the parameters. + * + * @param parameters the parameters + * @return this builder + */ + public Builder parameters(Parameters parameters) { + this.parameters = parameters; + return this; + } + + /** + * Sets the conditions. + * + * @param conditions the conditions in evaluation order + * @return this builder + */ + public Builder conditions(List conditions) { + this.conditions = conditions; + return this; + } + + /** + * Sets the results. + * + * @param results the results (must have NoMatchRule at index 0) + * @return this builder + */ + public Builder results(List results) { + this.results = results; + return this; + } + + /** + * Sets the BDD structure. + * + * @param bdd the BDD + * @return this builder + */ + public Builder bdd(Bdd bdd) { + this.bdd = bdd; + return this; + } + + @Override + public BddTrait build() { + return new BddTrait(this); + } + } + + public static final class Provider extends AbstractTrait.Provider { + public Provider() { + super(ID); + } + + @Override + public Trait createTrait(ShapeId target, Node value) { + BddTrait trait = BddTrait.fromNode(value); + trait.setNodeCache(value); + return trait; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java index 1131e0d3173..3aac6c56235 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java @@ -20,8 +20,7 @@ public final class NodeReversal implements Function { @Override public Bdd apply(Bdd bdd) { LOGGER.info("Starting BDD node reversal optimization"); - int[][] nodes = bdd.getNodes(); - int nodeCount = nodes.length; + int nodeCount = bdd.getNodeCount(); if (nodeCount <= 2) { return bdd; @@ -38,27 +37,22 @@ public Bdd apply(Bdd bdd) { oldToNew[oldIdx] = newIdx; } - // Create new node array with reversed order - int[][] newNodes = new int[nodeCount][]; - newNodes[0] = nodes[0].clone(); // Terminal stays at index 0 - - // Add nodes in reverse order, updating their references - int newIdx = 1; - for (int oldIdx = nodeCount - 1; oldIdx >= 1; oldIdx--) { - int[] oldNode = nodes[oldIdx]; - newNodes[newIdx++] = new int[] { - oldNode[0], // variable index stays the same - remapReference(oldNode[1], oldToNew), // remap high reference - remapReference(oldNode[2], oldToNew) // remap low reference - }; - } - // Remap the root reference int newRoot = remapReference(bdd.getRootRef(), oldToNew); - LOGGER.info("BDD node reversal complete"); - - return new Bdd(bdd.getParameters(), bdd.getConditions(), bdd.getResults(), newNodes, newRoot); + // Create reversed BDD using streaming constructor + return new Bdd(newRoot, bdd.getConditionCount(), bdd.getResultCount(), nodeCount, consumer -> { + // Terminal stays at index 0 + consumer.accept(bdd.getVariable(0), bdd.getHigh(0), bdd.getLow(0)); + + // Add nodes in reverse order, updating their references + for (int oldIdx = nodeCount - 1; oldIdx >= 1; oldIdx--) { + int var = bdd.getVariable(oldIdx); + int high = remapReference(bdd.getHigh(oldIdx), oldToNew); + int low = remapReference(bdd.getLow(oldIdx), oldToNew); + consumer.accept(var, high, low); + } + }); } /** @@ -69,21 +63,19 @@ public Bdd apply(Bdd bdd) { * @return the remapped reference */ private int remapReference(int ref, int[] oldToNew) { - // Handle special cases + // Return result references as-is. if (ref == 0) { - return 0; // Invalid reference stays invalid + return 0; } else if (ref == 1 || ref == -1) { - return ref; // TRUE/FALSE terminals unchanged + return ref; } else if (ref >= Bdd.RESULT_OFFSET) { - return ref; // Result references are not remapped + return ref; } // Handle regular node references (with possible complement) boolean isComplemented = ref < 0; int absRef = isComplemented ? -ref : ref; - - // Convert from reference to index (1-based to 0-based) - int oldIdx = absRef - 1; + int oldIdx = absRef - 1; // convert 1-based to 0-based if (oldIdx >= oldToNew.length) { throw new IllegalStateException("Invalid reference: " + ref); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java index 8e4eb850512..22b203b110c 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -123,7 +123,13 @@ private Bdd doApply(Bdd bdd) { // Pre-spin the ForkJoinPool for better first-pass performance ForkJoinPool.commonPool().submit(() -> {}).join(); - OptimizationState state = initializeOptimization(bdd); + + // Get the conditions from the CFG (since Bdd no longer stores them) + List bddConditions = cfg.getConditionData().getConditions() != null + ? Arrays.asList(cfg.getConditionData().getConditions()) + : new ArrayList<>(); + + OptimizationState state = initializeOptimization(bdd, bddConditions); LOGGER.info(String.format("Initial reordering: %d -> %d nodes", state.initialSize, state.currentSize)); state = runCoarseStage(state); @@ -144,25 +150,25 @@ private Bdd doApply(Bdd bdd) { return state.bestBdd; } - private OptimizationState initializeOptimization(Bdd bdd) { + private OptimizationState initializeOptimization(Bdd bdd, List bddConditions) { // Start with an intelligent initial ordering List initialOrder = DefaultOrderingStrategy.orderConditions( - bdd.getConditions().toArray(new Condition[0]), + bddConditions.toArray(new Condition[0]), conditionInfos); // Sanity check that ordering didn't lose conditions - if (initialOrder.size() != bdd.getConditions().size()) { + if (initialOrder.size() != bddConditions.size()) { throw new IllegalStateException("Initial ordering changed condition count: " + - bdd.getConditions().size() + " -> " + initialOrder.size()); + bddConditions.size() + " -> " + initialOrder.size()); } Condition[] order = initialOrder.toArray(new Condition[0]); List orderView = Arrays.asList(order); // Build initial BDD with better ordering - Bdd currentBest = Bdd.from(cfg, new BddBuilder(), ConditionOrderingStrategy.fixed(orderView)); - int currentSize = currentBest.getNodes().length - 1; // -1 for terminal - int initialSize = bdd.getNodes().length - 1; + Bdd currentBest = compileBdd(orderView); + int currentSize = currentBest.getNodeCount() - 1; // -1 for terminal + int initialSize = bdd.getNodeCount() - 1; return new OptimizationState(order, orderView, currentBest, currentSize, initialSize); } @@ -284,8 +290,8 @@ private OptimizationResult runOptimizationPass( // Move to best position and build BDD once move(order, varIdx, best.position); - Bdd newBdd = Bdd.from(cfg, new BddBuilder(), ConditionOrderingStrategy.fixed(orderView)); - int newSize = newBdd.getNodes().length - 1; + Bdd newBdd = compileBdd(orderView); + int newSize = newBdd.getNodeCount() - 1; if (newSize < bestSize) { bestBdd = newBdd; @@ -311,7 +317,7 @@ private OptimizationResult performAdjacentSwaps(Condition[] order, List ordering) { + BddBuilder builder = threadBuilder.get().reset(); + return new BddCompiler(cfg, ConditionOrderingStrategy.fixed(ordering), builder).compile(); + } + /** * Counts nodes for a given ordering without keeping the BDD. */ private int countNodes(List ordering) { - BddBuilder builder = threadBuilder.get().reset(); - return Bdd.from(cfg, builder, ConditionOrderingStrategy.fixed(ordering)).getNodes().length - 1; + Bdd bdd = compileBdd(ordering); + return bdd.getNodeCount() - 1; // -1 for terminal } // Position and its node count diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/BddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/BddTrait.java deleted file mode 100644 index 0386860e040..00000000000 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/BddTrait.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.traits; - -import software.amazon.smithy.model.node.Node; -import software.amazon.smithy.model.shapes.ShapeId; -import software.amazon.smithy.model.traits.AbstractTrait; -import software.amazon.smithy.model.traits.AbstractTraitBuilder; -import software.amazon.smithy.model.traits.Trait; -import software.amazon.smithy.rulesengine.logic.bdd.Bdd; -import software.amazon.smithy.utils.SmithyBuilder; -import software.amazon.smithy.utils.SmithyUnstableApi; -import software.amazon.smithy.utils.ToSmithyBuilder; - -/*** - * Defines an endpoint rule-set using a binary decision diagram (BDD) used to resolve the client's transport endpoint. - */ -@SmithyUnstableApi -public final class BddTrait extends AbstractTrait implements ToSmithyBuilder { - public static final ShapeId ID = ShapeId.from("smithy.rules#bdd"); - - private final Bdd bdd; - - private BddTrait(Builder builder) { - super(ID, builder.getSourceLocation()); - bdd = SmithyBuilder.requiredState("bdd", builder.bdd); - } - - public static Builder builder() { - return new Builder(); - } - - public Bdd getBdd() { - return bdd; - } - - @Override - protected Node createNode() { - return bdd.toNode(); - } - - @Override - public Builder toBuilder() { - return builder().sourceLocation(getSourceLocation()).bdd(bdd); - } - - public static final class Provider extends AbstractTrait.Provider { - public Provider() { - super(ID); - } - - @Override - public Trait createTrait(ShapeId target, Node value) { - BddTrait trait = builder().sourceLocation(value).bdd(Bdd.fromNode(value)).build(); - trait.setNodeCache(value); - return trait; - } - } - - public static final class Builder extends AbstractTraitBuilder { - private Bdd bdd; - - private Builder() {} - - public Builder bdd(Bdd bdd) { - this.bdd = bdd; - return this; - } - - @Override - public BddTrait build() { - return new BddTrait(this); - } - } -} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/BddTraitValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/BddTraitValidator.java index e8dc6fe0ada..d6c688b51d9 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/BddTraitValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/BddTraitValidator.java @@ -12,7 +12,7 @@ import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.logic.bdd.Bdd; -import software.amazon.smithy.rulesengine.traits.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; public final class BddTraitValidator extends AbstractValidator { @Override @@ -34,29 +34,40 @@ private void validateService(List events, ServiceShape service, // Validate root reference int rootRef = bdd.getRootRef(); - if (Bdd.isComplemented(rootRef)) { + if (Bdd.isComplemented(rootRef) && rootRef != -1) { events.add(error(service, trait, "Root reference cannot be complemented: " + rootRef)); } - validateReference(events, service, trait, "Root", rootRef, bdd); + validateReference(events, service, trait, "Root", rootRef, bdd, trait); - // Validate node references - int[][] nodes = bdd.getNodes(); - for (int i = 0; i < nodes.length; i++) { + // Validate that condition and result counts match what's in the trait + if (bdd.getConditionCount() != trait.getConditions().size()) { + events.add(error(service, + trait, + String.format("BDD condition count (%d) doesn't match trait conditions (%d)", + bdd.getConditionCount(), + trait.getConditions().size()))); + } + + if (bdd.getResultCount() != trait.getResults().size()) { + events.add(error(service, + trait, + String.format("BDD result count (%d) doesn't match trait results (%d)", + bdd.getResultCount(), + trait.getResults().size()))); + } + + // Validate nodes + int nodeCount = bdd.getNodeCount(); + + for (int i = 0; i < nodeCount; i++) { // Skip terminal node at index 0 if (i == 0) { continue; } - // Guard against malformed nodes array - if (nodes[i] == null || nodes[i].length != 3) { - events.add(error(service, trait, String.format("Node %d is malformed", i))); - continue; - } - - int[] node = nodes[i]; - int varIdx = node[0]; - int highRef = node[1]; - int lowRef = node[2]; + int varIdx = bdd.getVariable(i); + int highRef = bdd.getHigh(i); + int lowRef = bdd.getLow(i); if (varIdx < 0 || varIdx >= bdd.getConditionCount()) { events.add(error(service, @@ -68,8 +79,8 @@ private void validateService(List events, ServiceShape service, bdd.getConditionCount()))); } - validateReference(events, service, trait, String.format("Node %d high", i), highRef, bdd); - validateReference(events, service, trait, String.format("Node %d low", i), lowRef, bdd); + validateReference(events, service, trait, String.format("Node %d high", i), highRef, bdd, trait); + validateReference(events, service, trait, String.format("Node %d low", i), lowRef, bdd, trait); } } @@ -79,13 +90,15 @@ private void validateReference( BddTrait trait, String context, int ref, - Bdd bdd + Bdd bdd, + BddTrait bddTrait ) { if (ref == 0) { events.add(error(service, trait, String.format("%s has invalid reference: 0", context))); } else if (Bdd.isNodeReference(ref)) { int nodeIndex = Math.abs(ref) - 1; - if (nodeIndex >= bdd.getNodes().length) { + int nodeCount = bdd.getNodeCount(); + if (nodeIndex >= nodeCount) { events.add(error(service, trait, String.format( @@ -93,11 +106,11 @@ private void validateReference( context, ref, nodeIndex, - bdd.getNodes().length))); + nodeCount))); } } else if (Bdd.isResultReference(ref)) { int resultIndex = ref - Bdd.RESULT_OFFSET; - if (resultIndex >= bdd.getResults().size()) { + if (resultIndex >= bddTrait.getResults().size()) { events.add(error(service, trait, String.format( @@ -105,7 +118,7 @@ private void validateReference( context, ref, resultIndex, - bdd.getResults().size()))); + bddTrait.getResults().size()))); } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/EndpointTestsTraitValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/EndpointTestsTraitValidator.java index cfba5f764ab..99b6f9a983a 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/EndpointTestsTraitValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/EndpointTestsTraitValidator.java @@ -20,7 +20,7 @@ import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; -import software.amazon.smithy.rulesengine.traits.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestOperationInput; @@ -53,7 +53,7 @@ public List validate(Model model) { }); serviceShape.getTrait(BddTrait.class).ifPresent(trait -> { - validateEndpointRuleSet(events, model, serviceShape, trait.getBdd().getParameters(), operationNameMap); + validateEndpointRuleSet(events, model, serviceShape, trait.getParameters(), operationNameMap); }); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java index 99fdae7b137..5a109005413 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java @@ -21,7 +21,7 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; -import software.amazon.smithy.rulesengine.traits.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; /** @@ -50,7 +50,7 @@ private void visitRuleset(List events, ServiceShape serviceShap private void visitBdd(List events, ServiceShape serviceShape, BddTrait trait) { if (trait != null) { - for (Rule result : trait.getBdd().getResults()) { + for (Rule result : trait.getResults()) { if (result instanceof EndpointRule) { visitEndpoint(events, serviceShape, (EndpointRule) result); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java index 12a5f7d473e..b0e75424ede 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java @@ -15,7 +15,7 @@ import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; -import software.amazon.smithy.rulesengine.traits.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestOperationInput; @@ -30,7 +30,7 @@ public List validate(Model model) { List events = new ArrayList<>(); for (ServiceShape s : model.getServiceShapesWithTrait(BddTrait.class)) { - validateParams(events, s, s.expectTrait(BddTrait.class).getBdd().getParameters()); + validateParams(events, s, s.expectTrait(BddTrait.class).getParameters()); } for (ServiceShape s : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java index 1022925078a..333b0893d03 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java @@ -11,7 +11,7 @@ import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; -import software.amazon.smithy.rulesengine.traits.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; /** @@ -36,7 +36,7 @@ private void visitRuleset(List events, ServiceShape serviceShap private void visitBdd(List events, ServiceShape serviceShape, BddTrait trait) { if (trait != null) { - visitParams(events, serviceShape, trait.getBdd().getParameters()); + visitParams(events, serviceShape, trait.getParameters()); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java index 8eeaa16dbb2..0e6268fd0fe 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java @@ -25,7 +25,7 @@ import software.amazon.smithy.rulesengine.language.evaluation.value.Value; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; -import software.amazon.smithy.rulesengine.traits.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; import software.amazon.smithy.rulesengine.traits.ClientContextParamDefinition; import software.amazon.smithy.rulesengine.traits.ClientContextParamsTrait; import software.amazon.smithy.rulesengine.traits.ContextParamTrait; @@ -54,7 +54,7 @@ public List validate(Model model) { validate(model, topDownIndex, service, errors, epTrait, epTrait.getEndpointRuleSet().getParameters()); } if (bddTrait != null) { - validate(model, topDownIndex, service, errors, bddTrait, bddTrait.getBdd().getParameters()); + validate(model, topDownIndex, service, errors, bddTrait, bddTrait.getParameters()); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java index 157da6199c8..ae62a0322c4 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java @@ -12,8 +12,7 @@ import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.evaluation.TestEvaluator; -import software.amazon.smithy.rulesengine.logic.bdd.Bdd; -import software.amazon.smithy.rulesengine.traits.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait; @@ -50,10 +49,10 @@ private void validate(ServiceShape serviceShape, EndpointTestsTrait testsTrait, } private void validateBdd(ServiceShape serviceShape, EndpointTestsTrait testsTrait, List events) { - Bdd bdd = serviceShape.expectTrait(BddTrait.class).getBdd(); + BddTrait trait = serviceShape.expectTrait(BddTrait.class); for (EndpointTestCase endpointTestCase : testsTrait.getTestCases()) { try { - TestEvaluator.evaluate(bdd, endpointTestCase); + TestEvaluator.evaluate(trait, endpointTestCase); } catch (RuntimeException e) { events.add(error(serviceShape, endpointTestCase, e.getMessage())); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java index 12d0ef6a3c7..b5da5ff896e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java @@ -17,7 +17,7 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; -import software.amazon.smithy.rulesengine.traits.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -46,7 +46,7 @@ private void visitRuleset(List events, ServiceShape serviceShap private void visitBdd(List events, ServiceShape serviceShape, BddTrait trait) { if (trait != null) { - for (Rule result : trait.getBdd().getResults()) { + for (Rule result : trait.getResults()) { if (result instanceof EndpointRule) { visitEndpoint(events, serviceShape, (EndpointRule) result); } diff --git a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService index 02900906528..354eade0bdf 100644 --- a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService +++ b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService @@ -4,4 +4,4 @@ software.amazon.smithy.rulesengine.traits.StaticContextParamsTrait$Provider software.amazon.smithy.rulesengine.traits.OperationContextParamsTrait$Provider software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait$Provider software.amazon.smithy.rulesengine.traits.EndpointTestsTrait$Provider -software.amazon.smithy.rulesengine.traits.BddTrait$Provider +software.amazon.smithy.rulesengine.logic.bdd.BddTrait$Provider diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java index c94a9ec05a8..0ad1751d05c 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java @@ -10,7 +10,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import java.util.List; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -41,8 +40,8 @@ void testNodeReduction() { int reduced = builder.makeNode(0, builder.makeTrue(), builder.makeTrue()); assertEquals(1, reduced); // Should return TRUE directly - // Verify no new node was created - assertEquals(1, builder.getNodes().size()); + // Verify no new node was created (only terminal at index 0) + assertEquals(3, builder.getNodesArray().length); // 3 ints for terminal node } @Test @@ -58,8 +57,8 @@ void testComplementCanonicalization() { assertEquals(-node1, node2); // Should be complement of first node - // Only one actual node should be created - assertEquals(2, builder.getNodes().size()); // terminal + 1 node + // Only one actual node should be created (plus terminal) + assertEquals(6, builder.getNodesArray().length); // 3 ints for terminal + 3 for one node } @Test @@ -71,7 +70,7 @@ void testNodeDeduplication() { int node2 = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); assertEquals(node1, node2); // Should return same reference - assertEquals(2, builder.getNodes().size()); // No duplicate created + assertEquals(6, builder.getNodesArray().length); // No duplicate created } @Test @@ -220,7 +219,7 @@ void testGetVariable() { // Test result references int result = builder.makeResult(0); - assertEquals(3, builder.getVariable(result)); // conditionCount + 0 + assertEquals(-1, builder.getVariable(result)); // results have no variable } @Test @@ -233,11 +232,11 @@ void testReduceSimpleBdd() { int b = builder.makeNode(1, a, builder.makeFalse()); int root = builder.makeNode(0, b, a); - int nodesBefore = builder.getNodes().size(); + int nodesBefore = builder.getNodesArray().length; builder.reduce(root); // Structure should be preserved if already optimal - assertEquals(nodesBefore, builder.getNodes().size()); + assertEquals(nodesBefore, builder.getNodesArray().length); } @Test @@ -248,11 +247,11 @@ void testReduceNoChange() { int right = builder.makeNode(1, builder.makeTrue(), builder.makeFalse()); int root = builder.makeNode(0, right, builder.makeFalse()); - int nodesBefore = builder.getNodes().size(); + int nodesBefore = builder.getNodesArray().length; builder.reduce(root); // No change expected - assertEquals(nodesBefore, builder.getNodes().size()); + assertEquals(nodesBefore, builder.getNodesArray().length); } @Test @@ -291,7 +290,7 @@ void testReduceWithComplement() { assertEquals(builder.negate(reduced), reducedComplement); // Verify the structure is preserved - assertTrue(builder.getNodes().size() > 1); + assertTrue(builder.getNodesArray().length > 3); } @Test @@ -327,10 +326,10 @@ void testReduceSharedSubgraphs() { builder.reduce(root); // Shared subgraph should remain shared after reduction - List nodes = builder.getNodes(); + int[] nodes = builder.getNodesArray(); // Verify structure is maintained - at least one node should exist - assertTrue(nodes.size() > 1); + assertTrue(nodes.length > 3); } @Test @@ -350,11 +349,14 @@ void testReducePreservesResultNodes() { boolean foundResult1 = false; // Check the nodes for result references - for (int[] node : builder.getNodes()) { - if (node[1] == result0 || node[2] == result0) { + int[] nodes = builder.getNodesArray(); + int nodeCount = nodes.length / 3; + for (int i = 0; i < nodeCount; i++) { + int baseIdx = i * 3; + if (nodes[baseIdx + 1] == result0 || nodes[baseIdx + 2] == result0) { foundResult0 = true; } - if (node[1] == result1 || node[2] == result1) { + if (nodes[baseIdx + 1] == result1 || nodes[baseIdx + 2] == result1) { foundResult1 = true; } } @@ -372,9 +374,9 @@ void testReduceActuallyReduces() { int middle = builder.makeNode(1, bottom, builder.makeFalse()); int root = builder.makeNode(0, middle, bottom); - int beforeSize = builder.getNodes().size(); + int beforeSize = builder.getNodesArray().length; builder.reduce(root); - int afterSize = builder.getNodes().size(); + int afterSize = builder.getNodesArray().length; // In this case, no reduction should occur since makeNode already optimized assertEquals(beforeSize, afterSize); @@ -426,7 +428,7 @@ void testCofactorRecursive() { assertTrue(cofactorTrue != cofactorFalse); // Verify structure is simplified - assertTrue(builder.getNodes().size() > 1); + assertTrue(builder.getNodesArray().length > 3); } @Test @@ -520,7 +522,7 @@ void testReset() { builder.reset(); // Verify state is cleared - assertEquals(1, builder.getNodes().size()); // Only terminal + assertEquals(3, builder.getNodesArray().length); // Only terminal (3 ints) assertThrows(IllegalStateException.class, () -> builder.makeResult(0)); // Can use builder again diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java index f34bc2a97e9..d9d385166a8 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java @@ -5,12 +5,11 @@ package software.amazon.smithy.rulesengine.logic.bdd; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; @@ -23,6 +22,7 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; import software.amazon.smithy.rulesengine.logic.TestHelpers; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionData; class BddCompilerTest { @@ -65,10 +65,9 @@ void testCompileSimpleEndpointRule() { Bdd bdd = compiler.compile(); assertNotNull(bdd); - assertEquals(1, bdd.getConditions().size()); - // Results include: endpoint when condition true, no match when false, - // and possibly a no match for the overall fallthrough - assertTrue(bdd.getResults().size() >= 2); + assertEquals(1, bdd.getConditionCount()); + // Results include: endpoint when condition true, no match when false + assertTrue(bdd.getResultCount() >= 2); assertTrue(bdd.getRootRef() > 0); } @@ -89,9 +88,9 @@ void testCompileErrorRule() { Bdd bdd = compiler.compile(); - assertEquals(1, bdd.getConditions().size()); + assertEquals(1, bdd.getConditionCount()); // Similar to endpoint rule - assertTrue(bdd.getResults().size() >= 2); + assertTrue(bdd.getResultCount() >= 2); } @Test @@ -111,8 +110,8 @@ void testCompileTreeRule() { Bdd bdd = compiler.compile(); - assertEquals(2, bdd.getConditions().size()); - assertTrue(bdd.getNodes().length > 2); // Should have multiple nodes + assertEquals(2, bdd.getConditionCount()); + assertTrue(bdd.getNodeCount() > 2); // Should have multiple nodes } @Test @@ -127,17 +126,35 @@ void testCompileWithCustomOrdering() { Cfg cfg = Cfg.from(ruleSet); + // Get the actual conditions from the CFG after SSA transform + ConditionData conditionData = cfg.getConditionData(); + List cfgConditions = Arrays.asList(conditionData.getConditions()); + + // Find the conditions that correspond to A and B + Condition condA = null; + Condition condB = null; + for (Condition c : cfgConditions) { + String condStr = c.toString(); + if (condStr.contains("isSet(A)")) { + condA = c; + } else if (condStr.contains("isSet(B)")) { + condB = c; + } + } + + assertNotNull(condA, "Could not find condition for A"); + assertNotNull(condB, "Could not find condition for B"); + // Use fixed ordering (B before A) - Condition condA = rule.getConditions().get(0); - Condition condB = rule.getConditions().get(1); ConditionOrderingStrategy customOrdering = ConditionOrderingStrategy.fixed(Arrays.asList(condB, condA)); BddCompiler compiler = new BddCompiler(cfg, customOrdering, new BddBuilder()); Bdd bdd = compiler.compile(); - // Verify ordering was applied - assertEquals(condB, bdd.getConditions().get(0)); - assertEquals(condA, bdd.getConditions().get(1)); + // Verify ordering was applied by checking the compiled BDD + // Since we don't have access to conditions from Bdd anymore, we just verify compilation succeeded + assertEquals(2, bdd.getConditionCount()); + assertNotNull(bdd); } @Test @@ -149,11 +166,11 @@ void testCompileEmptyRuleSet() { BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); Bdd bdd = compiler.compile(); - assertEquals(0, bdd.getConditions().size()); + assertEquals(0, bdd.getConditionCount()); // Even with no rules, there's still a result (no match) - assertFalse(bdd.getResults().isEmpty()); + assertTrue(bdd.getResultCount() > 0); // Should have at least terminal node - assertNotEquals(0, bdd.getNodes().length); + assertTrue(bdd.getNodeCount() > 0); } @Test @@ -185,6 +202,6 @@ void testCompileSameResultMultiplePaths() { // The BDD compiler might create separate result nodes even for same endpoint // depending on how the CFG is structured - assertEquals(3, bdd.getResults().size()); + assertEquals(3, bdd.getResultCount()); } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluatorTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluatorTest.java deleted file mode 100644 index dc9af6c90f4..00000000000 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEvaluatorTest.java +++ /dev/null @@ -1,215 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.logic.bdd; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -import org.junit.jupiter.api.Test; -import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; -import software.amazon.smithy.rulesengine.logic.ConditionEvaluator; -import software.amazon.smithy.rulesengine.logic.TestHelpers; -import software.amazon.smithy.utils.ListUtils; - -class BddEvaluatorTest { - - private static final Parameters EMPTY = Parameters.builder().build(); - - @Test - void testEvaluateTerminalTrue() { - // BDD with just TRUE terminal - int[][] nodes = new int[][] {{-1, 1, -1}}; - Bdd bdd = new Bdd(EMPTY, ListUtils.of(), ListUtils.of(), nodes, 1); - - BddEvaluator evaluator = BddEvaluator.from(bdd); - int result = evaluator.evaluate(idx -> true); - - assertEquals(-1, result); // TRUE terminal returns -1 (TRUE isn't valid in our MTBDD) - } - - @Test - void testEvaluateTerminalFalse() { - // BDD with just FALSE terminal - int[][] nodes = new int[][] {{-1, 1, -1}}; - Bdd bdd = new Bdd(EMPTY, ListUtils.of(), ListUtils.of(), nodes, -1); - - BddEvaluator evaluator = BddEvaluator.from(bdd); - int result = evaluator.evaluate(idx -> true); - - assertEquals(-1, result); // FALSE terminal returns -1 (same as TRUE; FALSE isn't valid in our MTBDD). - } - - @Test - void testEvaluateSingleConditionTrue() { - Condition cond = Condition.builder().fn(TestHelpers.isSet("param")).build(); - Rule rule = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); - - // BDD: if condition then result1 else no-match - // With new encoding: result references are encoded as RESULT_OFFSET + resultIndex - int result1Ref = Bdd.RESULT_OFFSET + 1; - - int[][] nodes = new int[][] { - {-1, 1, -1}, // 0: terminal - {0, result1Ref, -1} // 1: condition node (high=result1, low=FALSE) - }; - Bdd bdd = new Bdd(EMPTY, ListUtils.of(cond), ListUtils.of(null, rule), nodes, 2); - - BddEvaluator evaluator = BddEvaluator.from(bdd); - - // When condition is true, should return result 1 - assertEquals(1, evaluator.evaluate(idx -> true)); - - // When condition is false, should return -1 (no match) - assertEquals(-1, evaluator.evaluate(idx -> false)); - } - - @Test - void testEvaluateComplementedNode() { - Condition cond1 = Condition.builder().fn(TestHelpers.isSet("param1")).build(); - Condition cond2 = Condition.builder().fn(TestHelpers.isSet("param2")).build(); - Rule rule = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); - - // BDD with a complemented reference to an internal node - // We want: if cond1 then NOT(cond2) else false - // Which means: if cond1 && !cond2 then result1 else no-match - int result1Ref = Bdd.RESULT_OFFSET + 1; - - int[][] nodes = new int[][] { - {-1, 1, -1}, // 0: terminal - {0, -3, -1}, // 1: cond1 node (high=-3 (complement of ref 3 = node 2), low=FALSE) - {1, -1, result1Ref} // 2: cond2 node (high=FALSE, low=result1) - }; - // Root is 2 (reference to node at index 1) - Bdd bdd = new Bdd(EMPTY, ListUtils.of(cond1, cond2), ListUtils.of(null, rule), nodes, 2); - - BddEvaluator evaluator = BddEvaluator.from(bdd); - - // When cond1=true, we follow high branch to -3 (complement of node 2) - // The complement flips the node's branch selection - // If cond2=true, with complement we take the "false" branch (which is low) -> result1 - // If cond2=false, with complement we take the "true" branch (which is high) -> FALSE - ConditionEvaluator bothTrue = idx -> true; - assertEquals(1, evaluator.evaluate(bothTrue)); // cond1=true, cond2=true -> result1 - - ConditionEvaluator firstTrueSecondFalse = idx -> idx == 0; - assertEquals(-1, evaluator.evaluate(firstTrueSecondFalse)); // cond1=true, cond2=false -> FALSE - - // When cond1=false, we follow low branch to FALSE - ConditionEvaluator firstFalse = idx -> false; - assertEquals(-1, evaluator.evaluate(firstFalse)); // cond1=false -> FALSE - } - - @Test - void testEvaluateMultipleConditions() { - Condition cond1 = Condition.builder().fn(TestHelpers.isSet("param1")).build(); - Condition cond2 = Condition.builder().fn(TestHelpers.isSet("param2")).build(); - Rule rule1 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://a.com")); - Rule rule2 = ErrorRule.builder().error("hi"); - - // BDD: if cond1 then (if cond2 then result1 else result2) else no-match - // Result references encoded with RESULT_OFFSET - int result1Ref = Bdd.RESULT_OFFSET + 1; - int result2Ref = Bdd.RESULT_OFFSET + 2; - - int[][] nodes = new int[][] { - {-1, 1, -1}, // 0: terminal - {0, 3, -1}, // 1: cond1 node (high=cond2 node, low=FALSE) - {1, result1Ref, result2Ref} // 2: cond2 node (high=result1, low=result2) - }; - Bdd bdd = new Bdd(EMPTY, ListUtils.of(cond1, cond2), ListUtils.of(null, rule1, rule2), nodes, 2); - - BddEvaluator evaluator = BddEvaluator.from(bdd); - ConditionEvaluator condEval = idx -> idx == 0; // only first condition is true - - int result = evaluator.evaluate(condEval); - assertEquals(2, result); // Should get result2 since cond2 is false - } - - @Test - void testEvaluateNoMatchResult() { - Condition cond = Condition.builder().fn(TestHelpers.isSet("param")).build(); - - // BDD with explicit no-match result (index 0) - // Result0 reference encoded with RESULT_OFFSET - int result0Ref = Bdd.RESULT_OFFSET; - - int[][] nodes = new int[][] { - {-1, 1, -1}, // 0: terminal - {0, -1, result0Ref} // 1: condition node - }; - Bdd bdd = new Bdd(EMPTY, ListUtils.of(cond), ListUtils.of((Rule) null), nodes, 2); - - BddEvaluator evaluator = BddEvaluator.from(bdd); - int result = evaluator.evaluate(idx -> false); - - assertEquals(-1, result); // Result index 0 is treated as no-match - } - - @Test - void testEvaluateWithLargeResultIndex() { - Condition cond = Condition.builder().fn(TestHelpers.isSet("param")).build(); - Rule rule = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); - - // Test with a larger result index to ensure offset works correctly - int result999Ref = Bdd.RESULT_OFFSET + 999; - - int[][] nodes = new int[][] { - {-1, 1, -1}, // 0: terminal - {0, result999Ref, -1} // 1: condition node - }; - - // Create a results list with 1000 entries (0-999) - Rule[] results = new Rule[1000]; - results[999] = rule; - - Bdd bdd = new Bdd(EMPTY, ListUtils.of(cond), ListUtils.of(results), nodes, 2); - - BddEvaluator evaluator = BddEvaluator.from(bdd); - - // When condition is true, should return result 999 - assertEquals(999, evaluator.evaluate(idx -> true)); - } - - @Test - void testEvaluateComplexBddWithMixedReferences() { - Condition cond1 = Condition.builder().fn(TestHelpers.isSet("param1")).build(); - Condition cond2 = Condition.builder().fn(TestHelpers.isSet("param2")).build(); - Condition cond3 = Condition.builder().fn(TestHelpers.isSet("param3")).build(); - Rule rule1 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://a.com")); - Rule rule2 = ErrorRule.builder().error("error"); - - // Complex BDD with multiple conditions, complement edges, and results - int result1Ref = Bdd.RESULT_OFFSET + 1; - int result2Ref = Bdd.RESULT_OFFSET + 2; - - int[][] nodes = new int[][] { - {-1, 1, -1}, // 0: terminal - {0, 3, 4}, // 1: cond1 node - {1, result1Ref, -1}, // 2: cond2 node - {2, result2Ref, -5} // 3: cond3 node (low has complement ref) - }; - - Bdd bdd = new Bdd(EMPTY, - ListUtils.of(cond1, cond2, cond3), - ListUtils.of(null, rule1, rule2), - nodes, - 2); - - BddEvaluator evaluator = BddEvaluator.from(bdd); - - // Test various paths through the BDD - ConditionEvaluator allTrue = idx -> true; - assertEquals(1, evaluator.evaluate(allTrue)); // cond1=T -> cond2=T -> result1 - - ConditionEvaluator firstTrueOnly = idx -> idx == 0; - assertEquals(-1, evaluator.evaluate(firstTrueOnly)); // cond1=T -> cond2=F -> FALSE - - ConditionEvaluator firstFalseThirdTrue = idx -> idx == 2; - assertEquals(2, evaluator.evaluate(firstFalseThirdTrue)); // cond1=F -> cond3=T -> result2 - } -} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java index 8493ae0ea25..e68b63a2d72 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java @@ -6,15 +6,10 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import java.util.ArrayList; -import java.util.List; import org.junit.jupiter.api.Test; -import software.amazon.smithy.model.node.Node; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; @@ -22,54 +17,47 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.logic.TestHelpers; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; -import software.amazon.smithy.utils.ListUtils; class BddTest { @Test void testConstructorValidation() { - Parameters params = Parameters.builder().build(); - int[][] nodes = new int[][] {{-1, 1, -1}}; - // Should reject complemented root (except -1 which is FALSE terminal) - assertThrows(IllegalArgumentException.class, () -> new Bdd(params, ListUtils.of(), ListUtils.of(), nodes, -2)); + assertThrows(IllegalArgumentException.class, + () -> new Bdd(-2, 0, 0, 1, consumer -> consumer.accept(-1, 1, -1))); // Should accept positive root - Bdd bdd = new Bdd(params, ListUtils.of(), ListUtils.of(), nodes, 1); + Bdd bdd = new Bdd(1, 0, 0, 1, consumer -> consumer.accept(-1, 1, -1)); assertEquals(1, bdd.getRootRef()); // Should accept FALSE terminal as root - Bdd bdd2 = new Bdd(params, ListUtils.of(), ListUtils.of(), nodes, -1); + Bdd bdd2 = new Bdd(-1, 0, 0, 1, consumer -> consumer.accept(-1, 1, -1)); assertEquals(-1, bdd2.getRootRef()); } @Test void testBasicAccessors() { - Parameters params = Parameters.builder() - .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) - .build(); - Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); - Rule rule = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); - int[][] nodes = new int[][] { - {-1, 1, -1}, - {0, 3, -1}, - {1, 1, -1} - }; + Bdd bdd = new Bdd(2, 2, 1, 3, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, 3, -1); // node 1: var 0, high 3, low -1 + consumer.accept(1, 1, -1); // node 2: var 1, high 1, low -1 + }); + + assertEquals(2, bdd.getConditionCount()); + assertEquals(1, bdd.getResultCount()); + assertEquals(3, bdd.getNodeCount()); + assertEquals(2, bdd.getRootRef()); - Bdd bdd = new Bdd(params, ListUtils.of(cond), ListUtils.of(rule), nodes, 2); + // Test node accessors + assertEquals(-1, bdd.getVariable(0)); + assertEquals(1, bdd.getHigh(0)); + assertEquals(-1, bdd.getLow(0)); - assertEquals(params, bdd.getParameters()); - assertEquals(1, bdd.getConditions().size()); - assertEquals(cond, bdd.getConditions().get(0)); - assertEquals(1, bdd.getConditionCount()); - assertEquals(1, bdd.getResults().size()); - assertEquals(rule, bdd.getResults().get(0)); - assertEquals(3, bdd.getNodes().length); - assertEquals(2, bdd.getRootRef()); + assertEquals(0, bdd.getVariable(1)); + assertEquals(3, bdd.getHigh(1)); + assertEquals(-1, bdd.getLow(1)); } @Test @@ -83,11 +71,12 @@ void testFromRuleSet() { .endpoint(TestHelpers.endpoint("https://example.com"))) .build(); - Bdd bdd = Bdd.from(ruleSet); + Cfg cfg = Cfg.from(ruleSet); + Bdd bdd = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()).compile(); assertTrue(bdd.getConditionCount() > 0); - assertFalse(bdd.getResults().isEmpty()); - assertTrue(bdd.getNodes().length > 1); // At least terminal + one node + assertTrue(bdd.getResultCount() > 0); + assertTrue(bdd.getNodeCount() > 1); // At least terminal + one node } @Test @@ -98,34 +87,10 @@ void testFromCfg() { .build(); Cfg cfg = Cfg.from(ruleSet); - Bdd bdd = Bdd.from(cfg); + Bdd bdd = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()).compile(); assertEquals(0, bdd.getConditionCount()); // No conditions - assertFalse(bdd.getResults().isEmpty()); - } - - @Test - void testEquals() { - Bdd bdd1 = createSimpleBdd(); - Bdd bdd2 = createSimpleBdd(); - - assertEquals(bdd1, bdd2); - assertEquals(bdd1.hashCode(), bdd2.hashCode()); - - // Different root - use a different value than what createSimpleBdd returns - Bdd bdd3 = new Bdd(bdd1.getParameters(), bdd1.getConditions(), bdd1.getResults(), bdd1.getNodes(), -1); - assertNotEquals(bdd1, bdd3); - - // Different conditions - Condition newCond = Condition.builder().fn(TestHelpers.isSet("Bucket")).build(); - Bdd bdd4 = new Bdd(bdd1 - .getParameters(), ListUtils.of(newCond), bdd1.getResults(), bdd1.getNodes(), bdd1.getRootRef()); - assertNotEquals(bdd1, bdd4); - - // Different nodes - int[][] newNodes = new int[][] {{-1, 1, -1}, {0, -1, Bdd.RESULT_OFFSET + 1}}; - Bdd bdd5 = new Bdd(bdd1.getParameters(), bdd1.getConditions(), bdd1.getResults(), newNodes, bdd1.getRootRef()); - assertNotEquals(bdd1, bdd5); + assertTrue(bdd.getResultCount() > 0); } @Test @@ -133,103 +98,67 @@ void testToString() { Bdd bdd = createSimpleBdd(); String str = bdd.toString(); - assertTrue(str.contains("Bdd{")); - assertTrue(str.contains("conditions")); - assertTrue(str.contains("results")); + assertTrue(str.contains("Bdd {")); + assertTrue(str.contains("conditions:")); + assertTrue(str.contains("results:")); assertTrue(str.contains("root:")); assertTrue(str.contains("nodes")); } - @Test - void testToNodeAndFromNode() { - Bdd original = createSimpleBdd(); - - Node node = original.toNode(); - assertTrue(node.isObjectNode()); - assertTrue(node.expectObjectNode().containsMember("conditions")); - assertTrue(node.expectObjectNode().containsMember("results")); - assertTrue(node.expectObjectNode().containsMember("nodes")); - assertTrue(node.expectObjectNode().containsMember("root")); - - // Original has 2 results: NoMatchRule at 0, endpoint at 1 - // Serialized should only have 1 result (the endpoint) - int serializedResultCount = node.expectObjectNode() - .expectArrayMember("results") - .getElements() - .size(); - assertEquals(1, serializedResultCount); - - Bdd restored = Bdd.fromNode(node); - assertEquals(original.getRootRef(), restored.getRootRef()); - assertEquals(original.getConditionCount(), restored.getConditionCount()); - assertEquals(original.getResults().size(), restored.getResults().size()); - assertEquals(original.getNodes().length, restored.getNodes().length); - - // Verify NoMatchRule was restored at index 0 - assertInstanceOf(NoMatchRule.class, restored.getResults().get(0)); - } - @Test void testToStringWithDifferentNodeTypes() { - Parameters params = Parameters.builder().build(); - - // Two conditions - Condition cond1 = Condition.builder().fn(TestHelpers.isSet("Region")).build(); - Condition cond2 = Condition.builder().fn(TestHelpers.booleanEquals("UseFips", true)).build(); - - // Two endpoint results - Rule endpoint1 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); - Rule endpoint2 = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example-fips.com")); - - // NoMatchRule MUST be at index 0 - List results = new ArrayList<>(); - results.add(NoMatchRule.INSTANCE); // Index 0 - always NoMatch - results.add(endpoint1); // Index 1 - results.add(endpoint2); // Index 2 - // BDD structure referencing the correct indices - int[][] nodes = new int[][] { - {-1, 1, -1}, // 0: terminal node - {0, 2, -1}, // 1: if Region is set, go to node 2, else FALSE - {1, Bdd.RESULT_OFFSET + 2, Bdd.RESULT_OFFSET + 1} // 2: if UseFips, return result 2, else result 1 - }; + Bdd bdd = new Bdd(2, 2, 3, 3, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, 2, -1); // node 1: if Region is set, go to node 2, else FALSE + consumer.accept(1, Bdd.RESULT_OFFSET + 2, Bdd.RESULT_OFFSET + 1); // node 2: if UseFips, return result 2, else result 1 + }); - Bdd bdd = new Bdd(params, ListUtils.of(cond1, cond2), results, nodes, 1); String str = bdd.toString(); - assertTrue(str.contains("Endpoint:")); + assertTrue(str.contains("conditions: 2")); + assertTrue(str.contains("results: 3")); assertTrue(str.contains("C0")); assertTrue(str.contains("C1")); - assertTrue(str.contains("R0")); assertTrue(str.contains("R1")); assertTrue(str.contains("R2")); - assertTrue(str.contains("NoMatchRule")); // R0 will show as NoMatchRule + } - // Test serialization doesn't include NoMatchRule - Node serialized = bdd.toNode(); - assertEquals(2, - serialized.expectObjectNode() - .expectArrayMember("results") - .getElements() - .size()); // Only the two endpoints, not NoMatch + @Test + void testReferenceHelperMethods() { + // Test isNodeReference + assertTrue(Bdd.isNodeReference(2)); + assertTrue(Bdd.isNodeReference(-2)); + assertFalse(Bdd.isNodeReference(0)); + assertFalse(Bdd.isNodeReference(1)); + assertFalse(Bdd.isNodeReference(-1)); + assertFalse(Bdd.isNodeReference(Bdd.RESULT_OFFSET)); + + // Test isResultReference + assertTrue(Bdd.isResultReference(Bdd.RESULT_OFFSET)); + assertTrue(Bdd.isResultReference(Bdd.RESULT_OFFSET + 1)); + assertFalse(Bdd.isResultReference(1)); + assertFalse(Bdd.isResultReference(-1)); + + // Test isTerminal + assertTrue(Bdd.isTerminal(1)); + assertTrue(Bdd.isTerminal(-1)); + assertFalse(Bdd.isTerminal(2)); + assertFalse(Bdd.isTerminal(Bdd.RESULT_OFFSET)); + + // Test isComplemented + assertTrue(Bdd.isComplemented(-2)); + assertTrue(Bdd.isComplemented(-3)); + assertFalse(Bdd.isComplemented(-1)); // FALSE terminal is not considered complemented + assertFalse(Bdd.isComplemented(1)); + assertFalse(Bdd.isComplemented(2)); } private Bdd createSimpleBdd() { - Parameters params = Parameters.builder().build(); - Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); - Rule endpoint = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); - - // NoMatchRule MUST be at index 0 - List results = new ArrayList<>(); - results.add(NoMatchRule.INSTANCE); // Index 0 - always NoMatch - results.add(endpoint); // Index 1 - the actual endpoint - - int[][] nodes = new int[][] { - {-1, 1, -1}, // 0: terminal - {0, Bdd.RESULT_OFFSET + 1, -1} // 1: if cond true, return result 1 (endpoint), else FALSE - }; - - return new Bdd(params, ListUtils.of(cond), results, nodes, 1); + return new Bdd(2, 1, 2, 2, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, Bdd.RESULT_OFFSET + 1, -1); // node 1: if cond true, return result 1, else FALSE + }); } // Used to regenerate BDD test cases for errorfiles diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java new file mode 100644 index 00000000000..443898921f4 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java @@ -0,0 +1,78 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.utils.ListUtils; + +public class BddTraitTest { + @Test + void testBddTraitSerialization() { + // Create a BddTrait with full context + Parameters params = Parameters.builder().build(); + Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); + Rule endpoint = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + + List results = new ArrayList<>(); + results.add(NoMatchRule.INSTANCE); + results.add(endpoint); + + Bdd bdd = createSimpleBdd(); + + BddTrait original = BddTrait.builder() + .parameters(params) + .conditions(ListUtils.of(cond)) + .results(results) + .bdd(bdd) + .build(); + + // Serialize to Node + Node node = original.toNode(); + assertTrue(node.isObjectNode()); + assertTrue(node.expectObjectNode().containsMember("parameters")); + assertTrue(node.expectObjectNode().containsMember("conditions")); + assertTrue(node.expectObjectNode().containsMember("results")); + + // Serialized should only have 1 result (the endpoint, not NoMatch) + int serializedResultCount = node.expectObjectNode() + .expectArrayMember("results") + .getElements() + .size(); + assertEquals(1, serializedResultCount); + + // Deserialize from Node + BddTrait restored = BddTrait.fromNode(node); + + assertEquals(original.getParameters(), restored.getParameters()); + assertEquals(original.getConditions().size(), restored.getConditions().size()); + assertEquals(original.getResults().size(), restored.getResults().size()); + assertEquals(original.getBdd().getRootRef(), restored.getBdd().getRootRef()); + assertEquals(original.getBdd().getConditionCount(), restored.getBdd().getConditionCount()); + assertEquals(original.getBdd().getResultCount(), restored.getBdd().getResultCount()); + + // Verify NoMatchRule was restored at index 0 + assertInstanceOf(NoMatchRule.class, restored.getResults().get(0)); + } + + private Bdd createSimpleBdd() { + return new Bdd(2, 1, 2, 2, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, Bdd.RESULT_OFFSET + 1, -1); // node 1: if cond true, return result 1, else FALSE + }); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java index 6c2cd3d5669..80464db0c42 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java @@ -4,65 +4,41 @@ */ package software.amazon.smithy.rulesengine.logic.bdd; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotSame; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; import org.junit.jupiter.api.Test; -import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; -import software.amazon.smithy.rulesengine.logic.TestHelpers; class NodeReversalTest { @Test void testSingleNodeBdd() { // BDD with just terminal node - int[][] nodes = new int[][] { - {-1, 1, -1} // terminal - }; - - Bdd original = new Bdd( - Parameters.builder().build(), - new ArrayList<>(), - new ArrayList<>(), - nodes, - 1 // root is TRUE - ); + Bdd original = new Bdd(1, 0, 0, 1, consumer -> { + consumer.accept(-1, 1, -1); // terminal + }); NodeReversal reversal = new NodeReversal(); Bdd reversed = reversal.apply(original); - // Should be unchanged (2 nodes returns as-is). - assertEquals(1, reversed.getNodes().length); + // Should be unchanged (only 1 node, reversal returns as-is for <= 2 nodes). + assertEquals(1, reversed.getNodeCount()); assertEquals(1, reversed.getRootRef()); - assertArrayEquals(new int[] {-1, 1, -1}, reversed.getNodes()[0]); + + // Check terminal node + assertEquals(-1, reversed.getVariable(0)); + assertEquals(1, reversed.getHigh(0)); + assertEquals(-1, reversed.getLow(0)); } @Test void testComplementEdges() { // BDD with complement edges - int[][] nodes = new int[][] { - {-1, 1, -1}, // terminal - {0, 3, -2}, // condition 0, high to node 2, low to complement of node 1 - {1, 1, -1} // condition 1 - }; - - Bdd original = new Bdd( - Parameters.builder().build(), - Arrays.asList( - Condition.builder().fn(TestHelpers.isSet("Region")).build(), - Condition.builder().fn(TestHelpers.isSet("Bucket")).build()), - new ArrayList<>(), - nodes, - 2 // root points to node 1 - ); + Bdd original = new Bdd(2, 2, 0, 3, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, 3, -2); // node 1: condition 0, high to node 2, low to complement of node 1 + consumer.accept(1, 1, -1); // node 2: condition 1 + }); NodeReversal reversal = new NodeReversal(); Bdd reversed = reversal.apply(original); @@ -72,78 +48,58 @@ void testComplementEdges() { assertEquals(3, reversed.getRootRef()); // Check complement edge is properly remapped - int[] reversedNode2 = reversed.getNodes()[2]; - assertEquals(0, reversedNode2[0]); // condition index unchanged - assertEquals(2, reversedNode2[1]); // high ref 3 -> 2 - assertEquals(-3, reversedNode2[2]); // complement low ref -2 -> -3 + // Original node 1 is now at index 2 + assertEquals(0, reversed.getVariable(2)); // condition index unchanged + assertEquals(2, reversed.getHigh(2)); // high ref 3 -> 2 + assertEquals(-3, reversed.getLow(2)); // complement low ref -2 -> -3 } @Test void testResultNodes() { // BDD with result terminals - int[][] nodes = new int[][] { - {-1, 1, -1}, // terminal - {0, Bdd.RESULT_OFFSET + 1, Bdd.RESULT_OFFSET}, // condition 0 at index 1 - {2, 1, -1}, // result 0 at index 2 - {3, 1, -1} // result 1 at index 3 - }; - - List results = Arrays.asList( - EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")), - ErrorRule.builder().error("Error occurred")); - - Bdd original = new Bdd( - Parameters.builder().build(), - Arrays.asList( - Condition.builder().fn(TestHelpers.isSet("Region")).build(), - Condition.builder().fn(TestHelpers.isSet("Bucket")).build()), - results, - nodes, - 2 // root points to node 1 - ); + Bdd original = new Bdd(2, 4, 2, 4, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, Bdd.RESULT_OFFSET + 1, Bdd.RESULT_OFFSET); // node 1: condition 0 + consumer.accept(2, 1, -1); // node 2: result 0 + consumer.accept(3, 1, -1); // node 3: result 1 + }); NodeReversal reversal = new NodeReversal(); Bdd reversed = reversal.apply(original); - assertEquals(4, reversed.getNodes().length); + assertEquals(4, reversed.getNodeCount()); assertEquals(4, reversed.getRootRef()); // root was ref 2, now ref 4 // Terminal stays at 0 - assertArrayEquals(new int[] {-1, 1, -1}, reversed.getNodes()[0]); + assertEquals(-1, reversed.getVariable(0)); + assertEquals(1, reversed.getHigh(0)); + assertEquals(-1, reversed.getLow(0)); // Original node 3 now at index 1 - assertArrayEquals(new int[] {3, 1, -1}, reversed.getNodes()[1]); + assertEquals(3, reversed.getVariable(1)); + assertEquals(1, reversed.getHigh(1)); + assertEquals(-1, reversed.getLow(1)); // Original node 2 stays at index 2 - assertArrayEquals(new int[] {2, 1, -1}, reversed.getNodes()[2]); + assertEquals(2, reversed.getVariable(2)); + assertEquals(1, reversed.getHigh(2)); + assertEquals(-1, reversed.getLow(2)); // Original node 1 now at index 3 - int[] conditionNode = reversed.getNodes()[3]; - assertEquals(0, conditionNode[0]); // condition index unchanged - assertEquals(Bdd.RESULT_OFFSET + 1, conditionNode[1]); // result references unchanged - assertEquals(Bdd.RESULT_OFFSET, conditionNode[2]); // result references unchanged + assertEquals(0, reversed.getVariable(3)); // condition index unchanged + assertEquals(Bdd.RESULT_OFFSET + 1, reversed.getHigh(3)); // result references unchanged + assertEquals(Bdd.RESULT_OFFSET, reversed.getLow(3)); // result references unchanged } @Test void testFourNodeExample() { // Simple 4-node example to verify reference mapping - int[][] nodes = new int[][] { - {-1, 1, -1}, // 0: terminal - {0, 3, 4}, // 1: points to nodes 2 and 3 - {1, 1, -1}, // 2: - {2, 1, -1} // 3: - }; - - Bdd original = new Bdd( - Parameters.builder().build(), - Arrays.asList( - Condition.builder().fn(TestHelpers.isSet("A")).build(), - Condition.builder().fn(TestHelpers.isSet("B")).build(), - Condition.builder().fn(TestHelpers.isSet("C")).build()), - new ArrayList<>(), - nodes, - 2 // root points to node 1 - ); + Bdd original = new Bdd(2, 3, 0, 4, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, 3, 4); // node 1: points to nodes 2 and 3 + consumer.accept(1, 1, -1); // node 2: + consumer.accept(2, 1, -1); // node 3: + }); NodeReversal reversal = new NodeReversal(); Bdd reversed = reversal.apply(original); @@ -152,45 +108,73 @@ void testFourNodeExample() { // Ref mapping: 2->4, 3->3, 4->2 assertEquals(4, reversed.getRootRef()); // root ref 2 -> 4 - int[] nodeAtIndex3 = reversed.getNodes()[3]; // original node 1 - assertEquals(0, nodeAtIndex3[0]); - assertEquals(3, nodeAtIndex3[1]); // ref 3 stays 3 - assertEquals(2, nodeAtIndex3[2]); // ref 4 -> 2 + // Check node at index 3 (originally node 1) + assertEquals(0, reversed.getVariable(3)); // original node 1's variable + assertEquals(3, reversed.getHigh(3)); // ref 3 stays 3 + assertEquals(2, reversed.getLow(3)); // ref 4 -> 2 } @Test void testImmutability() { // Ensure original BDD is not modified - int[][] originalNodes = new int[][] { - {-1, 1, -1}, - {0, 3, -1}, - {1, 1, -1} - }; - - Bdd original = new Bdd( - Parameters.builder().build(), - Arrays.asList( - Condition.builder().fn(TestHelpers.isSet("Region")).build(), - Condition.builder().fn(TestHelpers.isSet("Bucket")).build()), - new ArrayList<>(), - originalNodes, - 2); - - // Clone original node arrays for comparison - int[][] originalNodesCopy = new int[original.getNodes().length][]; - for (int i = 0; i < original.getNodes().length; i++) { - originalNodesCopy[i] = original.getNodes()[i].clone(); + Bdd original = new Bdd(2, 2, 0, 3, consumer -> { + consumer.accept(-1, 1, -1); // node 0 + consumer.accept(0, 3, -1); // node 1 + consumer.accept(1, 1, -1); // node 2 + }); + + // Get original values for comparison + int originalNodeCount = original.getNodeCount(); + int originalRootRef = original.getRootRef(); + + // Store original node values + int[] originalNodeValues = new int[original.getNodeCount() * 3]; + for (int i = 0; i < original.getNodeCount(); i++) { + originalNodeValues[i * 3] = original.getVariable(i); + originalNodeValues[i * 3 + 1] = original.getHigh(i); + originalNodeValues[i * 3 + 2] = original.getLow(i); } NodeReversal reversal = new NodeReversal(); Bdd reversed = reversal.apply(original); // Verify original is unchanged - assertEquals(originalNodesCopy.length, original.getNodes().length); - for (int i = 0; i < originalNodesCopy.length; i++) { - assertArrayEquals(originalNodesCopy[i], original.getNodes()[i]); + assertEquals(originalNodeCount, original.getNodeCount()); + assertEquals(originalRootRef, original.getRootRef()); + + // Check original node values haven't changed + for (int i = 0; i < original.getNodeCount(); i++) { + assertEquals(originalNodeValues[i * 3], original.getVariable(i)); + assertEquals(originalNodeValues[i * 3 + 1], original.getHigh(i)); + assertEquals(originalNodeValues[i * 3 + 2], original.getLow(i)); } - assertNotSame(original.getNodes(), reversed.getNodes()); + // Ensure reversed is a different object + assertNotSame(original, reversed); + } + + @Test + void testTwoNodeBdd() { + // Test edge case with exactly 2 nodes + Bdd original = new Bdd(2, 1, 0, 2, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, 1, -1); // node 1: simple condition + }); + + NodeReversal reversal = new NodeReversal(); + Bdd reversed = reversal.apply(original); + + // Should be unchanged (reversal returns as-is for <= 2 nodes) + assertEquals(2, reversed.getNodeCount()); + assertEquals(2, reversed.getRootRef()); + + // Check nodes are unchanged + assertEquals(-1, reversed.getVariable(0)); + assertEquals(1, reversed.getHigh(0)); + assertEquals(-1, reversed.getLow(0)); + + assertEquals(0, reversed.getVariable(1)); + assertEquals(1, reversed.getHigh(1)); + assertEquals(-1, reversed.getLow(1)); } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java index b213d1f59d4..47f87d05439 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java @@ -7,6 +7,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.Identifier; @@ -21,6 +23,7 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.logic.TestHelpers; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionData; // Does some basic checks, but doesn't get too specific so we can easily change the sifting algorithm. class SiftingOptimizationTest { @@ -45,17 +48,17 @@ void testBasicOptimization() { .build(); Cfg cfg = Cfg.from(ruleSet); - Bdd originalBdd = Bdd.from(cfg); + Bdd originalBdd = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()).compile(); SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); Bdd optimizedBdd = optimizer.apply(originalBdd); // Basic checks - assertEquals(originalBdd.getConditions().size(), optimizedBdd.getConditions().size()); - assertEquals(originalBdd.getResults().size(), optimizedBdd.getResults().size()); + assertEquals(originalBdd.getConditionCount(), optimizedBdd.getConditionCount()); + assertEquals(originalBdd.getResultCount(), optimizedBdd.getResultCount()); // Size should be same or smaller - assertTrue(optimizedBdd.getNodes().length <= originalBdd.getNodes().length); + assertTrue(optimizedBdd.getNodeCount() <= originalBdd.getNodeCount()); } @Test @@ -86,27 +89,23 @@ void testDependenciesPreserved() { .build(); Cfg cfg = Cfg.from(ruleSet); - Bdd originalBdd = Bdd.from(cfg); + Bdd originalBdd = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()).compile(); SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); Bdd optimizedBdd = optimizer.apply(originalBdd); - // Find the positions of dependent conditions - int hasInputPos = -1; - int usesInputPos = -1; - for (int i = 0; i < optimizedBdd.getConditions().size(); i++) { - Condition cond = optimizedBdd.getConditions().get(i); - if (cond.getResult().isPresent() && - cond.getResult().get().toString().equals("hasInput")) { - hasInputPos = i; - } else if (cond.getFunction().toString().contains("hasInput")) { - usesInputPos = i; - } - } - - // Verify dependency is preserved: definer comes before user - assertTrue(hasInputPos < usesInputPos, - "Condition defining hasInput must come before condition using it"); + // Get conditions from the CFG to verify ordering + ConditionData conditionData = cfg.getConditionData(); + List conditions = Arrays.asList(conditionData.getConditions()); + + // The optimizer may have reordered conditions, but we need to check + // if it created a valid BDD with the same number of conditions + assertEquals(originalBdd.getConditionCount(), optimizedBdd.getConditionCount()); + + // We can't directly check the ordering from the BDD anymore since it doesn't + // store conditions. The fact that the optimization completes successfully + // and produces a valid BDD means dependencies were preserved (otherwise + // the BddCompiler would have failed during the optimization process). } @Test @@ -122,13 +121,82 @@ void testSingleCondition() { EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).addRule(rule).build(); Cfg cfg = Cfg.from(ruleSet); - Bdd originalBdd = Bdd.from(cfg); + Bdd originalBdd = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()).compile(); + + SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); + Bdd optimizedBdd = optimizer.apply(originalBdd); + + // Should be unchanged or very similar + assertEquals(originalBdd.getNodeCount(), optimizedBdd.getNodeCount()); + assertEquals(1, optimizedBdd.getConditionCount()); + } + + @Test + void testEmptyRuleSet() { + // Test empty ruleset edge case + Parameters params = Parameters.builder().build(); + EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).build(); + + Cfg cfg = Cfg.from(ruleSet); + Bdd originalBdd = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()).compile(); SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); Bdd optimizedBdd = optimizer.apply(originalBdd); - // Should be unchanged - assertEquals(originalBdd.getNodes().length, optimizedBdd.getNodes().length); - assertEquals(1, optimizedBdd.getConditions().size()); + assertEquals(0, optimizedBdd.getConditionCount()); + assertEquals(originalBdd.getResultCount(), optimizedBdd.getResultCount()); + } + + @Test + void testLargeReduction() { + // Create a ruleset that should benefit from optimization + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("A").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("B").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("C").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("D").type(ParameterType.STRING).build()) + .build(); + + // Multiple rules with overlapping conditions + Rule rule1 = EndpointRule.builder() + .conditions( + Condition.builder().fn(TestHelpers.isSet("A")).build(), + Condition.builder().fn(TestHelpers.isSet("B")).build()) + .endpoint(TestHelpers.endpoint("https://ab.example.com")); + + Rule rule2 = EndpointRule.builder() + .conditions( + Condition.builder().fn(TestHelpers.isSet("A")).build(), + Condition.builder().fn(TestHelpers.isSet("C")).build()) + .endpoint(TestHelpers.endpoint("https://ac.example.com")); + + Rule rule3 = EndpointRule.builder() + .conditions( + Condition.builder().fn(TestHelpers.isSet("B")).build(), + Condition.builder().fn(TestHelpers.isSet("D")).build()) + .endpoint(TestHelpers.endpoint("https://bd.example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule1) + .addRule(rule2) + .addRule(rule3) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Bdd originalBdd = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()).compile(); + + SiftingOptimization optimizer = SiftingOptimization.builder() + .cfg(cfg) + .granularEffort(100_000, 10) // Allow more aggressive optimization + .build(); + Bdd optimizedBdd = optimizer.apply(originalBdd); + + // Should maintain correctness + assertEquals(originalBdd.getConditionCount(), optimizedBdd.getConditionCount()); + assertEquals(originalBdd.getResultCount(), optimizedBdd.getResultCount()); + + // Often achieves some reduction + assertTrue(optimizedBdd.getNodeCount() <= originalBdd.getNodeCount()); } } From 73b44f2160ab957ed33cdf8fbef7c75f15a91dea Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Tue, 29 Jul 2025 14:29:09 -0500 Subject: [PATCH 05/23] Always serialize conditions... even when empty --- .../rulesengine/language/syntax/rule/Rule.java | 15 +++++++-------- .../rulesengine/language/minimal-ruleset.json | 1 + 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java index ddcf4a12c6c..6ef6943d306 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java @@ -171,18 +171,17 @@ public Type typeCheck(Scope scope) { public Node toNode() { ObjectNode.Builder builder = ObjectNode.builder(); - if (!conditions.isEmpty()) { - ArrayNode.Builder conditionsBuilder = ArrayNode.builder(); - for (Condition condition : conditions) { - conditionsBuilder.withValue(condition.toNode()); - } - builder.withMember(CONDITIONS, conditionsBuilder.build()); - } - if (documentation != null) { builder.withMember(DOCUMENTATION, documentation); } + // TODO: we should remove the requirement of serializing an empty array here + ArrayNode.Builder conditionsBuilder = ArrayNode.builder(); + for (Condition condition : conditions) { + conditionsBuilder.withValue(condition.toNode()); + } + builder.withMember(CONDITIONS, conditionsBuilder.build()); + withValueNode(builder); return builder.build(); } diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json index 24ed8bc5eb1..3e554bce10b 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/minimal-ruleset.json @@ -10,6 +10,7 @@ "rules": [ { "documentation": "base rule", + "conditions": [], "endpoint": { "url": "https://{Region}.amazonaws.com", "properties": { From 0d8abb99446f03a63763c078232f598dd7cf8188 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Tue, 29 Jul 2025 16:05:47 -0500 Subject: [PATCH 06/23] Fix BddTrait logic issue, using wrong conditions We were using the wrong condition ordering in BddTrait after compiling a Bdd from the CFG, leading to a totally broken BDD. Also adds some tests, fixes, and generalizes BddTrait transforms --- .../smithy/rulesengine/logic/bdd/Bdd.java | 130 +++++++----- .../rulesengine/logic/bdd/BddBuilder.java | 47 ++--- .../rulesengine/logic/bdd/BddCompiler.java | 10 +- .../rulesengine/logic/bdd/BddTrait.java | 56 ++---- .../rulesengine/logic/bdd/NodeReversal.java | 18 +- .../logic/bdd/SiftingOptimization.java | 125 +++++++----- .../rulesengine/logic/bdd/BddBuilderTest.java | 86 ++------ .../logic/bdd/BddCompilerTest.java | 3 +- .../logic/bdd/BddEquivalenceCheckerTest.java | 188 ++++++++++++++++++ .../smithy/rulesengine/logic/bdd/BddTest.java | 162 +++++++++++++++ .../rulesengine/logic/bdd/BddTraitTest.java | 20 ++ .../logic/bdd/NodeReversalTest.java | 47 +++-- .../logic/bdd/SiftingOptimizationTest.java | 95 ++++++--- 13 files changed, 705 insertions(+), 282 deletions(-) create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceCheckerTest.java diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java index eab744c8cd5..27c52480528 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java @@ -10,8 +10,6 @@ import java.io.UncheckedIOException; import java.io.Writer; import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.Objects; import java.util.function.Consumer; import software.amazon.smithy.rulesengine.logic.ConditionEvaluator; @@ -44,6 +42,7 @@ public final class Bdd { private final int rootRef; private final int conditionCount; private final int resultCount; + private final int nodeCount; /** * Creates a BDD by streaming nodes directly into the structure. @@ -55,23 +54,68 @@ public final class Bdd { * @param nodeHandler a handler that will provide nodes via a consumer */ public Bdd(int rootRef, int conditionCount, int resultCount, int nodeCount, Consumer nodeHandler) { + validateCounts(conditionCount, resultCount, nodeCount); + validateRootReference(rootRef, nodeCount); + this.rootRef = rootRef; this.conditionCount = conditionCount; this.resultCount = resultCount; - - if (rootRef < 0 && rootRef != -1) { - throw new IllegalArgumentException("Root reference cannot be complemented: " + rootRef); - } + this.nodeCount = nodeCount; InputNodeConsumer consumer = new InputNodeConsumer(nodeCount); nodeHandler.accept(consumer); - this.variables = consumer.variables; this.highs = consumer.highs; this.lows = consumer.lows; if (consumer.index != nodeCount) { - throw new IllegalStateException("Expected " + nodeCount + " node, but got " + consumer.index); + throw new IllegalStateException("Expected " + nodeCount + " nodes, but got " + consumer.index); + } + } + + Bdd(int[] variables, int[] highs, int[] lows, int nodeCount, int rootRef, int conditionCount, int resultCount) { + validateArrays(variables, highs, lows, nodeCount); + validateCounts(conditionCount, resultCount, nodeCount); + validateRootReference(rootRef, nodeCount); + + this.variables = variables; + this.highs = highs; + this.lows = lows; + this.rootRef = rootRef; + this.conditionCount = conditionCount; + this.resultCount = resultCount; + this.nodeCount = nodeCount; + } + + private static void validateCounts(int conditionCount, int resultCount, int nodeCount) { + if (conditionCount < 0) { + throw new IllegalArgumentException("Condition count cannot be negative: " + conditionCount); + } else if (resultCount < 0) { + throw new IllegalArgumentException("Result count cannot be negative: " + resultCount); + } else if (nodeCount < 0) { + throw new IllegalArgumentException("Node count cannot be negative: " + nodeCount); + } + } + + private static void validateRootReference(int rootRef, int nodeCount) { + if (isComplemented(rootRef) && !isTerminal(rootRef)) { + throw new IllegalArgumentException("Root reference cannot be complemented: " + rootRef); + } else if (isNodeReference(rootRef)) { + int idx = Math.abs(rootRef) - 1; + if (idx >= nodeCount) { + throw new IllegalArgumentException("Root points to invalid BDD node: " + idx + + " (node count: " + nodeCount + ")"); + } + } + } + + private static void validateArrays(int[] variables, int[] highs, int[] lows, int nodeCount) { + if (variables.length != highs.length || variables.length != lows.length) { + throw new IllegalArgumentException("Array lengths must match: variables=" + variables.length + + ", highs=" + highs.length + ", lows=" + lows.length); + } else if (nodeCount > variables.length) { + throw new IllegalArgumentException("Node count (" + nodeCount + + ") exceeds array capacity (" + variables.length + ")"); } } @@ -96,23 +140,6 @@ public void accept(int var, int high, int low) { } } - Bdd(int[] variables, int[] highs, int[] lows, int rootRef, int conditionCount, int resultCount) { - this.variables = Objects.requireNonNull(variables, "variables is null"); - this.highs = Objects.requireNonNull(highs, "highs is null"); - this.lows = Objects.requireNonNull(lows, "lows is null"); - this.rootRef = rootRef; - this.conditionCount = conditionCount; - this.resultCount = resultCount; - - if (rootRef < 0 && rootRef != -1) { - throw new IllegalArgumentException("Root reference cannot be complemented: " + rootRef); - } - - if (variables.length != highs.length || variables.length != lows.length) { - throw new IllegalArgumentException("Array lengths must match"); - } - } - /** * Gets the number of conditions. * @@ -137,7 +164,7 @@ public int getResultCount() { * @return the node count */ public int getNodeCount() { - return variables.length; + return nodeCount; } /** @@ -156,9 +183,16 @@ public int getRootRef() { * @return the variable index */ public int getVariable(int nodeIndex) { + validateRange(nodeIndex); return variables[nodeIndex]; } + private void validateRange(int index) { + if (index < 0 || index >= nodeCount) { + throw new IndexOutOfBoundsException("Node index out of bounds: " + index + " (size: " + nodeCount + ")"); + } + } + /** * Gets the high (true) reference for a node. * @@ -166,6 +200,7 @@ public int getVariable(int nodeIndex) { * @return the high reference */ public int getHigh(int nodeIndex) { + validateRange(nodeIndex); return highs[nodeIndex]; } @@ -176,6 +211,7 @@ public int getHigh(int nodeIndex) { * @return the low reference */ public int getLow(int nodeIndex) { + validateRange(nodeIndex); return lows[nodeIndex]; } @@ -185,7 +221,7 @@ public int getLow(int nodeIndex) { * @param consumer the consumer to receive the integers */ public void getNodes(BddNodeConsumer consumer) { - for (int i = 0; i < variables.length; i++) { + for (int i = 0; i < nodeCount; i++) { consumer.accept(variables[i], highs[i], lows[i]); } } @@ -201,17 +237,14 @@ public int evaluate(ConditionEvaluator ev) { int[] vars = this.variables; int[] hi = this.highs; int[] lo = this.lows; - int off = RESULT_OFFSET; - // keep walking while ref is a non-terminal node - while ((ref > 1 && ref < off) || (ref < -1 && ref > -off)) { + while (isNodeReference(ref)) { int idx = ref > 0 ? ref - 1 : -ref - 1; // Math.abs // test ^ complement, pick hi or lo ref = (ev.test(vars[idx]) ^ (ref < 0)) ? hi[idx] : lo[idx]; } - // +1/-1 => no match - return (ref == 1 || ref == -1) ? -1 : (ref - off); + return isTerminal(ref) ? -1 : ref - RESULT_OFFSET; } /** @@ -221,10 +254,7 @@ public int evaluate(ConditionEvaluator ev) { * @return true if this is a node reference */ public static boolean isNodeReference(int ref) { - if (ref == 0 || isTerminal(ref)) { - return false; - } - return Math.abs(ref) < RESULT_OFFSET; + return (ref > 1 && ref < RESULT_OFFSET) || (ref < -1 && ref > -RESULT_OFFSET); } /** @@ -264,21 +294,31 @@ public boolean equals(Object obj) { } else if (!(obj instanceof Bdd)) { return false; } + Bdd other = (Bdd) obj; - return rootRef == other.rootRef - && conditionCount == other.conditionCount - && resultCount == other.resultCount - && Arrays.equals(variables, other.variables) - && Arrays.equals(highs, other.highs) - && Arrays.equals(lows, other.lows); + if (rootRef != other.rootRef + || conditionCount != other.conditionCount + || resultCount != other.resultCount + || nodeCount != other.nodeCount) { + return false; + } + + // Now check the views of arrays of each. + for (int i = 0; i < nodeCount; i++) { + if (variables[i] != other.variables[i] || highs[i] != other.highs[i] || lows[i] != other.lows[i]) { + return false; + } + } + + return true; } @Override public int hashCode() { - int hash = 31 * rootRef + variables.length; + int hash = 31 * rootRef + nodeCount; // Sample up to 16 nodes distributed across the BDD - int step = Math.max(1, variables.length / 16); - for (int i = 0; i < variables.length; i += step) { + int step = Math.max(1, nodeCount / 16); + for (int i = 0; i < nodeCount; i += step) { hash = 31 * hash + variables[i]; hash = 31 * hash + highs[i]; hash = 31 * hash + lows[i]; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java index 68af133e085..64660e2c0a8 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java @@ -35,9 +35,9 @@ final class BddBuilder { private final Map iteCache; // Node storage: three separate arrays - private int[] variables; - private int[] highs; - private int[] lows; + private int[] variables = new int[1024]; + private int[] highs = new int[1024]; + private int[] lows = new int[1024]; private int nodeCount; // Unique table for node deduplication @@ -51,9 +51,6 @@ final class BddBuilder { * Creates a new BDD engine. */ public BddBuilder() { - this.variables = new int[4]; - this.highs = new int[4]; - this.lows = new int[4]; this.nodeCount = 1; this.uniqueTable = new HashMap<>(); this.iteCache = new HashMap<>(); @@ -63,6 +60,10 @@ public BddBuilder() { lows[0] = FALSE_REF; } + int getNodeCount() { + return nodeCount; + } + /** * Sets the number of conditions. Must be called before creating result nodes. * @@ -127,6 +128,10 @@ public int makeResult(int resultIndex) { * @return reference to the BDD node */ public int makeNode(int var, int high, int low) { + if (conditionCount >= 0 && (var < 0 || var >= conditionCount)) { + throw new IllegalArgumentException("Variable out of bounds: " + var); + } + // Reduction rule: if both branches are identical, skip this test if (high == low) { return high; @@ -170,8 +175,8 @@ private int insertNode(int var, int high, int low, boolean flip) { private void ensureCapacity() { if (nodeCount >= variables.length) { - // Grow by 50% - int newCapacity = variables.length + (variables.length >> 1); + // Double the current capacity + int newCapacity = variables.length * 2; variables = Arrays.copyOf(variables, newCapacity); highs = Arrays.copyOf(highs, newCapacity); lows = Arrays.copyOf(lows, newCapacity); @@ -592,23 +597,6 @@ public BddBuilder reset() { return this; } - /** - * Get the nodes as a flat array. - * - * @return array of nodes, trimmed to actual size. - */ - public int[] getNodesArray() { - // Convert back to flat array for compatibility - int[] result = new int[nodeCount * 3]; - for (int i = 0; i < nodeCount; i++) { - int baseIdx = i * 3; - result[baseIdx] = variables[i]; - result[baseIdx + 1] = highs[i]; - result[baseIdx + 2] = lows[i]; - } - return result; - } - /** * Builds a BDD from the current state of the builder. * @@ -620,11 +608,10 @@ Bdd build(int rootRef, int resultCount) { throw new IllegalStateException("Condition count must be set before building BDD"); } - // Create trimmed copies of the arrays with only the used portion - int[] trimmedVariables = Arrays.copyOf(variables, nodeCount); - int[] trimmedHighs = Arrays.copyOf(highs, nodeCount); - int[] trimmedLows = Arrays.copyOf(lows, nodeCount); - return new Bdd(trimmedVariables, trimmedHighs, trimmedLows, rootRef, conditionCount, resultCount); + int[] v = Arrays.copyOf(variables, nodeCount); + int[] h = Arrays.copyOf(highs, nodeCount); + int[] l = Arrays.copyOf(lows, nodeCount); + return new Bdd(v, h, l, nodeCount, rootRef, conditionCount, resultCount); } private void validateBooleanOperands(int f, int g, String operation) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java index 7268fe8fb10..12950082290 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java @@ -62,8 +62,8 @@ Bdd compile() { noMatchIndex = getOrCreateResultIndex(NoMatchRule.INSTANCE); int rootRef = convertCfgToBdd(cfg.getRoot()); rootRef = bddBuilder.reduce(rootRef); - Bdd bdd = bddBuilder.build(rootRef, indexedResults.size()); + long elapsed = System.currentTimeMillis() - start; LOGGER.fine(String.format( "BDD compilation complete: %d conditions, %d results, %d BDD nodes in %dms", @@ -75,6 +75,14 @@ Bdd compile() { return bdd; } + List getIndexedResults() { + return indexedResults; + } + + List getOrderedConditions() { + return orderedConditions; + } + private int convertCfgToBdd(CfgNode cfgNode) { Integer cached = nodeCache.get(cfgNode); if (cached != null) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java index 4116837bbab..71fceba527a 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java @@ -11,11 +11,10 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Base64; -import java.util.LinkedHashSet; import java.util.List; import java.util.Set; +import java.util.function.Function; import software.amazon.smithy.model.node.Node; import software.amazon.smithy.model.node.ObjectNode; import software.amazon.smithy.model.shapes.ShapeId; @@ -27,9 +26,6 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; -import software.amazon.smithy.rulesengine.logic.cfg.CfgNode; -import software.amazon.smithy.rulesengine.logic.cfg.ConditionData; -import software.amazon.smithy.rulesengine.logic.cfg.ResultNode; import software.amazon.smithy.utils.SetUtils; import software.amazon.smithy.utils.SmithyBuilder; import software.amazon.smithy.utils.ToSmithyBuilder; @@ -68,43 +64,19 @@ private BddTrait(Builder builder) { * @return the BddTrait containing the compiled BDD and all context */ public static BddTrait from(Cfg cfg) { - ConditionData conditionData = cfg.getConditionData(); - List conditions = Arrays.asList(conditionData.getConditions()); - - // Compile the BDD BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); Bdd bdd = compiler.compile(); - List results = extractResultsFromCfg(cfg, bdd); - Parameters parameters = cfg.getRuleSet().getParameters(); - return builder().parameters(parameters).conditions(conditions).results(results).bdd(bdd).build(); - } - - private static List extractResultsFromCfg(Cfg cfg, Bdd bdd) { - // The BddCompiler always puts NoMatchRule at index 0 - List results = new ArrayList<>(); - results.add(NoMatchRule.INSTANCE); - - Set uniqueResults = new LinkedHashSet<>(); - for (CfgNode node : cfg) { - if (node instanceof ResultNode) { - Rule result = ((ResultNode) node).getResult(); - if (result != null && !(result instanceof NoMatchRule)) { - uniqueResults.add(result.withoutConditions()); - } - } - } - - results.addAll(uniqueResults); - - if (results.size() != bdd.getResultCount()) { - throw new IllegalStateException(String.format( - "Result count mismatch: found %d results in CFG but BDD expects %d", - results.size(), - bdd.getResultCount())); + if (compiler.getOrderedConditions().size() != bdd.getConditionCount()) { + throw new IllegalStateException("Mismatch between BDD var count and orderedConditions size"); } - return results; + return builder() + .parameters(cfg.getRuleSet().getParameters()) + .conditions(compiler.getOrderedConditions()) + .results(compiler.getIndexedResults()) + .bdd(bdd) + .build(); } /** @@ -143,6 +115,16 @@ public Bdd getBdd() { return bdd; } + /** + * Transform this BDD using the given function and return the updated BddTrait. + * + * @param transformer Transformer used to modify the trait. + * @return the updated trait. + */ + public BddTrait transform(Function transformer) { + return transformer.apply(this); + } + @Override protected Node createNode() { ObjectNode.Builder builder = ObjectNode.builder(); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java index 3aac6c56235..ecd58249762 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java @@ -13,12 +13,24 @@ *

    This transformation reverses the node array (except the terminal at index 0) * and updates all references throughout the BDD to maintain correctness. */ -public final class NodeReversal implements Function { +public final class NodeReversal implements Function { private static final Logger LOGGER = Logger.getLogger(NodeReversal.class.getName()); @Override - public Bdd apply(Bdd bdd) { + public BddTrait apply(BddTrait trait) { + Bdd reversedBdd = reverse(trait.getBdd()); + // Only rebuild the trait if the BDD actually changed + return reversedBdd == trait.getBdd() ? trait : trait.toBuilder().bdd(reversedBdd).build(); + } + + /** + * Reverses the node ordering in a BDD. + * + * @param bdd the BDD to reverse + * @return the reversed BDD, or the original if too small to reverse + */ + public static Bdd reverse(Bdd bdd) { LOGGER.info("Starting BDD node reversal optimization"); int nodeCount = bdd.getNodeCount(); @@ -62,7 +74,7 @@ public Bdd apply(Bdd bdd) { * @param oldToNew the index mapping array * @return the remapped reference */ - private int remapReference(int ref, int[] oldToNew) { + private static int remapReference(int ref, int[] oldToNew) { // Return result references as-is. if (ref == 0) { return 0; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java index 22b203b110c..fa8902e5de5 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -14,6 +14,7 @@ import java.util.function.Function; import java.util.logging.Logger; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.logic.ConditionInfo; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; import software.amazon.smithy.rulesengine.logic.cfg.CfgNode; @@ -32,7 +33,7 @@ * *

    Each stage runs until reaching its target size or maximum passes. */ -public final class SiftingOptimization implements Function { +public final class SiftingOptimization implements Function { private static final Logger LOGGER = Logger.getLogger(SiftingOptimization.class.getName()); // Default thresholds and passes for each optimization level @@ -109,68 +110,68 @@ public static Builder builder() { } @Override - public Bdd apply(Bdd bdd) { + public BddTrait apply(BddTrait trait) { try { - return doApply(bdd); + return doApply(trait); } finally { threadBuilder.remove(); } } - private Bdd doApply(Bdd bdd) { + private BddTrait doApply(BddTrait trait) { LOGGER.info("Starting BDD sifting optimization"); long startTime = System.currentTimeMillis(); // Pre-spin the ForkJoinPool for better first-pass performance ForkJoinPool.commonPool().submit(() -> {}).join(); - // Get the conditions from the CFG (since Bdd no longer stores them) - List bddConditions = cfg.getConditionData().getConditions() != null - ? Arrays.asList(cfg.getConditionData().getConditions()) - : new ArrayList<>(); - - OptimizationState state = initializeOptimization(bdd, bddConditions); - LOGGER.info(String.format("Initial reordering: %d -> %d nodes", state.initialSize, state.currentSize)); + OptimizationState state = initializeOptimization(trait); + LOGGER.info(String.format("Initial size: %d nodes", state.initialSize)); state = runCoarseStage(state); state = runMediumStage(state); state = runGranularStage(state); double totalTimeInSeconds = (System.currentTimeMillis() - startTime) / 1000.0; - if (state.bestSize < state.initialSize) { - LOGGER.info(String.format("Optimization complete: %d -> %d nodes (%.1f%% total reduction) in %fs", - state.initialSize, - state.bestSize, - (1.0 - (double) state.bestSize / state.initialSize) * 100, - totalTimeInSeconds)); - } else { + + // Only rebuild if we found an improvement + if (state.bestSize >= state.initialSize) { LOGGER.info(String.format("No improvements found in %fs", totalTimeInSeconds)); + return trait; } - return state.bestBdd; + LOGGER.info(String.format("Optimization complete: %d -> %d nodes (%.1f%% total reduction) in %fs", + state.initialSize, + state.bestSize, + (1.0 - (double) state.bestSize / state.initialSize) * 100, + totalTimeInSeconds)); + + // Rebuild the BddTrait with the optimized ordering and BDD + return trait.toBuilder() + .conditions(state.orderView) + .results(state.results) + .bdd(state.bestBdd) + .build(); } - private OptimizationState initializeOptimization(Bdd bdd, List bddConditions) { - // Start with an intelligent initial ordering - List initialOrder = DefaultOrderingStrategy.orderConditions( - bddConditions.toArray(new Condition[0]), - conditionInfos); - - // Sanity check that ordering didn't lose conditions - if (initialOrder.size() != bddConditions.size()) { - throw new IllegalStateException("Initial ordering changed condition count: " + - bddConditions.size() + " -> " + initialOrder.size()); - } + private OptimizationState initializeOptimization(BddTrait trait) { + // Use the trait's existing ordering as the starting point + List initialOrder = new ArrayList<>(trait.getConditions()); Condition[] order = initialOrder.toArray(new Condition[0]); List orderView = Arrays.asList(order); - // Build initial BDD with better ordering - Bdd currentBest = compileBdd(orderView); - int currentSize = currentBest.getNodeCount() - 1; // -1 for terminal - int initialSize = bdd.getNodeCount() - 1; - - return new OptimizationState(order, orderView, currentBest, currentSize, initialSize); + // Get the initial size from the input BDD + Bdd bdd = trait.getBdd(); + int initialSize = bdd.getNodeCount() - 1; // -1 for terminal + + // No need to recompile just for results—use the trait's own results + return new OptimizationState(order, + orderView, + bdd, + initialSize, + initialSize, + trait.getResults()); } private OptimizationState runCoarseStage(OptimizationState state) { @@ -202,7 +203,7 @@ private OptimizationState runGranularStage(OptimizationState state) { OptimizationResult swapResult = performAdjacentSwaps(state.order, state.orderView, state.currentSize); if (swapResult.improved) { LOGGER.info(String.format("Adjacent swaps: %d -> %d nodes", state.currentSize, swapResult.size)); - return state.withResult(swapResult.bdd, swapResult.size); + return state.withResult(swapResult.bdd, swapResult.size, swapResult.results); } return state; @@ -241,7 +242,7 @@ private OptimizationState runOptimizationStage( LOGGER.fine(String.format("%s pass %d found no improvements", stageName, pass)); break; } else { - currentState = currentState.withResult(result.bdd, result.size); + currentState = currentState.withResult(result.bdd, result.size, result.results); double reduction = (1.0 - (double) result.size / passStartSize) * 100; LOGGER.fine(String.format("%s pass %d: %d -> %d nodes (%.1f%% reduction)", stageName, @@ -271,6 +272,7 @@ private OptimizationResult runOptimizationPass( OrderConstraints constraints = new OrderConstraints(dependencyGraph, orderView); Bdd bestBdd = null; int bestSize = currentSize; + List bestResults = null; // Sample variables based on effort level for (int varIdx = 0; varIdx < order.length; varIdx += effort.sampleRate) { @@ -290,12 +292,14 @@ private OptimizationResult runOptimizationPass( // Move to best position and build BDD once move(order, varIdx, best.position); - Bdd newBdd = compileBdd(orderView); + BddCompilationResult compilationResult = compileBddWithResults(orderView); + Bdd newBdd = compilationResult.bdd; int newSize = newBdd.getNodeCount() - 1; if (newSize < bestSize) { bestBdd = newBdd; bestSize = newSize; + bestResults = compilationResult.results; improvements++; // Update constraints after successful move @@ -303,13 +307,14 @@ private OptimizationResult runOptimizationPass( } } - return new OptimizationResult(bestBdd, bestSize, improvements > 0); + return new OptimizationResult(bestBdd, bestSize, improvements > 0, bestResults); } private OptimizationResult performAdjacentSwaps(Condition[] order, List orderView, int currentSize) { OrderConstraints constraints = new OrderConstraints(dependencyGraph, orderView); Bdd bestBdd = null; int bestSize = currentSize; + List bestResults = null; boolean improved = false; for (int i = 0; i < order.length - 1; i++) { @@ -317,8 +322,10 @@ private OptimizationResult performAdjacentSwaps(Condition[] order, List ordering) { + private BddCompilationResult compileBddWithResults(List ordering) { BddBuilder builder = threadBuilder.get().reset(); - return new BddCompiler(cfg, ConditionOrderingStrategy.fixed(ordering), builder).compile(); + BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.fixed(ordering), builder); + Bdd bdd = compiler.compile(); + return new BddCompilationResult(bdd, compiler.getIndexedResults()); } /** * Counts nodes for a given ordering without keeping the BDD. */ private int countNodes(List ordering) { - Bdd bdd = compileBdd(ordering); + Bdd bdd = compileBddWithResults(ordering).bdd; return bdd.getNodeCount() - 1; // -1 for terminal } + // Container for BDD compilation results + private static final class BddCompilationResult { + final Bdd bdd; + final List results; + + BddCompilationResult(Bdd bdd, List results) { + this.bdd = bdd; + this.results = results; + } + } + // Position and its node count private static final class PositionCount { final int position; @@ -462,11 +482,13 @@ private static final class OptimizationResult { final Bdd bdd; final int size; final boolean improved; + final List results; - OptimizationResult(Bdd bdd, int size, boolean improved) { + OptimizationResult(Bdd bdd, int size, boolean improved, List results) { this.bdd = bdd; this.size = size; this.improved = improved; + this.results = results; } } @@ -478,13 +500,15 @@ private static final class OptimizationState { final int currentSize; final int bestSize; final int initialSize; + final List results; OptimizationState( Condition[] order, List orderView, Bdd bestBdd, int currentSize, - int initialSize + int initialSize, + List results ) { this.order = order; this.orderView = orderView; @@ -492,10 +516,11 @@ private static final class OptimizationState { this.currentSize = currentSize; this.bestSize = currentSize; this.initialSize = initialSize; + this.results = results; } - OptimizationState withResult(Bdd newBdd, int newSize) { - return new OptimizationState(order, orderView, newBdd, newSize, initialSize); + OptimizationState withResult(Bdd newBdd, int newSize, List newResults) { + return new OptimizationState(order, orderView, newBdd, newSize, initialSize, newResults); } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java index 0ad1751d05c..57c9d1f4931 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java @@ -41,7 +41,7 @@ void testNodeReduction() { assertEquals(1, reduced); // Should return TRUE directly // Verify no new node was created (only terminal at index 0) - assertEquals(3, builder.getNodesArray().length); // 3 ints for terminal node + assertEquals(1, builder.getNodeCount()); // Only terminal node exists } @Test @@ -58,7 +58,7 @@ void testComplementCanonicalization() { assertEquals(-node1, node2); // Should be complement of first node // Only one actual node should be created (plus terminal) - assertEquals(6, builder.getNodesArray().length); // 3 ints for terminal + 3 for one node + assertEquals(2, builder.getNodeCount()); // Terminal + one node } @Test @@ -70,7 +70,7 @@ void testNodeDeduplication() { int node2 = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); assertEquals(node1, node2); // Should return same reference - assertEquals(6, builder.getNodesArray().length); // No duplicate created + assertEquals(2, builder.getNodeCount()); // Terminal + one node (no duplicate) } @Test @@ -232,11 +232,11 @@ void testReduceSimpleBdd() { int b = builder.makeNode(1, a, builder.makeFalse()); int root = builder.makeNode(0, b, a); - int nodesBefore = builder.getNodesArray().length; - builder.reduce(root); + int nodesBefore = builder.getNodeCount(); + int reduced = builder.reduce(root); // Structure should be preserved if already optimal - assertEquals(nodesBefore, builder.getNodesArray().length); + assertEquals(nodesBefore, builder.getNodeCount()); } @Test @@ -247,11 +247,11 @@ void testReduceNoChange() { int right = builder.makeNode(1, builder.makeTrue(), builder.makeFalse()); int root = builder.makeNode(0, right, builder.makeFalse()); - int nodesBefore = builder.getNodesArray().length; - builder.reduce(root); + int nodesBefore = builder.getNodeCount(); + int reduced = builder.reduce(root); // No change expected - assertEquals(nodesBefore, builder.getNodesArray().length); + assertEquals(nodesBefore, builder.getNodeCount()); } @Test @@ -290,7 +290,7 @@ void testReduceWithComplement() { assertEquals(builder.negate(reduced), reducedComplement); // Verify the structure is preserved - assertTrue(builder.getNodesArray().length > 3); + assertTrue(builder.getNodeCount() > 1); // More than just terminal } @Test @@ -303,7 +303,7 @@ void testReduceClearsCache() { int ite1 = builder.ite(a, b, builder.makeFalse()); // Reduce - builder.reduce(ite1); + int reduced = builder.reduce(ite1); // Cache should be cleared, so same ITE creates new result // Recreate the nodes since reduce may have changed internal state @@ -313,58 +313,6 @@ void testReduceClearsCache() { assertTrue(ite2 != 0); // Should get a valid reference } - @Test - void testReduceSharedSubgraphs() { - builder.setConditionCount(3); - - // Create BDD with shared subgraphs - use only boolean nodes - int shared = builder.makeNode(2, builder.makeTrue(), builder.makeFalse()); - int left = builder.makeNode(1, shared, builder.makeFalse()); - int right = builder.makeNode(1, builder.makeTrue(), shared); - int root = builder.makeNode(0, left, right); - - builder.reduce(root); - - // Shared subgraph should remain shared after reduction - int[] nodes = builder.getNodesArray(); - - // Verify structure is maintained - at least one node should exist - assertTrue(nodes.length > 3); - } - - @Test - void testReducePreservesResultNodes() { - builder.setConditionCount(2); - - // Create BDD with result terminals - int result0 = builder.makeResult(0); - int result1 = builder.makeResult(1); - int cond = builder.makeNode(0, result0, result1); - int root = builder.makeNode(1, cond, builder.makeFalse()); - - int reduced = builder.reduce(root); - - // Result refs should be preserved in the hi/lo branches - boolean foundResult0 = false; - boolean foundResult1 = false; - - // Check the nodes for result references - int[] nodes = builder.getNodesArray(); - int nodeCount = nodes.length / 3; - for (int i = 0; i < nodeCount; i++) { - int baseIdx = i * 3; - if (nodes[baseIdx + 1] == result0 || nodes[baseIdx + 2] == result0) { - foundResult0 = true; - } - if (nodes[baseIdx + 1] == result1 || nodes[baseIdx + 2] == result1) { - foundResult1 = true; - } - } - - assertTrue(foundResult0); - assertTrue(foundResult1); - } - @Test void testReduceActuallyReduces() { builder.setConditionCount(3); @@ -374,9 +322,9 @@ void testReduceActuallyReduces() { int middle = builder.makeNode(1, bottom, builder.makeFalse()); int root = builder.makeNode(0, middle, bottom); - int beforeSize = builder.getNodesArray().length; - builder.reduce(root); - int afterSize = builder.getNodesArray().length; + int beforeSize = builder.getNodeCount(); + int reduced = builder.reduce(root); + int afterSize = builder.getNodeCount(); // In this case, no reduction should occur since makeNode already optimized assertEquals(beforeSize, afterSize); @@ -427,8 +375,8 @@ void testCofactorRecursive() { // The cofactors should be different assertTrue(cofactorTrue != cofactorFalse); - // Verify structure is simplified - assertTrue(builder.getNodesArray().length > 3); + // Verify structure exists + assertTrue(builder.getNodeCount() > 1); } @Test @@ -522,7 +470,7 @@ void testReset() { builder.reset(); // Verify state is cleared - assertEquals(3, builder.getNodesArray().length); // Only terminal (3 ints) + assertEquals(1, builder.getNodeCount()); // Only terminal assertThrows(IllegalStateException.class, () -> builder.makeResult(0)); // Can use builder again diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java index d9d385166a8..ddd8fe02d84 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java @@ -9,7 +9,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.Arrays; -import java.util.List; import org.junit.jupiter.api.Test; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; @@ -128,7 +127,7 @@ void testCompileWithCustomOrdering() { // Get the actual conditions from the CFG after SSA transform ConditionData conditionData = cfg.getConditionData(); - List cfgConditions = Arrays.asList(conditionData.getConditions()); + Condition[] cfgConditions = conditionData.getConditions(); // Find the conditions that correspond to A and B Condition condA = null; diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceCheckerTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceCheckerTest.java new file mode 100644 index 00000000000..403b27bfc78 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceCheckerTest.java @@ -0,0 +1,188 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.utils.ListUtils; + +class BddEquivalenceCheckerTest { + + @Test + void testSimpleEquivalentBdd() { + // Create a simple ruleset + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build()) + .rules(ListUtils.of( + EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")), + // Default case + ErrorRule.builder() + .error(Literal.of("No region provided")))) + .build(); + + // Convert to CFG + Cfg cfg = Cfg.from(ruleSet); + + // Create BDD from CFG + BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + Bdd bdd = compiler.compile(); + + // Create checker + BddEquivalenceChecker checker = BddEquivalenceChecker.of( + cfg, + bdd, + compiler.getOrderedConditions(), + compiler.getIndexedResults()); + + // Should pass verification + assertDoesNotThrow(checker::verify); + } + + @Test + void testEmptyRulesetEquivalence() { + // Empty ruleset with a default endpoint + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().build()) + .rules(ListUtils.of( + EndpointRule.builder() + .endpoint(TestHelpers.endpoint("https://default.com")))) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + + BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + Bdd bdd = compiler.compile(); + + BddEquivalenceChecker checker = BddEquivalenceChecker.of( + cfg, + bdd, + compiler.getOrderedConditions(), + compiler.getIndexedResults()); + + assertDoesNotThrow(checker::verify); + } + + @Test + void testMultipleConditionsEquivalence() { + // Ruleset with multiple conditions (AND logic) + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build()) + .rules(ListUtils.of( + EndpointRule.builder() + .conditions( + Condition.builder().fn(TestHelpers.isSet("Region")).build(), + Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")), + // Default case + ErrorRule.builder() + .error(Literal.of("Missing required parameters")))) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + Bdd bdd = compiler.compile(); + + BddEquivalenceChecker checker = BddEquivalenceChecker.of( + cfg, + bdd, + compiler.getOrderedConditions(), + compiler.getIndexedResults()); + + assertDoesNotThrow(checker::verify); + } + + @Test + void testSetMaxSamples() { + // Create a simpler test with just 3 parameters to avoid ordering issues + Parameters.Builder paramsBuilder = Parameters.builder(); + List rules = new ArrayList<>(); + + // Add parameters with zero-padded names to ensure correct ordering + for (int i = 0; i < 3; i++) { + String paramName = String.format("Param%02d", i); // Param00, Param01, Param02 + paramsBuilder.addParameter(Parameter.builder().name(paramName).type(ParameterType.STRING).build()); + rules.add(EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet(paramName)).build()) + .endpoint(TestHelpers.endpoint("https://example" + i + ".com"))); + } + + // Add default case + rules.add(ErrorRule.builder() + .error(Literal.of("No parameters set"))); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(paramsBuilder.build()) + .rules(rules) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + Bdd bdd = compiler.compile(); + + BddEquivalenceChecker checker = BddEquivalenceChecker.of( + cfg, + bdd, + compiler.getOrderedConditions(), + compiler.getIndexedResults()); + + // Set a small max samples to make test fast + checker.setMaxSamples(100); + + assertDoesNotThrow(checker::verify); + } + + @Test + void testSetMaxDuration() { + // Create a complex ruleset + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build()) + .rules(ListUtils.of( + EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")), + ErrorRule.builder() + .error(Literal.of("No region provided")))) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + Bdd bdd = compiler.compile(); + + BddEquivalenceChecker checker = BddEquivalenceChecker.of( + cfg, + bdd, + compiler.getOrderedConditions(), + compiler.getIndexedResults()); + + // Set a short timeout + checker.setMaxDuration(Duration.ofMillis(100)); + + assertDoesNotThrow(checker::verify); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java index e68b63a2d72..96079a2c6cc 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java @@ -4,8 +4,10 @@ */ package software.amazon.smithy.rulesengine.logic.bdd; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -161,6 +163,166 @@ private Bdd createSimpleBdd() { }); } + @Test + void testStreamingConstructorValidation() { + // Valid construction + assertDoesNotThrow(() -> { + new Bdd(1, 1, 1, 1, consumer -> { + consumer.accept(-1, 1, -1); + }); + }); + + // Root cannot be complemented (except -1) + assertThrows(IllegalArgumentException.class, () -> { + new Bdd(-2, 1, 1, 1, consumer -> { + consumer.accept(-1, 1, -1); + }); + }); + + // Root -1 (FALSE) is allowed + assertDoesNotThrow(() -> { + new Bdd(-1, 1, 1, 1, consumer -> { + consumer.accept(-1, 1, -1); + }); + }); + + // Wrong node count + assertThrows(IllegalStateException.class, () -> { + new Bdd(1, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); // Only provides 1 node, but claims 2 + }); + }); + } + + @Test + void testArrayConstructorValidation() { + int[] vars = {-1}; + int[] highs = {1}; + int[] lows = {-1}; + + // Valid construction + assertDoesNotThrow(() -> { + new Bdd(vars, highs, lows, 1, 1, 1, 1); + }); + + // Null arrays + assertThrows(NullPointerException.class, () -> { + new Bdd(null, highs, lows, 1, 1, 1, 1); + }); + + assertThrows(NullPointerException.class, () -> { + new Bdd(vars, null, lows, 1, 1, 1, 1); + }); + + assertThrows(NullPointerException.class, () -> { + new Bdd(vars, highs, null, 1, 1, 1, 1); + }); + + // Mismatched array lengths + int[] shortArray = {}; + assertThrows(IllegalArgumentException.class, () -> { + new Bdd(shortArray, highs, lows, 1, 1, 1, 1); + }); + + // Node count exceeds array capacity + assertThrows(IllegalArgumentException.class, () -> { + new Bdd(vars, highs, lows, 2, 1, 1, 1); // nodeCount=2 but arrays have length 1 + }); + + // Root cannot be complemented (except -1) + assertThrows(IllegalArgumentException.class, () -> { + new Bdd(vars, highs, lows, 1, -2, 1, 1); + }); + + // Root -1 (FALSE) is allowed + assertDoesNotThrow(() -> { + new Bdd(vars, highs, lows, 1, -1, 1, 1); + }); + } + + @Test + void testGetterBoundsChecking() { + Bdd bdd = new Bdd(1, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, 1, -1); + }); + + // Valid indices + assertDoesNotThrow(() -> bdd.getVariable(0)); + assertDoesNotThrow(() -> bdd.getVariable(1)); + + // Out of bounds + assertThrows(IndexOutOfBoundsException.class, () -> bdd.getVariable(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> bdd.getVariable(2)); + + // Same for high/low + assertThrows(IndexOutOfBoundsException.class, () -> bdd.getHigh(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> bdd.getHigh(2)); + assertThrows(IndexOutOfBoundsException.class, () -> bdd.getLow(-1)); + assertThrows(IndexOutOfBoundsException.class, () -> bdd.getLow(2)); + } + + @Test + void testEquals() { + // Create two identical BDDs + Bdd bdd1 = new Bdd(2, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, 1, -1); + }); + + Bdd bdd2 = new Bdd(2, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, 1, -1); + }); + + // Same content should be equal + assertEquals(bdd1, bdd2); + assertEquals(bdd1.hashCode(), bdd2.hashCode()); + + // Self equality + assertEquals(bdd1, bdd1); + + // Different root ref (use TRUE terminal) + Bdd bdd3 = new Bdd(1, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, 1, -1); + }); + assertNotEquals(bdd1, bdd3); + + // Different root ref (use FALSE terminal) + Bdd bdd4 = new Bdd(-1, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, 1, -1); + }); + assertNotEquals(bdd1, bdd4); + + // Different node count + Bdd bdd5 = new Bdd(2, 1, 1, 3, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, 1, -1); + consumer.accept(0, -1, 1); + }); + assertNotEquals(bdd1, bdd5); + + // Different node content + Bdd bdd6 = new Bdd(2, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, -1, 1); // Different high/low + }); + assertNotEquals(bdd1, bdd6); + + // Different root ref (use result reference) + Bdd bdd7 = new Bdd(Bdd.RESULT_OFFSET + 0, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); + consumer.accept(0, 1, -1); + }); + assertNotEquals(bdd1, bdd7); + + // Null and different type + assertNotEquals(bdd1, null); + assertNotEquals(bdd1, "not a BDD"); + } + // Used to regenerate BDD test cases for errorfiles // @Test // void generateValidBddEncoding() { diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java index 443898921f4..1fbdc7c0ae7 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java @@ -75,4 +75,24 @@ private Bdd createSimpleBdd() { consumer.accept(0, Bdd.RESULT_OFFSET + 1, -1); // node 1: if cond true, return result 1, else FALSE }); } + + @Test + void testEmptyBddTrait() { + Parameters params = Parameters.builder().build(); + + Bdd bdd = new Bdd(-1, 0, 1, 1, consumer -> { + consumer.accept(-1, 1, -1); // terminal node only + }); + + BddTrait trait = BddTrait.builder() + .parameters(params) + .conditions(ListUtils.of()) + .results(ListUtils.of(NoMatchRule.INSTANCE)) + .bdd(bdd) + .build(); + + assertEquals(0, trait.getConditions().size()); + assertEquals(1, trait.getResults().size()); + assertEquals(-1, trait.getBdd().getRootRef()); // FALSE terminal + } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java index 80464db0c42..b9d05b1f15e 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java @@ -6,8 +6,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import java.util.ArrayList; +import java.util.Collections; import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; class NodeReversalTest { @@ -18,8 +23,7 @@ void testSingleNodeBdd() { consumer.accept(-1, 1, -1); // terminal }); - NodeReversal reversal = new NodeReversal(); - Bdd reversed = reversal.apply(original); + Bdd reversed = NodeReversal.reverse(original); // Should be unchanged (only 1 node, reversal returns as-is for <= 2 nodes). assertEquals(1, reversed.getNodeCount()); @@ -40,8 +44,7 @@ void testComplementEdges() { consumer.accept(1, 1, -1); // node 2: condition 1 }); - NodeReversal reversal = new NodeReversal(); - Bdd reversed = reversal.apply(original); + Bdd reversed = NodeReversal.reverse(original); // Mapping: 0->0, 1->2, 2->1 // Ref mapping: 2->3, 3->2, -2->-3 @@ -64,8 +67,7 @@ void testResultNodes() { consumer.accept(3, 1, -1); // node 3: result 1 }); - NodeReversal reversal = new NodeReversal(); - Bdd reversed = reversal.apply(original); + Bdd reversed = NodeReversal.reverse(original); assertEquals(4, reversed.getNodeCount()); assertEquals(4, reversed.getRootRef()); // root was ref 2, now ref 4 @@ -101,8 +103,7 @@ void testFourNodeExample() { consumer.accept(2, 1, -1); // node 3: }); - NodeReversal reversal = new NodeReversal(); - Bdd reversed = reversal.apply(original); + Bdd reversed = NodeReversal.reverse(original); // Mapping: 0->0, 1->3, 2->2, 3->1 // Ref mapping: 2->4, 3->3, 4->2 @@ -135,8 +136,7 @@ void testImmutability() { originalNodeValues[i * 3 + 2] = original.getLow(i); } - NodeReversal reversal = new NodeReversal(); - Bdd reversed = reversal.apply(original); + Bdd reversed = NodeReversal.reverse(original); // Verify original is unchanged assertEquals(originalNodeCount, original.getNodeCount()); @@ -161,8 +161,7 @@ void testTwoNodeBdd() { consumer.accept(0, 1, -1); // node 1: simple condition }); - NodeReversal reversal = new NodeReversal(); - Bdd reversed = reversal.apply(original); + Bdd reversed = NodeReversal.reverse(original); // Should be unchanged (reversal returns as-is for <= 2 nodes) assertEquals(2, reversed.getNodeCount()); @@ -177,4 +176,28 @@ void testTwoNodeBdd() { assertEquals(1, reversed.getHigh(1)); assertEquals(-1, reversed.getLow(1)); } + + @Test + void testBddTraitReversalReturnsOriginalForSmallBdd() { + // Test that small BDDs return the original trait unchanged + NodeReversal reversal = new NodeReversal(); + + // Create a BddTrait with a 2-node BDD + Bdd bdd = new Bdd(2, 1, 1, 2, consumer -> { + consumer.accept(-1, 1, -1); // node 0: terminal + consumer.accept(0, 1, -1); // node 1: simple condition + }); + + BddTrait originalTrait = BddTrait.builder() + .parameters(Parameters.builder().build()) + .conditions(new ArrayList<>()) + .results(Collections.singletonList(NoMatchRule.INSTANCE)) + .bdd(bdd) + .build(); + + BddTrait reversedTrait = reversal.apply(originalTrait); + + // Should return the exact same trait object for small BDDs + assertSame(originalTrait, reversedTrait); + } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java index 47f87d05439..005985861a8 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java @@ -7,8 +7,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -import java.util.Arrays; -import java.util.List; import org.junit.jupiter.api.Test; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.Identifier; @@ -23,7 +21,6 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.logic.TestHelpers; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; -import software.amazon.smithy.rulesengine.logic.cfg.ConditionData; // Does some basic checks, but doesn't get too specific so we can easily change the sifting algorithm. class SiftingOptimizationTest { @@ -48,17 +45,19 @@ void testBasicOptimization() { .build(); Cfg cfg = Cfg.from(ruleSet); - Bdd originalBdd = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()).compile(); + BddTrait originalTrait = BddTrait.from(cfg); SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); - Bdd optimizedBdd = optimizer.apply(originalBdd); + BddTrait optimizedTrait = optimizer.apply(originalTrait); // Basic checks - assertEquals(originalBdd.getConditionCount(), optimizedBdd.getConditionCount()); - assertEquals(originalBdd.getResultCount(), optimizedBdd.getResultCount()); + assertEquals(originalTrait.getConditions().size(), optimizedTrait.getConditions().size()); + assertEquals(originalTrait.getResults().size(), optimizedTrait.getResults().size()); + assertEquals(originalTrait.getBdd().getConditionCount(), optimizedTrait.getBdd().getConditionCount()); + assertEquals(originalTrait.getBdd().getResultCount(), optimizedTrait.getBdd().getResultCount()); // Size should be same or smaller - assertTrue(optimizedBdd.getNodeCount() <= originalBdd.getNodeCount()); + assertTrue(optimizedTrait.getBdd().getNodeCount() <= originalTrait.getBdd().getNodeCount()); } @Test @@ -89,23 +88,21 @@ void testDependenciesPreserved() { .build(); Cfg cfg = Cfg.from(ruleSet); - Bdd originalBdd = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()).compile(); + BddTrait originalTrait = BddTrait.from(cfg); SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); - Bdd optimizedBdd = optimizer.apply(originalBdd); + BddTrait optimizedTrait = optimizer.apply(originalTrait); - // Get conditions from the CFG to verify ordering - ConditionData conditionData = cfg.getConditionData(); - List conditions = Arrays.asList(conditionData.getConditions()); + // Verify the optimizer preserved the number of conditions + assertEquals(originalTrait.getConditions().size(), optimizedTrait.getConditions().size()); + assertEquals(originalTrait.getBdd().getConditionCount(), optimizedTrait.getBdd().getConditionCount()); - // The optimizer may have reordered conditions, but we need to check - // if it created a valid BDD with the same number of conditions - assertEquals(originalBdd.getConditionCount(), optimizedBdd.getConditionCount()); + // The fact that the optimization completes successfully and produces a valid BDD + // means dependencies were preserved (otherwise the BddCompiler would have failed + // during the optimization process). - // We can't directly check the ordering from the BDD anymore since it doesn't - // store conditions. The fact that the optimization completes successfully - // and produces a valid BDD means dependencies were preserved (otherwise - // the BddCompiler would have failed during the optimization process). + // Also verify results are preserved + assertEquals(originalTrait.getResults(), optimizedTrait.getResults()); } @Test @@ -121,14 +118,15 @@ void testSingleCondition() { EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).addRule(rule).build(); Cfg cfg = Cfg.from(ruleSet); - Bdd originalBdd = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()).compile(); + BddTrait originalTrait = BddTrait.from(cfg); SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); - Bdd optimizedBdd = optimizer.apply(originalBdd); + BddTrait optimizedTrait = optimizer.apply(originalTrait); // Should be unchanged or very similar - assertEquals(originalBdd.getNodeCount(), optimizedBdd.getNodeCount()); - assertEquals(1, optimizedBdd.getConditionCount()); + assertEquals(originalTrait.getBdd().getNodeCount(), optimizedTrait.getBdd().getNodeCount()); + assertEquals(1, optimizedTrait.getBdd().getConditionCount()); + assertEquals(originalTrait.getConditions(), optimizedTrait.getConditions()); } @Test @@ -138,13 +136,14 @@ void testEmptyRuleSet() { EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).build(); Cfg cfg = Cfg.from(ruleSet); - Bdd originalBdd = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()).compile(); + BddTrait originalTrait = BddTrait.from(cfg); SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); - Bdd optimizedBdd = optimizer.apply(originalBdd); + BddTrait optimizedTrait = optimizer.apply(originalTrait); - assertEquals(0, optimizedBdd.getConditionCount()); - assertEquals(originalBdd.getResultCount(), optimizedBdd.getResultCount()); + assertEquals(0, optimizedTrait.getBdd().getConditionCount()); + assertEquals(originalTrait.getBdd().getResultCount(), optimizedTrait.getBdd().getResultCount()); + assertEquals(originalTrait.getResults(), optimizedTrait.getResults()); } @Test @@ -184,19 +183,49 @@ void testLargeReduction() { .build(); Cfg cfg = Cfg.from(ruleSet); - Bdd originalBdd = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()).compile(); + BddTrait originalTrait = BddTrait.from(cfg); SiftingOptimization optimizer = SiftingOptimization.builder() .cfg(cfg) .granularEffort(100_000, 10) // Allow more aggressive optimization .build(); - Bdd optimizedBdd = optimizer.apply(originalBdd); + BddTrait optimizedTrait = optimizer.apply(originalTrait); // Should maintain correctness - assertEquals(originalBdd.getConditionCount(), optimizedBdd.getConditionCount()); - assertEquals(originalBdd.getResultCount(), optimizedBdd.getResultCount()); + assertEquals(originalTrait.getConditions().size(), optimizedTrait.getConditions().size()); + assertEquals(originalTrait.getBdd().getConditionCount(), optimizedTrait.getBdd().getConditionCount()); + assertEquals(originalTrait.getBdd().getResultCount(), optimizedTrait.getBdd().getResultCount()); + assertEquals(originalTrait.getResults(), optimizedTrait.getResults()); // Often achieves some reduction - assertTrue(optimizedBdd.getNodeCount() <= originalBdd.getNodeCount()); + assertTrue(optimizedTrait.getBdd().getNodeCount() <= originalTrait.getBdd().getNodeCount()); + } + + @Test + void testNoImprovementReturnsOriginal() { + // Test that when no improvement is found, the original trait is returned + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).addRule(rule).build(); + Cfg cfg = Cfg.from(ruleSet); + BddTrait originalTrait = BddTrait.from(cfg); + + SiftingOptimization optimizer = SiftingOptimization.builder() + .cfg(cfg) + .coarseEffort(1, 1) // Minimal effort to likely find no improvement + .build(); + + BddTrait optimizedTrait = optimizer.apply(originalTrait); + + // For simple cases with minimal optimization effort, should return the same trait object + if (optimizedTrait.getBdd().getNodeCount() == originalTrait.getBdd().getNodeCount()) { + assertTrue(optimizedTrait == originalTrait); + } } } From 3428992e55e731e724774a4f0647a6c8daaf6e38 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Thu, 31 Jul 2025 22:25:39 -0500 Subject: [PATCH 07/23] Create UniqueTable cache and cleanup BddBuilder --- .../rulesengine/logic/bdd/BddBuilder.java | 349 +++++++----------- .../rulesengine/logic/bdd/UniqueTable.java | 74 ++++ 2 files changed, 205 insertions(+), 218 deletions(-) create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/UniqueTable.java diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java index 64660e2c0a8..1c98a49e0c4 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java @@ -5,8 +5,6 @@ package software.amazon.smithy.rulesengine.logic.bdd; import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; /** * Binary Decision Diagram (BDD) builder with complement edges and multi-terminal support. @@ -31,33 +29,27 @@ final class BddBuilder { private static final int TRUE_REF = 1; private static final int FALSE_REF = -1; - // ITE operation cache for memoization - private final Map iteCache; - // Node storage: three separate arrays private int[] variables = new int[1024]; private int[] highs = new int[1024]; private int[] lows = new int[1024]; private int nodeCount; - // Unique table for node deduplication - private Map uniqueTable; + // Unique tables for node deduplication and ITE caching + private final UniqueTable uniqueTable; + private final UniqueTable iteCache; + // Track the boundary between conditions and results private int conditionCount = -1; - private final TripleKey mutableKey = new TripleKey(0, 0, 0); - /** * Creates a new BDD engine. */ public BddBuilder() { this.nodeCount = 1; - this.uniqueTable = new HashMap<>(); - this.iteCache = new HashMap<>(); - // Initialize with terminal node at index 0 - variables[0] = -1; - highs[0] = TRUE_REF; - lows[0] = FALSE_REF; + this.uniqueTable = new UniqueTable(); + this.iteCache = new UniqueTable(4096); // Larger initial capacity for ITE cache + initializeTerminalNode(); } int getNodeCount() { @@ -130,33 +122,34 @@ public int makeResult(int resultIndex) { public int makeNode(int var, int high, int low) { if (conditionCount >= 0 && (var < 0 || var >= conditionCount)) { throw new IllegalArgumentException("Variable out of bounds: " + var); - } - - // Reduction rule: if both branches are identical, skip this test - if (high == low) { + } else if (high == low) { + // Reduction rule: if both branches are identical, skip this test return high; } // Complement edge canonicalization: ensure complement only on low branch. // Don't apply this to result nodes or when branches contain results - boolean flip = false; - if (isComplement(low) && !isResult(high) && !isResult(low)) { + boolean flip = shouldFlip(high, low); + if (flip) { high = negate(high); low = negate(low); - flip = true; } // Check if this node already exists - mutableKey.update(var, high, low); - Integer existing = uniqueTable.get(mutableKey); - + Integer existing = uniqueTable.get(var, high, low); if (existing != null) { - int ref = toReference(existing); - return flip ? negate(ref) : ref; + return applyFlip(flip, existing); + } else { + return insertNode(var, high, low, flip); } + } - // Create new node - return insertNode(var, high, low, flip); + private boolean shouldFlip(int high, int low) { + return isComplement(low) && !isResult(high) && !isResult(low); + } + + private int applyFlip(boolean flip, int idx) { + return flip ? negate(toReference(idx)) : toReference(idx); } private int insertNode(int var, int high, int low, boolean flip) { @@ -168,9 +161,8 @@ private int insertNode(int var, int high, int low, boolean flip) { lows[idx] = low; nodeCount++; - uniqueTable.put(new TripleKey(var, high, low), idx); - int ref = toReference(idx); - return flip ? negate(ref) : ref; + uniqueTable.put(var, high, low, idx); + return applyFlip(flip, idx); } private void ensureCapacity() { @@ -231,6 +223,16 @@ public boolean isResult(int ref) { return ref >= Bdd.RESULT_OFFSET; } + /** + * Checks if a reference is a leaf (terminal or result). + * + * @param ref the reference to check + * @return true if this is a leaf node + */ + private boolean isLeaf(int ref) { + return Math.abs(ref) == TRUE_REF || ref >= Bdd.RESULT_OFFSET; + } + /** * Gets the variable index for a BDD node. * @@ -238,17 +240,13 @@ public boolean isResult(int ref) { * @return the variable index, or -1 for terminals */ public int getVariable(int ref) { - if (isTerminal(ref)) { + if (isLeaf(ref)) { return -1; - } else if (isResult(ref)) { - return -1; // Result terminals are leaves and don't test variables - } else { - int nodeIndex = Math.abs(ref) - 1; - if (nodeIndex >= nodeCount || nodeIndex < 0) { - throw new IllegalStateException("Invalid node index: " + nodeIndex); - } - return variables[nodeIndex]; } + + int nodeIndex = Math.abs(ref) - 1; + validateNodeIndex(nodeIndex); + return variables[nodeIndex]; } /** @@ -261,15 +259,13 @@ public int getVariable(int ref) { */ public int cofactor(int bdd, int varIndex, boolean value) { // Terminals and results are unaffected by cofactoring - if (isTerminal(bdd) || isResult(bdd)) { + if (isLeaf(bdd)) { return bdd; } boolean complemented = isComplement(bdd); int nodeIndex = toNodeIndex(bdd); - if (nodeIndex >= nodeCount || nodeIndex < 0) { - throw new IllegalStateException("Invalid node index: " + nodeIndex); - } + validateNodeIndex(nodeIndex); int nodeVar = variables[nodeIndex]; @@ -329,36 +325,26 @@ public int or(int f, int g) { * @throws IllegalArgumentException if f is a result terminal */ public int ite(int f, int g, int h) { - // Normalize: if condition is complemented, swap branches - if (isComplement(f)) { - f = negate(f); + // Normalize complement edge on f + if (f < 0) { + f = -f; int tmp = g; g = h; h = tmp; } - // Terminal cases and validation. - if (isResult(f)) { - throw new IllegalArgumentException("Condition f must be boolean, not a result terminal"); - } else if (f == TRUE_REF) { - return g; - } else if (f == FALSE_REF) { - return h; - } else if (g == h) { + // Quick terminal cases + if (f == TRUE_REF || g == h) { return g; - } - - // Boolean-specific optimizations (don't apply to result terminals) - if (!(isResult(g) || isResult(h))) { - // Standard Boolean identities + } else if (isResult(f)) { + throw new IllegalArgumentException("Condition f must be boolean, not a result terminal"); + } else if (!isResult(g) && !isResult(h)) { + // Boolean-only identities if (g == TRUE_REF && h == FALSE_REF) { return f; } else if (g == FALSE_REF && h == TRUE_REF) { return negate(f); - } else if (isComplement(g) && isComplement(h) && !isResult(negate(g)) && !isResult(negate(h))) { - // Factor out common complement only if the negated values aren't results - return negate(ite(f, negate(g), negate(h))); - } else if (g == f) { // Simplifications when f appears in branches + } else if (g == f) { return or(f, h); } else if (h == f) { return and(f, g); @@ -366,39 +352,31 @@ public int ite(int f, int g, int h) { return and(negate(f), h); } else if (h == negate(f)) { return or(negate(f), g); + } else if (isComplement(g) && isComplement(h)) { + // Factor out common complement + return negate(ite(f, negate(g), negate(h))); } } - // Check cache using the mutable key. - Integer cached = iteCache.get(mutableKey.update(f, g, h)); + // Check cache + Integer cached = iteCache.get(f, g, h); if (cached != null) { return cached; } - // Create the actual key, and reserve cache slot to handle recursive calls - TripleKey key = new TripleKey(f, g, h); - iteCache.put(key, 0); // invalid place holder + // Reserve cache slot to handle recursive calls + iteCache.put(f, g, h, 0); // placeholder - // Shannon expansion: find the top variable + // Shannon expansion int v = getTopVariable(f, g, h); - - // Compute cofactors - int f0 = cofactor(f, v, false); - int f1 = cofactor(f, v, true); - int g0 = cofactor(g, v, false); - int g1 = cofactor(g, v, true); - int h0 = cofactor(h, v, false); - int h1 = cofactor(h, v, true); - - // Recursive ITE on cofactors - int r0 = ite(f0, g0, h0); - int r1 = ite(f1, g1, h1); + int r0 = ite(cofactor(f, v, false), cofactor(g, v, false), cofactor(h, v, false)); + int r1 = ite(cofactor(f, v, true), cofactor(g, v, true), cofactor(h, v, true)); // Build result node int result = makeNode(v, r1, r0); // Update cache with actual result - iteCache.put(key, result); + iteCache.put(f, g, h, result); return result; } @@ -409,41 +387,40 @@ public int ite(int f, int g, int h) { * @return the reduced BDD root */ public int reduce(int rootRef) { - // Quick exit for terminals/results - if (isTerminal(rootRef) || isResult(rootRef)) { + if (isLeaf(rootRef)) { return rootRef; } - boolean rootComp = isComplement(rootRef); + // Peel off complement on the root + boolean rootComp = rootRef < 0; int absRoot = rootComp ? negate(rootRef) : rootRef; - // Prep new storage + // Allocate new tables int[] newVariables = new int[nodeCount]; int[] newHighs = new int[nodeCount]; int[] newLows = new int[nodeCount]; - Map newUnique = new HashMap<>(nodeCount * 2); - // Initialize terminal node + // Clear and reuse the existing unique table + uniqueTable.clear(); + + // Initialize the terminal node newVariables[0] = -1; newHighs[0] = TRUE_REF; newLows[0] = FALSE_REF; - // Create a mutable counter to track nodes added - int[] newNodeCounter = new int[] {1}; // Start at 1 (terminal already added) - - // Mapping array + // Prepare the visitation map int[] oldToNew = new int[nodeCount]; Arrays.fill(oldToNew, -1); + int[] newCount = {1}; // start after terminal - // Recurse - int newRoot = reduceRec(absRoot, oldToNew, newVariables, newHighs, newLows, newNodeCounter, newUnique); + // Recursively rebuild + int newRoot = reduceRec(absRoot, oldToNew, newVariables, newHighs, newLows, newCount); - // Swap in - use the actual count of nodes created - this.variables = Arrays.copyOf(newVariables, newNodeCounter[0]); - this.highs = Arrays.copyOf(newHighs, newNodeCounter[0]); - this.lows = Arrays.copyOf(newLows, newNodeCounter[0]); - this.nodeCount = newNodeCounter[0]; - this.uniqueTable = newUnique; + // Swap in the new tables + this.variables = newVariables; + this.highs = newHighs; + this.lows = newLows; + this.nodeCount = newCount[0]; clearCaches(); return rootComp ? negate(newRoot) : newRoot; @@ -455,119 +432,79 @@ private int reduceRec( int[] newVariables, int[] newHighs, int[] newLows, - int[] newNodeCounter, - Map newUnique + int[] newCount ) { - // Handle terminals and results first - if (isTerminal(ref)) { - return ref; - } - - // Handle result references (not stored as nodes) - if (isResult(ref)) { + if (isLeaf(ref)) { return ref; } // Peel complement - boolean comp = isComplement(ref); + boolean comp = ref < 0; int abs = comp ? negate(ref) : ref; int idx = toNodeIndex(abs); - // Bounds check against nodeCount, not array length - if (idx >= nodeCount || idx < 0) { - throw new IllegalStateException("Invalid node index: " + idx + " (nodeCount=" + nodeCount + ")"); - } - - // Already processed? + // If already mapped, return it int mapped = oldToNew[idx]; if (mapped != -1) { return comp ? negate(mapped) : mapped; } - // Process children + // Recurse on children int var = variables[idx]; - int hiNew = reduceRec(highs[idx], oldToNew, newVariables, newHighs, newLows, newNodeCounter, newUnique); - int loNew = reduceRec(lows[idx], oldToNew, newVariables, newHighs, newLows, newNodeCounter, newUnique); + int hiNew = reduceRec(highs[idx], oldToNew, newVariables, newHighs, newLows, newCount); + int loNew = reduceRec(lows[idx], oldToNew, newVariables, newHighs, newLows, newCount); - // Apply reduction rule + // Reduction rule int resultAbs; if (hiNew == loNew) { resultAbs = hiNew; } else { - resultAbs = makeNodeInNew(var, hiNew, loNew, newVariables, newHighs, newLows, newNodeCounter, newUnique); - } - - oldToNew[idx] = resultAbs; - return comp ? negate(resultAbs) : resultAbs; - } - - private int makeNodeInNew( - int var, - int hi, - int lo, - int[] newVariables, - int[] newHighs, - int[] newLows, - int[] newNodeCounter, - Map newUnique - ) { - if (hi == lo) { - return hi; - } - - // Canonicalize complement edges (but not for result nodes) - boolean comp = false; - if (!isResult(hi) && !isResult(lo) && isComplement(lo)) { - hi = negate(hi); - lo = negate(lo); - comp = true; - } - - // Check if node already exists in new structure - Integer existing = newUnique.get(mutableKey.update(var, hi, lo)); - if (existing != null) { - int ref = toReference(existing); - return comp ? negate(ref) : ref; - } else { - int idx = newNodeCounter[0]; - - if (idx >= newVariables.length) { - throw new IllegalStateException("Insufficient space allocated for reduction"); + // Canonicalize complement edges on the low branch + boolean flip = shouldFlip(hiNew, loNew); + if (flip) { + hiNew = negate(hiNew); + loNew = negate(loNew); } - newVariables[idx] = var; - newHighs[idx] = hi; - newLows[idx] = lo; - - newUnique.put(new TripleKey(var, hi, lo), idx); - newNodeCounter[0]++; // Increment the node counter + // Lookup or create a new node + Integer existing = uniqueTable.get(var, hiNew, loNew); + if (existing != null) { + resultAbs = toReference(existing); + } else { + int nodeIdx = newCount[0]++; + newVariables[nodeIdx] = var; + newHighs[nodeIdx] = hiNew; + newLows[nodeIdx] = loNew; + uniqueTable.put(var, hiNew, loNew, nodeIdx); + resultAbs = toReference(nodeIdx); + } - int ref = toReference(idx); - return comp ? negate(ref) : ref; + if (flip) { + resultAbs = negate(resultAbs); + } } + + oldToNew[idx] = resultAbs; + return comp ? negate(resultAbs) : resultAbs; } /** * Finds the topmost variable among three BDDs. */ private int getTopVariable(int f, int g, int h) { - int varF = getVariable(f); - int varG = getVariable(g); - int varH = getVariable(h); - - // Filter out -1 (terminal marker) and find minimum - int min = Integer.MAX_VALUE; - if (varF >= 0 && varF < min) { - min = varF; - } - if (varG >= 0 && varG < min) { - min = varG; - } - if (varH >= 0 && varH < min) { - min = varH; - } + int minVar = Integer.MAX_VALUE; + minVar = updateMinVariable(minVar, f); + minVar = updateMinVariable(minVar, g); + minVar = updateMinVariable(minVar, h); + return (minVar == Integer.MAX_VALUE) ? -1 : minVar; + } - return min == Integer.MAX_VALUE ? -1 : min; + private int updateMinVariable(int currentMin, int ref) { + int absRef = Math.abs(ref); + if (absRef > 1 && absRef < Bdd.RESULT_OFFSET) { + return Math.min(currentMin, variables[absRef - 1]); + } + return currentMin; } /** @@ -589,10 +526,7 @@ public BddBuilder reset() { Arrays.fill(highs, 0, nodeCount, 0); Arrays.fill(lows, 0, nodeCount, 0); nodeCount = 1; - // Re-initialize terminal node - variables[0] = -1; - highs[0] = TRUE_REF; - lows[0] = FALSE_REF; + initializeTerminalNode(); conditionCount = -1; return this; } @@ -628,36 +562,15 @@ private int toReference(int nodeIndex) { return nodeIndex + 1; } - private static final class TripleKey { - private int a, b, c, hash; - - private TripleKey(int a, int b, int c) { - update(a, b, c); - } - - TripleKey update(int a, int b, int c) { - this.a = a; - this.b = b; - this.c = c; - int i = (a * 31 + b) * 31 + c; - this.hash = (i ^ (i >>> 16)); - return this; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } else if (!(o instanceof TripleKey)) { - return false; - } - TripleKey k = (TripleKey) o; - return a == k.a && b == k.b && c == k.c; + private void validateNodeIndex(int nodeIndex) { + if (nodeIndex >= nodeCount || nodeIndex < 0) { + throw new IllegalStateException("Invalid node index: " + nodeIndex); } + } - @Override - public int hashCode() { - return hash; - } + private void initializeTerminalNode() { + variables[0] = -1; + highs[0] = TRUE_REF; + lows[0] = FALSE_REF; } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/UniqueTable.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/UniqueTable.java new file mode 100644 index 00000000000..d83e7843509 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/UniqueTable.java @@ -0,0 +1,74 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.HashMap; +import java.util.Map; + +/** + * A specialized hash table for BDD node deduplication using triple (var, high, low) keys. + */ +final class UniqueTable { + private final Map table; + private final TripleKey mutableKey = new TripleKey(0, 0, 0); + + public UniqueTable() { + this.table = new HashMap<>(); + } + + public UniqueTable(int initialCapacity) { + this.table = new HashMap<>(initialCapacity); + } + + public Integer get(int var, int high, int low) { + mutableKey.update(var, high, low); + return table.get(mutableKey); + } + + public void put(int var, int high, int low, int nodeIndex) { + table.put(new TripleKey(var, high, low), nodeIndex); + } + + public void clear() { + table.clear(); + } + + public int size() { + return table.size(); + } + + private static final class TripleKey { + private int a, b, c, hash; + + private TripleKey(int a, int b, int c) { + update(a, b, c); + } + + TripleKey update(int a, int b, int c) { + this.a = a; + this.b = b; + this.c = c; + int i = (a * 31 + b) * 31 + c; + this.hash = (i ^ (i >>> 16)); + return this; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (!(o instanceof TripleKey)) { + return false; + } + TripleKey k = (TripleKey) o; + return a == k.a && b == k.b && c == k.c; + } + + @Override + public int hashCode() { + return hash; + } + } +} From 493349f4a6e0b1f73df6c63700ef5cf0ac80265a Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Fri, 1 Aug 2025 10:47:39 -0500 Subject: [PATCH 08/23] Simplify condition handling in CFG and BDD This also revealed a bug in the BDD compilation process that was causing negated nodes to get added twice. --- .../syntax/expressions/Expression.java | 37 +++- .../syntax/expressions/Reference.java | 2 +- .../functions/LibraryFunction.java | 11 +- .../expressions/literal/RecordLiteral.java | 11 +- .../expressions/literal/StringLiteral.java | 25 ++- .../expressions/literal/TupleLiteral.java | 11 +- .../rulesengine/logic/ConditionInfo.java | 59 ------ .../rulesengine/logic/ConditionInfoImpl.java | 103 ----------- .../rulesengine/logic/ConditionReference.java | 51 ++++-- .../rulesengine/logic/bdd/BddCompiler.java | 9 +- .../logic/bdd/BddEquivalenceChecker.java | 5 +- .../logic/bdd/ConditionDependencyGraph.java | 31 ++-- .../logic/bdd/ConditionOrderingStrategy.java | 7 +- .../logic/bdd/DefaultOrderingStrategy.java | 27 ++- .../logic/bdd/SiftingOptimization.java | 18 +- .../smithy/rulesengine/logic/cfg/Cfg.java | 69 ++++++- .../rulesengine/logic/cfg/CfgBuilder.java | 9 +- .../rulesengine/logic/cfg/ConditionData.java | 64 ------- .../logic/ConditionInfoImplTest.java | 126 ------------- .../logic/ConditionReferenceTest.java | 138 -------------- .../logic/bdd/BddCompilerTest.java | 4 +- .../bdd/ConditionDependencyGraphTest.java | 45 +++-- .../bdd/DefaultOrderingStrategyTest.java | 41 ++--- .../logic/bdd/OrderConstraintsTest.java | 31 +--- .../smithy/rulesengine/logic/cfg/CfgTest.java | 25 --- .../logic/cfg/ConditionDataTest.java | 170 ------------------ 26 files changed, 258 insertions(+), 871 deletions(-) delete mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfo.java delete mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImpl.java delete mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionData.java delete mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImplTest.java delete mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionReferenceTest.java delete mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDataTest.java diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java index b4157244604..7f8398dfa87 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java @@ -36,6 +36,8 @@ @SmithyUnstableApi public abstract class Expression extends SyntaxElement implements FromSourceLocation, ToNode, TypeCheck { private final SourceLocation sourceLocation; + private Integer cachedComplexity; + private Set cachedReferences; private Type cachedType; public Expression(SourceLocation sourceLocation) { @@ -147,10 +149,43 @@ public static Literal getLiteral(StringNode node) { * * @return variable references by name. */ - public Set getReferences() { + public final Set getReferences() { + if (cachedReferences == null) { + cachedReferences = Collections.unmodifiableSet(calculateReferences()); + } + return cachedReferences; + } + + /** + * Computes the references of an expression. + * + * @return the computed references. + */ + protected Set calculateReferences() { return Collections.emptySet(); } + /** + * Get the complexity heuristic of the expression, based on functions, references, etc. + * + * @return the complexity heuristic. + */ + public final int getComplexity() { + if (cachedComplexity == null) { + cachedComplexity = calculateComplexity(); + } + return cachedComplexity; + } + + /** + * Calculates the complexity of the expression, to be overridden by implementations. + * + * @return complexity estimate. + */ + protected int calculateComplexity() { + return 1; + } + /** * Invoke the {@link ExpressionVisitor} functions for this expression. * diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Reference.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Reference.java index 6db377dc8dc..e7341dfe745 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Reference.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Reference.java @@ -50,7 +50,7 @@ public String template() { } @Override - public Set getReferences() { + protected Set calculateReferences() { return Collections.singleton(getName().toString()); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java index 5d8f5166d7a..3c9bcfddbe5 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java @@ -48,7 +48,7 @@ public String getName() { } @Override - public Set getReferences() { + protected Set calculateReferences() { Set references = new LinkedHashSet<>(); for (Expression arg : getArguments()) { references.addAll(arg.getReferences()); @@ -245,4 +245,13 @@ private static boolean isReference(Expression arg) { } return false; } + + @Override + protected int calculateComplexity() { + int complexity = getFunctionDefinition().getCostHeuristic(); + for (Expression arg : getArguments()) { + complexity += arg.getComplexity(); + } + return complexity; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java index 5d70dff6e0c..92ecd99baf8 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java @@ -76,11 +76,20 @@ public Node toNode() { } @Override - public Set getReferences() { + protected Set calculateReferences() { Set references = new LinkedHashSet<>(); for (Literal value : members().values()) { references.addAll(value.getReferences()); } return references; } + + @Override + protected int calculateComplexity() { + int complexity = 1; + for (Literal value : members().values()) { + complexity += value.getComplexity(); + } + return complexity; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java index 8adb24f487b..f8597c51387 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java @@ -71,7 +71,7 @@ public Node toNode() { } @Override - public Set getReferences() { + protected Set calculateReferences() { Template template = value(); if (template.isStatic()) { return Collections.emptySet(); @@ -86,4 +86,27 @@ public Set getReferences() { return references; } + + @Override + protected int calculateComplexity() { + Template template = value(); + if (template.isStatic()) { + return 1; + } + + int complexity = 1; + if (template.getParts().size() > 1) { + // Multiple parts are expensive + complexity += 8; + } + + for (Template.Part part : template.getParts()) { + if (part instanceof Template.Dynamic) { + Template.Dynamic dynamic = (Template.Dynamic) part; + complexity += dynamic.toExpression().getComplexity(); + } + } + + return complexity; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java index 867f5ec6e4c..9c269228559 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java @@ -82,11 +82,20 @@ public Node toNode() { } @Override - public Set getReferences() { + protected Set calculateReferences() { Set references = new LinkedHashSet<>(); for (Literal member : members()) { references.addAll(member.getReferences()); } return references; } + + @Override + protected int calculateComplexity() { + int complexity = 1; + for (Literal member : members()) { + complexity += member.getComplexity(); + } + return complexity; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfo.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfo.java deleted file mode 100644 index 9b8d653356e..00000000000 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfo.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.logic; - -import java.util.Collections; -import java.util.Set; -import software.amazon.smithy.rulesengine.language.syntax.Identifier; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; - -/** - * Information about a condition. - */ -public interface ConditionInfo { - /** - * Create a new ConditionInfo from the given condition. - * - * @param condition Condition to compute. - * @return the created ConditionInfo. - */ - static ConditionInfo from(Condition condition) { - return new ConditionInfoImpl(condition); - } - - /** - * Get the underlying condition. - * - * @return condition. - */ - Condition getCondition(); - - /** - * Get the complexity of the condition. - * - * @return the complexity. - */ - default int getComplexity() { - return 1; - } - - /** - * Get the references used by the condition. - * - * @return the references. - */ - default Set getReferences() { - return Collections.emptySet(); - } - - /** - * Get the name of the variable this condition defines, if any, or null. - * - * @return the defined variable name or null. - */ - default String getReturnVariable() { - return getCondition().getResult().map(Identifier::toString).orElse(null); - } -} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImpl.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImpl.java deleted file mode 100644 index d493ae1c72a..00000000000 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImpl.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.logic; - -import java.util.Set; -import software.amazon.smithy.rulesengine.language.syntax.Identifier; -import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; -import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; - -/** - * Default implementation of {@link ConditionInfo} that computes condition metadata. - */ -final class ConditionInfoImpl implements ConditionInfo { - - private final Condition condition; - private final int complexity; - private final Set references; - - ConditionInfoImpl(Condition condition) { - this.condition = condition; - this.complexity = calculateComplexity(condition.getFunction()); - this.references = condition.getFunction().getReferences(); - } - - @Override - public Condition getCondition() { - return condition; - } - - @Override - public int getComplexity() { - return complexity; - } - - @Override - public Set getReferences() { - return references; - } - - @Override - public String getReturnVariable() { - return condition.getResult().map(Identifier::toString).orElse(null); - } - - @Override - public boolean equals(Object object) { - if (this == object) { - return true; - } else if (object == null || getClass() != object.getClass()) { - return false; - } else { - return condition.equals(((ConditionInfoImpl) object).condition); - } - } - - @Override - public int hashCode() { - return condition.hashCode(); - } - - @Override - public String toString() { - return condition.toString(); - } - - private static int calculateComplexity(Expression e) { - // Base complexity for this node - int complexity = 1; - - if (e instanceof StringLiteral) { - Template template = ((StringLiteral) e).value(); - if (!template.isStatic()) { - if (template.getParts().size() > 1) { - // Single dynamic part is cheap, but multiple parts are expensive - complexity += 8; - } - for (Template.Part part : template.getParts()) { - // Add complexity from dynamic parts - if (part instanceof Template.Dynamic) { - Template.Dynamic dynamic = (Template.Dynamic) part; - complexity += calculateComplexity(dynamic.toExpression()); - } - } - } - } else if (e instanceof GetAttr) { - complexity += calculateComplexity(((GetAttr) e).getTarget()) + 2; - } else if (e instanceof LibraryFunction) { - LibraryFunction l = (LibraryFunction) e; - complexity += l.getFunctionDefinition().getCostHeuristic(); - for (Expression arg : l.getArguments()) { - complexity += calculateComplexity(arg); - } - } - - return complexity; - } -} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionReference.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionReference.java index 94f1a34da75..1b2fbd7d248 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionReference.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionReference.java @@ -5,19 +5,22 @@ package software.amazon.smithy.rulesengine.logic; import java.util.Set; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; /** * A reference to a condition and whether it is negated. */ -public final class ConditionReference implements ConditionInfo { +public final class ConditionReference { - private final ConditionInfo delegate; + private final Condition condition; private final boolean negated; + private final String returnVar; - public ConditionReference(ConditionInfo delegate, boolean negated) { - this.delegate = delegate; + public ConditionReference(Condition condition, boolean negated) { + this.condition = condition; this.negated = negated; + this.returnVar = condition.getResult().map(Identifier::toString).orElse(null); } /** @@ -35,32 +38,48 @@ public boolean isNegated() { * @return returns the negated reference. */ public ConditionReference negate() { - return new ConditionReference(delegate, !negated); + return new ConditionReference(condition, !negated); } - @Override + /** + * Get the underlying condition. + * + * @return condition. + */ public Condition getCondition() { - return delegate.getCondition(); + return condition; } - @Override + /** + * Get the complexity of the condition. + * + * @return the complexity. + */ public int getComplexity() { - return delegate.getComplexity(); + return condition.getFunction().getComplexity(); } - @Override + /** + * Get the references used by the condition. + * + * @return the references. + */ public Set getReferences() { - return delegate.getReferences(); + return condition.getFunction().getReferences(); } - @Override + /** + * Get the name of the variable this condition defines, if any, or null. + * + * @return the defined variable name or null. + */ public String getReturnVariable() { - return delegate.getReturnVariable(); + return returnVar; } @Override public String toString() { - return (negated ? "!" : "") + delegate.toString(); + return (negated ? "!" : "") + condition.toString(); } @Override @@ -71,11 +90,11 @@ public boolean equals(Object object) { return false; } ConditionReference that = (ConditionReference) object; - return negated == that.negated && delegate.equals(that.delegate); + return negated == that.negated && condition.equals(that.condition); } @Override public int hashCode() { - return delegate.hashCode() ^ (negated ? 0x80000000 : 0); + return condition.hashCode() ^ (negated ? 0x80000000 : 0); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java index 12950082290..f9d97377654 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java @@ -14,11 +14,9 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; -import software.amazon.smithy.rulesengine.logic.ConditionInfo; import software.amazon.smithy.rulesengine.logic.ConditionReference; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; import software.amazon.smithy.rulesengine.logic.cfg.CfgNode; -import software.amazon.smithy.rulesengine.logic.cfg.ConditionData; import software.amazon.smithy.rulesengine.logic.cfg.ConditionNode; import software.amazon.smithy.rulesengine.logic.cfg.ResultNode; @@ -133,12 +131,7 @@ private int getOrCreateResultIndex(Rule rule) { } private void extractAndOrderConditions() { - // Extract conditions from CFG and order them. - ConditionData data = cfg.getConditionData(); - Map infos = data.getConditionInfos(); - orderedConditions = orderingStrategy.orderConditions(data.getConditions(), infos); - - // Build index map + orderedConditions = orderingStrategy.orderConditions(cfg.getConditions()); conditionToIndex = new LinkedHashMap<>(); for (int i = 0; i < orderedConditions.size(); i++) { conditionToIndex.put(orderedConditions.get(i), i); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java index e83dd0c176e..32bc8778aa3 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java @@ -19,7 +19,6 @@ import software.amazon.smithy.rulesengine.logic.ConditionEvaluator; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; import software.amazon.smithy.rulesengine.logic.cfg.CfgNode; -import software.amazon.smithy.rulesengine.logic.cfg.ConditionData; import software.amazon.smithy.rulesengine.logic.cfg.ConditionNode; import software.amazon.smithy.rulesengine.logic.cfg.ResultNode; @@ -266,10 +265,8 @@ private void verifyCase(long mask) { } private Rule evaluateCfgWithMask(ConditionEvaluator maskEvaluator) { - // Get the condition data from CFG - ConditionData conditionData = cfg.getConditionData(); Map cfgConditionToIndex = new HashMap<>(); - Condition[] cfgConditions = conditionData.getConditions(); + Condition[] cfgConditions = cfg.getConditions(); for (int i = 0; i < cfgConditions.length; i++) { cfgConditionToIndex.put(cfgConditions[i], i); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraph.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraph.java index dcbba6596ad..d4a678cce99 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraph.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraph.java @@ -4,14 +4,15 @@ */ package software.amazon.smithy.rulesengine.logic.bdd; +import java.util.ArrayList; import java.util.Collections; import java.util.LinkedHashMap; import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Set; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.logic.ConditionInfo; /** * Immutable graph of dependencies between conditions. @@ -24,7 +25,7 @@ * */ final class ConditionDependencyGraph { - private final Map conditionInfos; + private final List conditions; private final Map> dependencies; private final Map> variableDefiners; private final Map> isSetConditions; @@ -32,27 +33,24 @@ final class ConditionDependencyGraph { /** * Creates a dependency graph by analyzing the given conditions. * - * @param conditionInfos metadata about each condition + * @param conditions the conditions to analyze */ - public ConditionDependencyGraph(Map conditionInfos) { - this.conditionInfos = Collections.unmodifiableMap(new LinkedHashMap<>(conditionInfos)); + public ConditionDependencyGraph(List conditions) { + this.conditions = Collections.unmodifiableList(new ArrayList<>(conditions)); this.variableDefiners = new LinkedHashMap<>(); this.isSetConditions = new LinkedHashMap<>(); // Categorize all conditions - for (Map.Entry entry : conditionInfos.entrySet()) { - Condition cond = entry.getKey(); - ConditionInfo info = entry.getValue(); - + for (Condition cond : conditions) { // Track variable definition - String definedVar = info.getReturnVariable(); - if (definedVar != null) { + if (cond.getResult().isPresent()) { + String definedVar = cond.getResult().get().toString(); variableDefiners.computeIfAbsent(definedVar, k -> new LinkedHashSet<>()).add(cond); } // Track isSet conditions if (isIsset(cond)) { - for (String var : info.getReferences()) { + for (String var : cond.getFunction().getReferences()) { isSetConditions.computeIfAbsent(var, k -> new LinkedHashSet<>()).add(cond); } } @@ -60,13 +58,10 @@ public ConditionDependencyGraph(Map conditionInfos) { // Compute dependencies Map> deps = new LinkedHashMap<>(); - for (Map.Entry entry : conditionInfos.entrySet()) { - Condition cond = entry.getKey(); - ConditionInfo info = entry.getValue(); - + for (Condition cond : conditions) { Set condDeps = new LinkedHashSet<>(); - for (String usedVar : info.getReferences()) { + for (String usedVar : cond.getFunction().getReferences()) { // Must come after any condition that defines this variable condDeps.addAll(variableDefiners.getOrDefault(usedVar, Collections.emptySet())); @@ -101,7 +96,7 @@ public Set getDependencies(Condition condition) { * @return the number of conditions */ public int size() { - return conditionInfos.size(); + return conditions.size(); } private static boolean isIsset(Condition cond) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionOrderingStrategy.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionOrderingStrategy.java index c0cd4fe900d..06f9e48cd59 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionOrderingStrategy.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionOrderingStrategy.java @@ -5,9 +5,7 @@ package software.amazon.smithy.rulesengine.logic.bdd; import java.util.List; -import java.util.Map; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.logic.ConditionInfo; /** * Strategy interface for ordering conditions in a BDD. @@ -18,10 +16,9 @@ interface ConditionOrderingStrategy { * Orders the given conditions for BDD construction. * * @param conditions array of conditions to order - * @param conditionInfos metadata about each condition * @return ordered list of conditions */ - List orderConditions(Condition[] conditions, Map conditionInfos); + List orderConditions(Condition[] conditions); /** * Default ordering strategy that uses the existing ConditionOrderer. @@ -38,6 +35,6 @@ static ConditionOrderingStrategy defaultOrdering() { * @return a fixed ordering strategy. */ static ConditionOrderingStrategy fixed(List ordering) { - return (conditions, infos) -> ordering; + return conditions -> ordering; } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategy.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategy.java index cd174a81b54..b9cc2653399 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategy.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategy.java @@ -5,15 +5,14 @@ package software.amazon.smithy.rulesengine.logic.bdd; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Set; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.logic.ConditionInfo; /** * Orders conditions for BDD construction while respecting variable dependencies. @@ -26,17 +25,16 @@ * */ final class DefaultOrderingStrategy { - private DefaultOrderingStrategy() {} - static List orderConditions(Condition[] conditions, Map conditionInfos) { - return sort(conditions, new ConditionDependencyGraph(conditionInfos), conditionInfos); + static List orderConditions(Condition[] conditions) { + ConditionDependencyGraph deps = new ConditionDependencyGraph(Arrays.asList(conditions)); + return sort(conditions, deps); } private static List sort( Condition[] conditions, - ConditionDependencyGraph deps, - Map infos + ConditionDependencyGraph deps ) { List result = new ArrayList<>(); Set visited = new HashSet<>(); @@ -52,16 +50,18 @@ private static List sort( // isSet() before everything else .thenComparingInt(c -> c.getFunction().getFunctionDefinition() == IsSet.getDefinition() ? 0 : 1) // variable-defining conditions first - .thenComparingInt(c -> infos.get(c).getReturnVariable() != null ? 0 : 1) + .thenComparingInt(c -> c.getResult().isPresent() ? 0 : 1) // fewer references first - .thenComparingInt(c -> infos.get(c).getReferences().size()) + .thenComparingInt(c -> c.getFunction().getReferences().size()) + // lower complexity first + .thenComparingInt(c -> c.getFunction().getComplexity()) // stable tie-breaker .thenComparing(Condition::toString)); // Visit in priority order for (Condition cond : queue) { if (!visited.contains(cond)) { - visit(cond, deps, visited, visiting, result, infos); + visit(cond, deps, visited, visiting, result); } } @@ -73,8 +73,7 @@ private static void visit( ConditionDependencyGraph depGraph, Set visited, Set visiting, - List result, - Map infos + List result ) { if (visiting.contains(cond)) { throw new IllegalStateException("Circular dependency detected involving: " + cond); @@ -90,10 +89,10 @@ private static void visit( Set deps = depGraph.getDependencies(cond); if (!deps.isEmpty()) { List sortedDeps = new ArrayList<>(deps); - sortedDeps.sort(Comparator.comparingInt(c -> infos.get(c).getComplexity())); + sortedDeps.sort(Comparator.comparingInt(c -> c.getFunction().getComplexity())); for (Condition dep : sortedDeps) { - visit(dep, depGraph, visited, visiting, result, infos); + visit(dep, depGraph, visited, visiting, result); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java index fa8902e5de5..d71eef0a349 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -7,18 +7,13 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; -import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; import java.util.concurrent.ForkJoinPool; import java.util.function.Function; import java.util.logging.Logger; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; -import software.amazon.smithy.rulesengine.logic.ConditionInfo; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; -import software.amazon.smithy.rulesengine.logic.cfg.CfgNode; -import software.amazon.smithy.rulesengine.logic.cfg.ConditionNode; import software.amazon.smithy.utils.SmithyBuilder; /** @@ -52,7 +47,6 @@ public final class SiftingOptimization implements Function { private final Cfg cfg; private final ConditionDependencyGraph dependencyGraph; - private final Map conditionInfos; // Tiered optimization settings private final int coarseMinNodes; @@ -87,17 +81,7 @@ private SiftingOptimization(Builder builder) { this.mediumMaxPasses = builder.mediumMaxPasses; this.granularMaxNodes = builder.granularMaxNodes; this.granularMaxPasses = builder.granularMaxPasses; - - // Extract condition infos from CFG - this.conditionInfos = new LinkedHashMap<>(); - for (CfgNode node : cfg) { - if (node instanceof ConditionNode) { - ConditionInfo info = ((ConditionNode) node).getCondition(); - conditionInfos.put(info.getCondition(), info); - } - } - - this.dependencyGraph = new ConditionDependencyGraph(conditionInfos); + this.dependencyGraph = new ConditionDependencyGraph(Arrays.asList(cfg.getConditions())); } /** diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java index 4d3df58ccfe..14d0af25c2b 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java @@ -11,6 +11,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; @@ -39,7 +40,10 @@ public final class Cfg implements Iterable { private final EndpointRuleSet ruleSet; private final CfgNode root; - private ConditionData data; + + // Lazily computed condition data + private Condition[] conditions; + private Map conditionToIndex; Cfg(EndpointRuleSet ruleSet, CfgNode root) { this.ruleSet = ruleSet; @@ -61,17 +65,64 @@ public static Cfg from(EndpointRuleSet ruleSet) { } /** - * Get the condition data of the CFG. + * Gets all unique conditions in the CFG, in the order they were discovered. + * + * @return array of conditions + */ + public Condition[] getConditions() { + ensureConditionsExtracted(); + return conditions; + } + + /** + * Gets the index of a condition in the conditions array. * - * @return the lazily created and cached prepared condition data. + * @param condition the condition to look up + * @return the index, or null if not found */ - public ConditionData getConditionData() { - ConditionData result = data; - if (result == null) { - result = ConditionData.from(this); - data = result; + public Integer getConditionIndex(Condition condition) { + ensureConditionsExtracted(); + return conditionToIndex.get(condition); + } + + /** + * Gets the number of unique conditions in the CFG. + * + * @return the condition count + */ + public int getConditionCount() { + ensureConditionsExtracted(); + return conditions.length; + } + + private void ensureConditionsExtracted() { + if (conditions == null) { + extractConditions(); + } + } + + private synchronized void extractConditions() { + if (conditions != null) { + return; } - return result; + + List conditionList = new ArrayList<>(); + Map indexMap = new LinkedHashMap<>(); + + for (CfgNode node : this) { + if (node instanceof ConditionNode) { + ConditionNode condNode = (ConditionNode) node; + Condition condition = condNode.getCondition().getCondition(); + + if (!indexMap.containsKey(condition)) { + indexMap.put(condition, conditionList.size()); + conditionList.add(condition); + } + } + } + + this.conditions = conditionList.toArray(new Condition[0]); + this.conditionToIndex = indexMap; } public EndpointRuleSet getRuleSet() { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java index 3275c060f44..9bedc35478b 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java @@ -20,7 +20,6 @@ import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; -import software.amazon.smithy.rulesengine.logic.ConditionInfo; import software.amazon.smithy.rulesengine.logic.ConditionReference; /** @@ -36,7 +35,6 @@ public final class CfgBuilder { private final Map nodeCache = new HashMap<>(); // Condition and result canonicalization - private final Map conditionToInfo = new HashMap<>(); private final Map conditionToReference = new HashMap<>(); private final Map resultCache = new HashMap<>(); private final Map resultNodeCache = new HashMap<>(); @@ -114,7 +112,7 @@ public ConditionReference createConditionReference(Condition condition) { // Check if we already have the non-negated version ConditionReference existing = conditionToReference.get(canonical); if (existing != null) { - // Reuse the existing ConditionInfo, just negate the reference + // Reuse the existing Condition, just negate the reference ConditionReference negatedReference = existing.negate(); conditionToReference.put(condition, negatedReference); return negatedReference; @@ -132,11 +130,8 @@ public ConditionReference createConditionReference(Condition condition) { negated = !negated; } - // Get or create the ConditionInfo - ConditionInfo info = conditionToInfo.computeIfAbsent(canonical, ConditionInfo::from); - // Create the reference (possibly negated) - ConditionReference reference = new ConditionReference(info, negated); + ConditionReference reference = new ConditionReference(canonical, negated); // Cache the reference under the original key conditionToReference.put(condition, reference); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionData.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionData.java deleted file mode 100644 index 4a8ee1e7a3b..00000000000 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionData.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.logic.cfg; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.logic.ConditionInfo; - -/** - * Extracts and indexes condition data from a CFG. - */ -public final class ConditionData { - private final Condition[] conditions; - private final Map conditionToIndex; - private final Map conditionInfos; - - private ConditionData(Condition[] conditions, Map index, Map infos) { - this.conditions = conditions; - this.conditionToIndex = index; - this.conditionInfos = infos; - } - - /** - * Extracts and indexes all conditions from a CFG. - * - * @param cfg the control flow graph to process - * @return ConditionData containing indexed conditions - */ - public static ConditionData from(Cfg cfg) { - List conditionList = new ArrayList<>(); - Map indexMap = new LinkedHashMap<>(); - Map infoMap = new HashMap<>(); - - for (CfgNode node : cfg) { - if (node instanceof ConditionNode) { - ConditionNode condNode = (ConditionNode) node; - ConditionInfo info = condNode.getCondition(); - Condition condition = info.getCondition(); - - if (!indexMap.containsKey(condition)) { - indexMap.put(condition, conditionList.size()); - conditionList.add(condition); - infoMap.put(condition, info); - } - } - } - - return new ConditionData(conditionList.toArray(new Condition[0]), indexMap, infoMap); - } - - public Condition[] getConditions() { - return conditions; - } - - public Map getConditionInfos() { - return conditionInfos; - } -} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImplTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImplTest.java deleted file mode 100644 index 95c94170d0f..00000000000 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionInfoImplTest.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.logic; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import org.junit.jupiter.api.Test; -import software.amazon.smithy.rulesengine.language.syntax.Identifier; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.BooleanLiteral; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; - -public class ConditionInfoImplTest { - - @Test - void testSimpleIsSetCondition() { - Condition condition = Condition.builder() - .fn(IsSet.ofExpressions(Literal.of("{Region}"))) - .build(); - - ConditionInfo info = ConditionInfo.from(condition); - - assertEquals(condition, info.getCondition()); - - assertEquals(4, info.getComplexity()); - assertEquals(1, info.getReferences().size()); - assertTrue(info.getReferences().contains("Region")); - assertNull(info.getReturnVariable()); - } - - @Test - void testConditionWithVariableBinding() { - Condition condition = Condition.builder() - .fn(IsSet.ofExpressions(Literal.of("{Region}"))) - .result(Identifier.of("RegionExists")) - .build(); - - ConditionInfo info = ConditionInfo.from(condition); - - assertEquals("RegionExists", info.getReturnVariable()); - } - - @Test - void testComplexNestedCondition() { - // Test nested function calls - Condition condition = Condition.builder() - .fn(Not.ofExpressions( - BooleanEquals.ofExpressions( - IsSet.ofExpressions(Literal.of("{Region}")), - BooleanLiteral.of(true)))) - .build(); - - ConditionInfo info = ConditionInfo.from(condition); - - assertEquals(11, info.getComplexity()); - assertEquals(1, info.getReferences().size()); - assertTrue(info.getReferences().contains("Region")); - } - - @Test - void testTemplateStringComplexity() { - Condition condition = Condition.builder() - .fn(StringEquals.ofExpressions( - Literal.of("{Endpoint}"), - StringLiteral.of("https://{Service}.{Region}.amazonaws.com"))) - .build(); - - ConditionInfo info = ConditionInfo.from(condition); - - assertTrue(info.getComplexity() > StringEquals.getDefinition().getCostHeuristic()); - assertEquals(3, info.getReferences().size()); - assertTrue(info.getReferences().contains("Endpoint")); - assertTrue(info.getReferences().contains("Service")); - assertTrue(info.getReferences().contains("Region")); - } - - @Test - void testGetAttrNestedComplexity() { - Condition condition = Condition.builder() - .fn(GetAttr.ofExpressions(GetAttr.ofExpressions(Literal.of("{ComplexObject}"), "nested"), "value")) - .build(); - - ConditionInfo info = ConditionInfo.from(condition); - - assertEquals(8, info.getComplexity()); - assertEquals(1, info.getReferences().size()); - assertTrue(info.getReferences().contains("ComplexObject")); - } - - @Test - void testEquals() { - Condition condition1 = Condition.builder().fn(IsSet.ofExpressions(Literal.of("{Region}"))).build(); - Condition condition2 = Condition.builder().fn(IsSet.ofExpressions(Literal.of("{Region}"))).build(); - - ConditionInfo info1 = ConditionInfo.from(condition1); - ConditionInfo info2 = ConditionInfo.from(condition2); - - assertEquals(info1, info2); - assertEquals(info1.hashCode(), info2.hashCode()); - } - - @Test - void testToString() { - Condition condition = Condition.builder() - .fn(IsSet.ofExpressions(Literal.of("{Region}"))) - .result(Identifier.of("RegionExists")) - .build(); - - ConditionInfo info = ConditionInfo.from(condition); - - String str = info.toString(); - assertTrue(str.contains("isSet")); - assertTrue(str.contains("Region")); - assertTrue(str.contains("RegionExists")); - } -} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionReferenceTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionReferenceTest.java deleted file mode 100644 index 1430c15b599..00000000000 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/ConditionReferenceTest.java +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.logic; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import software.amazon.smithy.rulesengine.language.syntax.Identifier; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; - -public class ConditionReferenceTest { - - private ConditionInfo baseConditionInfo; - private Condition simpleCondition; - - @BeforeEach - void setUp() { - simpleCondition = Condition.builder().fn(IsSet.ofExpressions(Literal.of("{Region}"))).build(); - baseConditionInfo = ConditionInfo.from(simpleCondition); - } - - @Test - void testBasicConstruction() { - ConditionReference ref = new ConditionReference(baseConditionInfo, false); - - assertFalse(ref.isNegated()); - assertEquals(simpleCondition, ref.getCondition()); - } - - @Test - void testNegatedConstruction() { - ConditionReference ref = new ConditionReference(baseConditionInfo, true); - - assertTrue(ref.isNegated()); - assertEquals(simpleCondition, ref.getCondition()); - } - - @Test - void testNegateMethod() { - ConditionReference ref = new ConditionReference(baseConditionInfo, false); - ConditionReference negated = ref.negate(); - - assertFalse(ref.isNegated()); - assertTrue(negated.isNegated()); - assertEquals(ref.getCondition(), negated.getCondition()); - } - - @Test - void testDoubleNegation() { - ConditionReference ref = new ConditionReference(baseConditionInfo, false); - ConditionReference doubleNegated = ref.negate().negate(); - - assertFalse(doubleNegated.isNegated()); - assertEquals(ref.getCondition(), doubleNegated.getCondition()); - } - - @Test - void testGetReturnVariable() { - Condition condWithVariable = Condition.builder() - .fn(IsSet.ofExpressions(Literal.of("{Region}"))) - .result(Identifier.of("RegionSet")) - .build(); - - ConditionInfo info = ConditionInfo.from(condWithVariable); - ConditionReference ref = new ConditionReference(info, false); - - assertEquals("RegionSet", ref.getReturnVariable()); - } - - @Test - void testEquals() { - ConditionReference ref1 = new ConditionReference(baseConditionInfo, false); - ConditionReference ref2 = new ConditionReference(baseConditionInfo, false); - - assertEquals(ref1, ref2); - } - - @Test - void testNotEqualsWithDifferentNegation() { - ConditionReference ref1 = new ConditionReference(baseConditionInfo, false); - ConditionReference ref2 = new ConditionReference(baseConditionInfo, true); - - assertNotEquals(ref1, ref2); - } - - @Test - void testNotEqualsWithDifferentCondition() { - Condition otherCondition = Condition.builder().fn(IsSet.ofExpressions(Literal.of("{Bucket}"))).build(); - ConditionInfo otherInfo = ConditionInfo.from(otherCondition); - ConditionReference ref1 = new ConditionReference(baseConditionInfo, false); - ConditionReference ref2 = new ConditionReference(otherInfo, false); - - assertNotEquals(ref1, ref2); - } - - @Test - void testHashCode() { - ConditionReference ref1 = new ConditionReference(baseConditionInfo, false); - ConditionReference ref2 = new ConditionReference(baseConditionInfo, false); - - assertEquals(ref1.hashCode(), ref2.hashCode()); - } - - @Test - void testHashCodeDifferentForNegated() { - ConditionReference ref1 = new ConditionReference(baseConditionInfo, false); - ConditionReference ref2 = new ConditionReference(baseConditionInfo, true); - - // Hash codes should be different for negated vs non-negated - assertNotEquals(ref1.hashCode(), ref2.hashCode()); - } - - @Test - void testToString() { - ConditionReference ref = new ConditionReference(baseConditionInfo, false); - String str = ref.toString(); - - assertFalse(str.startsWith("!")); - assertTrue(str.contains("isSet")); - } - - @Test - void testToStringNegated() { - ConditionReference ref = new ConditionReference(baseConditionInfo, true); - String str = ref.toString(); - - assertTrue(str.startsWith("!")); - assertTrue(str.contains("isSet")); - } -} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java index ddd8fe02d84..0c6e57eac7b 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java @@ -21,7 +21,6 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; import software.amazon.smithy.rulesengine.logic.TestHelpers; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; -import software.amazon.smithy.rulesengine.logic.cfg.ConditionData; class BddCompilerTest { @@ -126,8 +125,7 @@ void testCompileWithCustomOrdering() { Cfg cfg = Cfg.from(ruleSet); // Get the actual conditions from the CFG after SSA transform - ConditionData conditionData = cfg.getConditionData(); - Condition[] cfgConditions = conditionData.getConditions(); + Condition[] cfgConditions = cfg.getConditions(); // Find the conditions that correspond to A and B Condition condA = null; diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraphTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraphTest.java index 90c92a4dd96..0932dec64c7 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraphTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraphTest.java @@ -7,8 +7,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -import java.util.HashMap; -import java.util.Map; +import java.util.Arrays; +import java.util.List; import java.util.Set; import org.junit.jupiter.api.Test; import software.amazon.smithy.rulesengine.language.syntax.Identifier; @@ -16,7 +16,6 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.logic.ConditionInfo; import software.amazon.smithy.rulesengine.logic.TestHelpers; class ConditionDependencyGraphTest { @@ -34,11 +33,8 @@ void testBasicVariableDependency() { .fn(BooleanEquals.ofExpressions(Expression.of("{hasRegion}"), Expression.of(true))) .build(); - Map conditionInfos = new HashMap<>(); - conditionInfos.put(definer, ConditionInfo.from(definer)); - conditionInfos.put(user, ConditionInfo.from(user)); - - ConditionDependencyGraph graph = new ConditionDependencyGraph(conditionInfos); + List conditions = Arrays.asList(definer, user); + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); // Definer has no dependencies assertTrue(graph.getDependencies(definer).isEmpty()); @@ -58,11 +54,8 @@ void testIsSetDependencyForNonIsSetCondition() { .fn(StringEquals.ofExpressions(Expression.of("{Region}"), Expression.of("us-east-1"))) .build(); - Map conditionInfos = new HashMap<>(); - conditionInfos.put(isSetCondition, ConditionInfo.from(isSetCondition)); - conditionInfos.put(userCondition, ConditionInfo.from(userCondition)); - - ConditionDependencyGraph graph = new ConditionDependencyGraph(conditionInfos); + List conditions = Arrays.asList(isSetCondition, userCondition); + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); // Non-isSet condition depends on isSet for undefined variables Set userDeps = graph.getDependencies(userCondition); @@ -90,12 +83,8 @@ void testMultipleDependencies() { BooleanEquals.ofExpressions(Expression.of("{hasBucket}"), Expression.of(true)))) .build(); - Map conditionInfos = new HashMap<>(); - conditionInfos.put(definer1, ConditionInfo.from(definer1)); - conditionInfos.put(definer2, ConditionInfo.from(definer2)); - conditionInfos.put(user, ConditionInfo.from(user)); - - ConditionDependencyGraph graph = new ConditionDependencyGraph(conditionInfos); + List conditions = Arrays.asList(definer1, definer2, user); + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); // User depends on both definers Set userDeps = graph.getDependencies(user); @@ -109,12 +98,22 @@ void testUnknownConditionReturnsEmptyDependencies() { Condition known = Condition.builder().fn(TestHelpers.isSet("Region")).build(); Condition unknown = Condition.builder().fn(TestHelpers.isSet("Bucket")).build(); - Map conditionInfos = new HashMap<>(); - conditionInfos.put(known, ConditionInfo.from(known)); - - ConditionDependencyGraph graph = new ConditionDependencyGraph(conditionInfos); + List conditions = Arrays.asList(known); + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); // Getting dependencies for unknown condition returns empty set assertTrue(graph.getDependencies(unknown).isEmpty()); } + + @Test + void testGraphSize() { + Condition cond1 = Condition.builder().fn(TestHelpers.isSet("A")).build(); + Condition cond2 = Condition.builder().fn(TestHelpers.isSet("B")).build(); + Condition cond3 = Condition.builder().fn(TestHelpers.isSet("C")).build(); + + List conditions = Arrays.asList(cond1, cond2, cond3); + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); + + assertEquals(3, graph.size()); + } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategyTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategyTest.java index 1f2d7a89958..f04b26fef19 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategyTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategyTest.java @@ -8,9 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import java.util.HashMap; import java.util.List; -import java.util.Map; import org.junit.jupiter.api.Test; import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; @@ -18,7 +16,6 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.logic.ConditionInfo; import software.amazon.smithy.rulesengine.logic.TestHelpers; class DefaultOrderingStrategyTest { @@ -32,9 +29,8 @@ void testIsSetComesFirst() { .build(); Condition[] conditions = {stringEqualsCond, isSetCond}; - Map infos = createInfoMap(conditions); - List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + List ordered = DefaultOrderingStrategy.orderConditions(conditions); // isSet should come first assertEquals(isSetCond, ordered.get(0)); @@ -52,9 +48,8 @@ void testVariableDefiningConditionsFirst() { Condition nonDefiner = Condition.builder().fn(TestHelpers.isSet("Bucket")).build(); Condition[] conditions = {nonDefiner, definer}; - Map infos = createInfoMap(conditions); - List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + List ordered = DefaultOrderingStrategy.orderConditions(conditions); // Variable-defining condition should come first assertEquals(definer, ordered.get(0)); @@ -75,9 +70,8 @@ void testDependencyOrdering() { .build(); Condition[] conditions = {user, definer}; - Map infos = createInfoMap(conditions); - List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + List ordered = DefaultOrderingStrategy.orderConditions(conditions); // Definer must come before user assertEquals(definer, ordered.get(0)); @@ -93,9 +87,8 @@ void testComplexityOrdering() { Condition complex = Condition.builder().fn(ParseUrl.ofExpressions(Literal.of("https://example.com"))).build(); Condition[] conditions = {complex, simple}; - Map infos = createInfoMap(conditions); - List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + List ordered = DefaultOrderingStrategy.orderConditions(conditions); // Simple should come before complex assertEquals(simple, ordered.get(0)); @@ -117,10 +110,9 @@ void testCircularDependencyDetection() { .build(); Condition[] conditions = {cond1, cond2}; - Map infos = createInfoMap(conditions); assertThrows(IllegalStateException.class, - () -> DefaultOrderingStrategy.orderConditions(conditions, infos)); + () -> DefaultOrderingStrategy.orderConditions(conditions)); } @Test @@ -142,9 +134,8 @@ void testMultiLevelDependencies() { // Mix up the order Condition[] conditions = {condC, condA, condB}; - Map infos = createInfoMap(conditions); - List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + List ordered = DefaultOrderingStrategy.orderConditions(conditions); assertEquals(condA, ordered.get(0)); assertEquals(condB, ordered.get(1)); @@ -158,9 +149,8 @@ void testStableSortForEqualPriority() { Condition cond2 = Condition.builder().fn(TestHelpers.isSet("Bucket")).build(); Condition[] conditions = {cond1, cond2}; - Map infos = createInfoMap(conditions); - List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + List ordered = DefaultOrderingStrategy.orderConditions(conditions); // Order should be deterministic based on toString assertEquals(2, ordered.size()); @@ -171,9 +161,8 @@ void testStableSortForEqualPriority() { @Test void testEmptyConditions() { Condition[] conditions = new Condition[0]; - Map infos = new HashMap<>(); - List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + List ordered = DefaultOrderingStrategy.orderConditions(conditions); assertEquals(0, ordered.size()); } @@ -183,9 +172,8 @@ void testSingleCondition() { Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); Condition[] conditions = {cond}; - Map infos = createInfoMap(conditions); - List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + List ordered = DefaultOrderingStrategy.orderConditions(conditions); assertEquals(1, ordered.size()); assertEquals(cond, ordered.get(0)); @@ -202,20 +190,11 @@ void testIsSetDependencyForSameVariable() { // Put value check first to test ordering Condition[] conditions = {valueCheck, isSet}; - Map infos = createInfoMap(conditions); - List ordered = DefaultOrderingStrategy.orderConditions(conditions, infos); + List ordered = DefaultOrderingStrategy.orderConditions(conditions); // isSet must come before value check assertEquals(isSet, ordered.get(0)); assertEquals(valueCheck, ordered.get(1)); } - - private Map createInfoMap(Condition... conditions) { - Map map = new HashMap<>(); - for (Condition cond : conditions) { - map.put(cond, ConditionInfo.from(cond)); - } - return map; - } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraintsTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraintsTest.java index 9cf4714c493..ea872c6dd7f 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraintsTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraintsTest.java @@ -11,15 +11,12 @@ import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; import org.junit.jupiter.api.Test; import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.logic.ConditionInfo; import software.amazon.smithy.rulesengine.logic.TestHelpers; class OrderConstraintsTest { @@ -30,10 +27,8 @@ void testIndependentConditions() { Condition cond1 = Condition.builder().fn(TestHelpers.isSet("Region")).build(); Condition cond2 = Condition.builder().fn(TestHelpers.isSet("Bucket")).build(); - Map infos = createInfoMap(cond1, cond2); - ConditionDependencyGraph graph = new ConditionDependencyGraph(infos); List conditions = Arrays.asList(cond1, cond2); - + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); OrderConstraints constraints = new OrderConstraints(graph, conditions); // Both conditions can be placed anywhere @@ -58,10 +53,8 @@ void testDependentConditions() { .fn(BooleanEquals.ofExpressions(Literal.of("{hasRegion}"), Literal.of(true))) .build(); - Map infos = createInfoMap(cond1, cond2); - ConditionDependencyGraph graph = new ConditionDependencyGraph(infos); List conditions = Arrays.asList(cond1, cond2); - + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); OrderConstraints constraints = new OrderConstraints(graph, conditions); // cond1 can only stay in place (cannot move past its dependent) @@ -90,10 +83,8 @@ void testChainedDependencies() { .fn(BooleanEquals.ofExpressions(Literal.of("{var2}"), Literal.of(false))) .build(); - Map infos = createInfoMap(condA, condB, condC); - ConditionDependencyGraph graph = new ConditionDependencyGraph(infos); List conditions = Arrays.asList(condA, condB, condC); - + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); OrderConstraints constraints = new OrderConstraints(graph, conditions); // A can only be at position 0 @@ -118,10 +109,8 @@ void testChainedDependencies() { @Test void testCanMoveToSamePosition() { Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); - Map infos = createInfoMap(cond); - ConditionDependencyGraph graph = new ConditionDependencyGraph(infos); List conditions = Collections.singletonList(cond); - + ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); OrderConstraints constraints = new OrderConstraints(graph, conditions); // Moving to same position is always allowed @@ -131,8 +120,8 @@ void testCanMoveToSamePosition() { @Test void testMismatchedSizes() { Condition cond1 = Condition.builder().fn(TestHelpers.isSet("Region")).build(); - Map infos = createInfoMap(cond1); - ConditionDependencyGraph graph = new ConditionDependencyGraph(infos); + List graphConditions = Collections.singletonList(cond1); + ConditionDependencyGraph graph = new ConditionDependencyGraph(graphConditions); // Try to create constraints with more conditions than in graph List conditions = Arrays.asList( @@ -141,12 +130,4 @@ void testMismatchedSizes() { assertThrows(IllegalArgumentException.class, () -> new OrderConstraints(graph, conditions)); } - - private Map createInfoMap(Condition... conditions) { - Map map = new HashMap<>(); - for (Condition cond : conditions) { - map.put(cond, ConditionInfo.from(cond)); - } - return map; - } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java index ba2921510d0..583ca4bac7c 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java @@ -122,31 +122,6 @@ void fromHandlesMultipleRules() { assertInstanceOf(ConditionNode.class, cfg.getRoot()); } - @Test - void getConditionDataCachesResult() { - Parameters params = Parameters.builder() - .addParameter(Parameter.builder().name("region").type(ParameterType.STRING).build()) - .build(); - - EndpointRule rule = EndpointRule.builder() - .condition(Condition.builder().fn(TestHelpers.isSet("region")).build()) - .endpoint(TestHelpers.endpoint("https://example.com")); - - EndpointRuleSet ruleSet = EndpointRuleSet.builder() - .parameters(params) - .addRule(rule) - .build(); - - Cfg cfg = Cfg.from(ruleSet); - - ConditionData data1 = cfg.getConditionData(); - ConditionData data2 = cfg.getConditionData(); - - assertNotNull(data1); - assertSame(data1, data2); - assertEquals(1, data1.getConditions().length); - } - @Test void iteratorVisitsAllNodes() { Parameters params = Parameters.builder() diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDataTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDataTest.java deleted file mode 100644 index 886c5eabf91..00000000000 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDataTest.java +++ /dev/null @@ -1,170 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.logic.cfg; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertSame; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import java.util.Map; -import org.junit.jupiter.api.Test; -import software.amazon.smithy.rulesengine.language.EndpointRuleSet; -import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; -import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; -import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; -import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; -import software.amazon.smithy.rulesengine.logic.ConditionInfo; -import software.amazon.smithy.rulesengine.logic.TestHelpers; - -class ConditionDataTest { - - @Test - void extractsConditionsFromSimpleCfg() { - // Build a simple ruleset with two conditions - Condition cond1 = Condition.builder() - .fn(TestHelpers.isSet("param1")) - .build(); - - // For stringEquals, we need to ensure param2 is set first - Rule rule = TreeRule.builder() - .condition(cond1) - .treeRule( - TreeRule.builder() - .condition(Condition.builder().fn(TestHelpers.isSet("param2")).build()) - .treeRule( - EndpointRule.builder() - .condition(Condition.builder() - .fn(TestHelpers.stringEquals("param2", "value")) - .build()) - .endpoint(TestHelpers.endpoint("https://example.com")))); - - Parameters params = Parameters.builder() - .addParameter(Parameter.builder().name("param1").type(ParameterType.STRING).build()) - .addParameter(Parameter.builder().name("param2").type(ParameterType.STRING).build()) - .build(); - - EndpointRuleSet ruleSet = EndpointRuleSet.builder() - .addRule(rule) - .parameters(params) - .build(); - - Cfg cfg = Cfg.from(ruleSet); - ConditionData data = ConditionData.from(cfg); - - // Verify condition extraction - Condition[] conditions = data.getConditions(); - assertEquals(3, conditions.length); // isSet(param1), isSet(param2), stringEquals - - // Verify condition infos - Map infos = data.getConditionInfos(); - assertEquals(3, infos.size()); - } - - @Test - void deduplicatesIdenticalConditions() { - // Create identical conditions used in different rules - Condition cond = Condition.builder() - .fn(TestHelpers.isSet("param")) - .build(); - - Rule rule1 = EndpointRule.builder().conditions(cond).endpoint(TestHelpers.endpoint("https://endpoint1.com")); - Rule rule2 = EndpointRule.builder().conditions(cond).endpoint(TestHelpers.endpoint("https://endpoint2.com")); - - Parameters params = Parameters.builder() - .addParameter(Parameter.builder().name("param").type(ParameterType.STRING).build()) - .build(); - - EndpointRuleSet ruleSet = EndpointRuleSet.builder() - .parameters(params) - .addRule(rule1) - .addRule(rule2) - .build(); - - Cfg cfg = Cfg.from(ruleSet); - ConditionData data = ConditionData.from(cfg); - - // Should only have one condition despite being used twice - assertEquals(1, data.getConditions().length); - assertEquals(cond, data.getConditions()[0]); - } - - @Test - void handlesNestedTreeRules() { - Condition cond1 = Condition.builder().fn(TestHelpers.isSet("param1")).build(); - Condition cond2 = Condition.builder().fn(TestHelpers.isSet("param2")).build(); - Condition cond3 = Condition.builder().fn(TestHelpers.isSet("param3")).build(); - - Rule innerRule = TreeRule.builder() - .conditions(cond2) - .treeRule(EndpointRule.builder() - .condition(cond3) - .endpoint(TestHelpers.endpoint("https://example.com"))); - - Rule outerRule = TreeRule.builder().condition(cond1).treeRule(innerRule); - - Parameters params = Parameters.builder() - .addParameter(Parameter.builder().name("param1").type(ParameterType.STRING).build()) - .addParameter(Parameter.builder().name("param2").type(ParameterType.STRING).build()) - .addParameter(Parameter.builder().name("param3").type(ParameterType.STRING).build()) - .build(); - - EndpointRuleSet ruleSet = EndpointRuleSet.builder() - .addRule(outerRule) - .parameters(params) - .build(); - - Cfg cfg = Cfg.from(ruleSet); - ConditionData data = ConditionData.from(cfg); - - // Should extract all three conditions - assertEquals(3, data.getConditions().length); - assertEquals(3, data.getConditionInfos().size()); - } - - @Test - void handlesCfgWithOnlyResults() { - // Rule with no conditions, just a result - Rule rule = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://default.com")); - EndpointRuleSet ruleSet = EndpointRuleSet.builder() - .parameters(Parameters.builder().build()) - .addRule(rule) - .build(); - - Cfg cfg = Cfg.from(ruleSet); - ConditionData data = ConditionData.from(cfg); - - // Should have no conditions - assertEquals(0, data.getConditions().length); - assertTrue(data.getConditionInfos().isEmpty()); - } - - @Test - void cachesResultOnCfg() { - Parameters params = Parameters.builder() - .addParameter(Parameter.builder().name("param").type(ParameterType.STRING).build()) - .build(); - - EndpointRule rule = EndpointRule.builder() - .condition(Condition.builder().fn(TestHelpers.isSet("param")).build()) - .endpoint(TestHelpers.endpoint("https://example.com")); - EndpointRuleSet ruleSet = EndpointRuleSet.builder() - .parameters(params) - .addRule(rule) - .build(); - Cfg cfg = Cfg.from(ruleSet); - - // First call should create the data - ConditionData data1 = cfg.getConditionData(); - assertNotNull(data1); - - // Second call should return the same instance - ConditionData data2 = cfg.getConditionData(); - assertSame(data1, data2); - } -} From 777203df5934e3ace8ffc36c6f34bc3d7ffb5769 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Fri, 1 Aug 2025 11:56:08 -0500 Subject: [PATCH 09/23] Remove unused methods from ConditionReference --- .../rulesengine/logic/ConditionReference.java | 31 ------------------- .../rulesengine/logic/cfg/CfgBuilderTest.java | 4 +-- 2 files changed, 2 insertions(+), 33 deletions(-) diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionReference.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionReference.java index 1b2fbd7d248..c41c1d76e71 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionReference.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/ConditionReference.java @@ -4,8 +4,6 @@ */ package software.amazon.smithy.rulesengine.logic; -import java.util.Set; -import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; /** @@ -15,12 +13,10 @@ public final class ConditionReference { private final Condition condition; private final boolean negated; - private final String returnVar; public ConditionReference(Condition condition, boolean negated) { this.condition = condition; this.negated = negated; - this.returnVar = condition.getResult().map(Identifier::toString).orElse(null); } /** @@ -50,33 +46,6 @@ public Condition getCondition() { return condition; } - /** - * Get the complexity of the condition. - * - * @return the complexity. - */ - public int getComplexity() { - return condition.getFunction().getComplexity(); - } - - /** - * Get the references used by the condition. - * - * @return the references. - */ - public Set getReferences() { - return condition.getFunction().getReferences(); - } - - /** - * Get the name of the variable this condition defines, if any, or null. - * - * @return the defined variable name or null. - */ - public String getReturnVariable() { - return returnVar; - } - @Override public String toString() { return (negated ? "!" : "") + condition.toString(); diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java index 2099989d8f1..225feabd34e 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java @@ -234,7 +234,7 @@ void createConditionReferenceHandlesVariableBinding() { ConditionReference ref = builder.createConditionReference(cond); assertNotNull(ref); - assertEquals("parsedUrl", ref.getReturnVariable()); + assertEquals(cond, ref.getCondition()); } @Test @@ -275,6 +275,6 @@ void createConditionReferenceIgnoresNegationWithVariableBinding() { // Should not be treated as simple negation due to variable binding assertFalse(ref.isNegated()); assertInstanceOf(Not.class, ref.getCondition().getFunction()); - assertEquals("notRegionSet", ref.getReturnVariable()); + assertEquals(negatedWithBinding, ref.getCondition()); } } From 8afec2dd46d47f637a11e8011bfd7cda852cb09b Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Fri, 1 Aug 2025 19:40:41 -0500 Subject: [PATCH 10/23] Remove varint encoding and address PR feedback The varint encoding does help compact the binary node array, but adds maybe a bit to much decoding complexity for only a 20-30% size reduction, and most of the size comes from conditions and results. --- .../smithy/rulesengine/logic/bdd/Bdd.java | 89 ++++++----------- .../rulesengine/logic/bdd/BddBuilder.java | 96 ++++++++----------- .../rulesengine/logic/bdd/BddCompiler.java | 2 - .../logic/bdd/BddEquivalenceChecker.java | 5 +- .../rulesengine/logic/bdd/BddTrait.java | 75 ++++----------- .../META-INF/smithy/smithy.rules.smithy | 3 + .../rulesengine/logic/bdd/BddBuilderTest.java | 20 ++-- .../smithy/rulesengine/logic/bdd/BddTest.java | 38 ++------ .../rulesengine/logic/bdd/BddTraitTest.java | 19 ++-- .../bdd/bdd-invalid-node-data.errors | 2 +- .../bdd/bdd-invalid-root-reference.smithy | 2 +- .../bdd/bdd-node-count-mismatch.errors | 1 - .../bdd/bdd-node-count-mismatch.smithy | 51 ---------- .../traits/errorfiles/bdd/bdd-valid.smithy | 6 +- 14 files changed, 132 insertions(+), 277 deletions(-) delete mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.errors delete mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.smithy diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java index 27c52480528..fe2313a8e24 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/Bdd.java @@ -10,6 +10,7 @@ import java.io.UncheckedIOException; import java.io.Writer; import java.nio.charset.StandardCharsets; +import java.util.Arrays; import java.util.function.Consumer; import software.amazon.smithy.rulesengine.logic.ConditionEvaluator; @@ -36,9 +37,7 @@ public final class Bdd { */ public static final int RESULT_OFFSET = 100_000_000; - private final int[] variables; - private final int[] highs; - private final int[] lows; + private final int[] nodes; // Flat array: [var0, high0, low0, var1, high1, low1, ...] private final int rootRef; private final int conditionCount; private final int resultCount; @@ -64,27 +63,29 @@ public Bdd(int rootRef, int conditionCount, int resultCount, int nodeCount, Cons InputNodeConsumer consumer = new InputNodeConsumer(nodeCount); nodeHandler.accept(consumer); - this.variables = consumer.variables; - this.highs = consumer.highs; - this.lows = consumer.lows; + this.nodes = consumer.nodes; - if (consumer.index != nodeCount) { - throw new IllegalStateException("Expected " + nodeCount + " nodes, but got " + consumer.index); + if (consumer.index != nodeCount * 3) { + throw new IllegalStateException("Expected " + nodeCount + " nodes, but got " + (consumer.index / 3)); } } - Bdd(int[] variables, int[] highs, int[] lows, int nodeCount, int rootRef, int conditionCount, int resultCount) { - validateArrays(variables, highs, lows, nodeCount); + /** + * Package-private constructor for direct array initialization (used by BddTrait). + */ + Bdd(int rootRef, int conditionCount, int resultCount, int nodeCount, int[] nodes) { validateCounts(conditionCount, resultCount, nodeCount); validateRootReference(rootRef, nodeCount); - this.variables = variables; - this.highs = highs; - this.lows = lows; + if (nodes.length != nodeCount * 3) { + throw new IllegalArgumentException("Nodes array length must be nodeCount * 3"); + } + this.rootRef = rootRef; this.conditionCount = conditionCount; this.resultCount = resultCount; this.nodeCount = nodeCount; + this.nodes = nodes; } private static void validateCounts(int conditionCount, int resultCount, int nodeCount) { @@ -109,34 +110,19 @@ private static void validateRootReference(int rootRef, int nodeCount) { } } - private static void validateArrays(int[] variables, int[] highs, int[] lows, int nodeCount) { - if (variables.length != highs.length || variables.length != lows.length) { - throw new IllegalArgumentException("Array lengths must match: variables=" + variables.length + - ", highs=" + highs.length + ", lows=" + lows.length); - } else if (nodeCount > variables.length) { - throw new IllegalArgumentException("Node count (" + nodeCount + - ") exceeds array capacity (" + variables.length + ")"); - } - } - private static final class InputNodeConsumer implements BddNodeConsumer { private int index = 0; - private final int[] variables; - private final int[] highs; - private final int[] lows; + private final int[] nodes; private InputNodeConsumer(int nodeCount) { - this.variables = new int[nodeCount]; - this.highs = new int[nodeCount]; - this.lows = new int[nodeCount]; + this.nodes = new int[nodeCount * 3]; } @Override public void accept(int var, int high, int low) { - variables[index] = var; - highs[index] = high; - lows[index] = low; - index++; + nodes[index++] = var; + nodes[index++] = high; + nodes[index++] = low; } } @@ -184,7 +170,7 @@ public int getRootRef() { */ public int getVariable(int nodeIndex) { validateRange(nodeIndex); - return variables[nodeIndex]; + return nodes[nodeIndex * 3]; } private void validateRange(int index) { @@ -201,7 +187,7 @@ private void validateRange(int index) { */ public int getHigh(int nodeIndex) { validateRange(nodeIndex); - return highs[nodeIndex]; + return nodes[nodeIndex * 3 + 1]; } /** @@ -212,7 +198,7 @@ public int getHigh(int nodeIndex) { */ public int getLow(int nodeIndex) { validateRange(nodeIndex); - return lows[nodeIndex]; + return nodes[nodeIndex * 3 + 2]; } /** @@ -222,7 +208,8 @@ public int getLow(int nodeIndex) { */ public void getNodes(BddNodeConsumer consumer) { for (int i = 0; i < nodeCount; i++) { - consumer.accept(variables[i], highs[i], lows[i]); + int base = i * 3; + consumer.accept(nodes[base], nodes[base + 1], nodes[base + 2]); } } @@ -234,14 +221,13 @@ public void getNodes(BddNodeConsumer consumer) { */ public int evaluate(ConditionEvaluator ev) { int ref = rootRef; - int[] vars = this.variables; - int[] hi = this.highs; - int[] lo = this.lows; + int[] n = this.nodes; while (isNodeReference(ref)) { - int idx = ref > 0 ? ref - 1 : -ref - 1; // Math.abs + int idx = ref > 0 ? ref - 1 : -ref - 1; + int base = idx * 3; // test ^ complement, pick hi or lo - ref = (ev.test(vars[idx]) ^ (ref < 0)) ? hi[idx] : lo[idx]; + ref = (ev.test(n[base]) ^ (ref < 0)) ? n[base + 1] : n[base + 2]; } return isTerminal(ref) ? -1 : ref - RESULT_OFFSET; @@ -303,27 +289,12 @@ public boolean equals(Object obj) { return false; } - // Now check the views of arrays of each. - for (int i = 0; i < nodeCount; i++) { - if (variables[i] != other.variables[i] || highs[i] != other.highs[i] || lows[i] != other.lows[i]) { - return false; - } - } - - return true; + return Arrays.equals(nodes, other.nodes); } @Override public int hashCode() { - int hash = 31 * rootRef + nodeCount; - // Sample up to 16 nodes distributed across the BDD - int step = Math.max(1, nodeCount / 16); - for (int i = 0; i < nodeCount; i += step) { - hash = 31 * hash + variables[i]; - hash = 31 * hash + highs[i]; - hash = 31 * hash + lows[i]; - } - return hash; + return 31 * rootRef + nodeCount + Arrays.hashCode(nodes); } @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java index 1c98a49e0c4..a3d4398d728 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilder.java @@ -29,10 +29,9 @@ final class BddBuilder { private static final int TRUE_REF = 1; private static final int FALSE_REF = -1; - // Node storage: three separate arrays - private int[] variables = new int[1024]; - private int[] highs = new int[1024]; - private int[] lows = new int[1024]; + // Node storage: flat array [var0, high0, low0, var1, high1, low1, ...] + private static final int INITIAL_SIZE = 256 * 3; + private int[] nodes = new int[INITIAL_SIZE]; private int nodeCount; // Unique tables for node deduplication and ITE caching @@ -48,7 +47,7 @@ final class BddBuilder { public BddBuilder() { this.nodeCount = 1; this.uniqueTable = new UniqueTable(); - this.iteCache = new UniqueTable(4096); // Larger initial capacity for ITE cache + this.iteCache = new UniqueTable(1024); initializeTerminalNode(); } @@ -156,9 +155,10 @@ private int insertNode(int var, int high, int low, boolean flip) { ensureCapacity(); int idx = nodeCount; - variables[idx] = var; - highs[idx] = high; - lows[idx] = low; + int base = idx * 3; + nodes[base] = var; + nodes[base + 1] = high; + nodes[base + 2] = low; nodeCount++; uniqueTable.put(var, high, low, idx); @@ -166,12 +166,10 @@ private int insertNode(int var, int high, int low, boolean flip) { } private void ensureCapacity() { - if (nodeCount >= variables.length) { + if (nodeCount * 3 >= nodes.length) { // Double the current capacity - int newCapacity = variables.length * 2; - variables = Arrays.copyOf(variables, newCapacity); - highs = Arrays.copyOf(highs, newCapacity); - lows = Arrays.copyOf(lows, newCapacity); + int newCapacity = nodes.length * 2; + nodes = Arrays.copyOf(nodes, newCapacity); } } @@ -246,7 +244,7 @@ public int getVariable(int ref) { int nodeIndex = Math.abs(ref) - 1; validateNodeIndex(nodeIndex); - return variables[nodeIndex]; + return nodes[nodeIndex * 3]; } /** @@ -267,11 +265,12 @@ public int cofactor(int bdd, int varIndex, boolean value) { int nodeIndex = toNodeIndex(bdd); validateNodeIndex(nodeIndex); - int nodeVar = variables[nodeIndex]; + int base = nodeIndex * 3; + int nodeVar = nodes[base]; if (nodeVar == varIndex) { // This node tests our variable, so take the appropriate branch - int child = value ? highs[nodeIndex] : lows[nodeIndex]; + int child = value ? nodes[base + 1] : nodes[base + 2]; // Only negate if child is not a result return (complemented && !isResult(child)) ? negate(child) : child; } else if (nodeVar > varIndex) { @@ -279,8 +278,8 @@ public int cofactor(int bdd, int varIndex, boolean value) { return bdd; } else { // Variable appears deeper, so recurse on both branches - int high = cofactor(highs[nodeIndex], varIndex, value); - int low = cofactor(lows[nodeIndex], varIndex, value); + int high = cofactor(nodes[base + 1], varIndex, value); + int low = cofactor(nodes[base + 2], varIndex, value); int result = makeNode(nodeVar, high, low); return (complemented && !isResult(result)) ? negate(result) : result; } @@ -358,7 +357,6 @@ public int ite(int f, int g, int h) { } } - // Check cache Integer cached = iteCache.get(f, g, h); if (cached != null) { return cached; @@ -372,10 +370,8 @@ public int ite(int f, int g, int h) { int r0 = ite(cofactor(f, v, false), cofactor(g, v, false), cofactor(h, v, false)); int r1 = ite(cofactor(f, v, true), cofactor(g, v, true), cofactor(h, v, true)); - // Build result node + // Build result node and cache it int result = makeNode(v, r1, r0); - - // Update cache with actual result iteCache.put(f, g, h, result); return result; } @@ -395,18 +391,16 @@ public int reduce(int rootRef) { boolean rootComp = rootRef < 0; int absRoot = rootComp ? negate(rootRef) : rootRef; - // Allocate new tables - int[] newVariables = new int[nodeCount]; - int[] newHighs = new int[nodeCount]; - int[] newLows = new int[nodeCount]; + // Allocate new nodes array + int[] newNodes = new int[nodeCount * 3]; // Clear and reuse the existing unique table uniqueTable.clear(); // Initialize the terminal node - newVariables[0] = -1; - newHighs[0] = TRUE_REF; - newLows[0] = FALSE_REF; + newNodes[0] = -1; + newNodes[1] = TRUE_REF; + newNodes[2] = FALSE_REF; // Prepare the visitation map int[] oldToNew = new int[nodeCount]; @@ -414,12 +408,10 @@ public int reduce(int rootRef) { int[] newCount = {1}; // start after terminal // Recursively rebuild - int newRoot = reduceRec(absRoot, oldToNew, newVariables, newHighs, newLows, newCount); + int newRoot = reduceRec(absRoot, oldToNew, newNodes, newCount); - // Swap in the new tables - this.variables = newVariables; - this.highs = newHighs; - this.lows = newLows; + // Swap in the new nodes array (trimmed to actual size) + this.nodes = Arrays.copyOf(newNodes, newCount[0] * 3); this.nodeCount = newCount[0]; clearCaches(); @@ -429,9 +421,7 @@ public int reduce(int rootRef) { private int reduceRec( int ref, int[] oldToNew, - int[] newVariables, - int[] newHighs, - int[] newLows, + int[] newNodes, int[] newCount ) { if (isLeaf(ref)) { @@ -450,9 +440,10 @@ private int reduceRec( } // Recurse on children - int var = variables[idx]; - int hiNew = reduceRec(highs[idx], oldToNew, newVariables, newHighs, newLows, newCount); - int loNew = reduceRec(lows[idx], oldToNew, newVariables, newHighs, newLows, newCount); + int base = idx * 3; + int var = nodes[base]; + int hiNew = reduceRec(nodes[base + 1], oldToNew, newNodes, newCount); + int loNew = reduceRec(nodes[base + 2], oldToNew, newNodes, newCount); // Reduction rule int resultAbs; @@ -472,9 +463,10 @@ private int reduceRec( resultAbs = toReference(existing); } else { int nodeIdx = newCount[0]++; - newVariables[nodeIdx] = var; - newHighs[nodeIdx] = hiNew; - newLows[nodeIdx] = loNew; + int newBase = nodeIdx * 3; + newNodes[newBase] = var; + newNodes[newBase + 1] = hiNew; + newNodes[newBase + 2] = loNew; uniqueTable.put(var, hiNew, loNew, nodeIdx); resultAbs = toReference(nodeIdx); } @@ -502,7 +494,7 @@ private int getTopVariable(int f, int g, int h) { private int updateMinVariable(int currentMin, int ref) { int absRef = Math.abs(ref); if (absRef > 1 && absRef < Bdd.RESULT_OFFSET) { - return Math.min(currentMin, variables[absRef - 1]); + return Math.min(currentMin, nodes[(absRef - 1) * 3]); } return currentMin; } @@ -522,9 +514,7 @@ public void clearCaches() { public BddBuilder reset() { clearCaches(); uniqueTable.clear(); - Arrays.fill(variables, 0, nodeCount, 0); - Arrays.fill(highs, 0, nodeCount, 0); - Arrays.fill(lows, 0, nodeCount, 0); + Arrays.fill(nodes, 0, nodeCount * 3, 0); nodeCount = 1; initializeTerminalNode(); conditionCount = -1; @@ -542,10 +532,8 @@ Bdd build(int rootRef, int resultCount) { throw new IllegalStateException("Condition count must be set before building BDD"); } - int[] v = Arrays.copyOf(variables, nodeCount); - int[] h = Arrays.copyOf(highs, nodeCount); - int[] l = Arrays.copyOf(lows, nodeCount); - return new Bdd(v, h, l, nodeCount, rootRef, conditionCount, resultCount); + int[] n = Arrays.copyOf(nodes, nodeCount * 3); + return new Bdd(rootRef, conditionCount, resultCount, nodeCount, n); } private void validateBooleanOperands(int f, int g, String operation) { @@ -569,8 +557,8 @@ private void validateNodeIndex(int nodeIndex) { } private void initializeTerminalNode() { - variables[0] = -1; - highs[0] = TRUE_REF; - lows[0] = FALSE_REF; + nodes[0] = -1; + nodes[1] = TRUE_REF; + nodes[2] = FALSE_REF; } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java index f9d97377654..3d8741d6890 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java @@ -90,11 +90,9 @@ private int convertCfgToBdd(CfgNode cfgNode) { int result; if (cfgNode == null) { result = bddBuilder.makeResult(noMatchIndex); - } else if (cfgNode instanceof ResultNode) { Rule rule = ((ResultNode) cfgNode).getResult(); result = bddBuilder.makeResult(getOrCreateResultIndex(rule)); - } else { ConditionNode cn = (ConditionNode) cfgNode; ConditionReference ref = cn.getCondition(); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java index 32bc8778aa3..b95d234b432 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java @@ -27,7 +27,7 @@ * *

    This verifier uses structural equivalence checking to ensure that both representations produce the same result. * When the BDD has fewer than 20 conditions, the checking is exhaustive. When there are more, random samples are - * checked up the earlier of max samples being reached, or the max duration being reached. + * checked up to the earlier of max samples being reached or the max duration being reached. */ public final class BddEquivalenceChecker { @@ -210,11 +210,12 @@ private void verifyCriticalCases() { verifyCase(allTrue ^ (1L << i)); } - // Alternating patterns + // Alternating patterns: 0101... (even conditions false, odd true) if (!hasEitherLimitBeenExceeded()) { verifyCase(0x5555555555555555L & ((1L << bdd.getConditionCount()) - 1)); } + // Pattern: 1010... (even conditions true, odd false) if (!hasEitherLimitBeenExceeded()) { verifyCase(0xAAAAAAAAAAAAAAAAL & ((1L << bdd.getConditionCount()) - 1)); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java index 71fceba527a..ca20c33a7ef 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java @@ -4,17 +4,18 @@ */ package software.amazon.smithy.rulesengine.logic.bdd; -import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; -import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.ArrayList; import java.util.Base64; import java.util.List; import java.util.Set; import java.util.function.Function; +import software.amazon.smithy.model.node.ArrayNode; import software.amazon.smithy.model.node.Node; import software.amazon.smithy.model.node.ObjectNode; import software.amazon.smithy.model.shapes.ShapeId; @@ -130,14 +131,14 @@ protected Node createNode() { ObjectNode.Builder builder = ObjectNode.builder(); builder.withMember("parameters", parameters.toNode()); - List conditionNodes = new ArrayList<>(); + ArrayNode.Builder conditionBuilder = ArrayNode.builder(); for (Condition c : conditions) { - conditionNodes.add(c.toNode()); + conditionBuilder.withValue(c); } - builder.withMember("conditions", Node.fromNodes(conditionNodes)); + builder.withMember("conditions", conditionBuilder.build()); // Results (skip NoMatchRule at index 0 for serialization) - List resultNodes = new ArrayList<>(); + ArrayNode.Builder resultBuilder = ArrayNode.builder(); if (!results.isEmpty() && !(results.get(0) instanceof NoMatchRule)) { throw new IllegalStateException("BDD must always have a NoMatchRule as the first result"); } @@ -146,9 +147,9 @@ protected Node createNode() { if (result instanceof NoMatchRule) { throw new IllegalStateException("NoMatch rules can only appear at rule index 0. Found at index " + i); } - resultNodes.add(result.toNode()); + resultBuilder.withValue(result); } - builder.withMember("results", Node.fromNodes(resultNodes)); + builder.withMember("results", resultBuilder.build()); builder.withMember("root", bdd.getRootRef()); builder.withMember("nodeCount", bdd.getNodeCount()); @@ -196,9 +197,9 @@ private static String encodeNodes(Bdd bdd) { DataOutputStream dos = new DataOutputStream(baos)) { bdd.getNodes((varIdx, high, low) -> { try { - writeVarInt(dos, varIdx); - writeVarInt(dos, high); - writeVarInt(dos, low); + dos.writeInt(varIdx); + dos.writeInt(high); + dos.writeInt(low); } catch (IOException e) { throw new UncheckedIOException(e); } @@ -213,53 +214,19 @@ private static String encodeNodes(Bdd bdd) { } private static Bdd decodeBdd(String base64, int nodeCount, int rootRef, int conditionCount, int resultCount) { - // Special case for empty BDD with just terminal (should never happen, but just in case). - if (base64.isEmpty() || nodeCount == 0) { - return new Bdd(rootRef, conditionCount, resultCount, 1, consumer -> { - consumer.accept(-1, 1, -1); - }); - } - byte[] data = Base64.getDecoder().decode(base64); - return new Bdd(rootRef, conditionCount, resultCount, nodeCount, consumer -> { - try (ByteArrayInputStream bais = new ByteArrayInputStream(data); - DataInputStream dis = new DataInputStream(bais)) { - for (int i = 0; i < nodeCount; i++) { - consumer.accept(readVarInt(dis), readVarInt(dis), readVarInt(dis)); - } - if (bais.available() > 0) { - throw new IllegalArgumentException("Extra data found after decoding " + nodeCount + - " nodes. " + bais.available() + " bytes remaining."); - } - } catch (IOException e) { - throw new RuntimeException("Failed to decode BDD nodes", e); - } - }); - } - - // Zig-zag + varint encode of a signed int - private static void writeVarInt(DataOutputStream dos, int value) throws IOException { - int zz = (value << 1) ^ (value >> 31); - while ((zz & ~0x7F) != 0) { - dos.writeByte((zz & 0x7F) | 0x80); - zz >>>= 7; + if (data.length != nodeCount * 12) { + throw new IllegalArgumentException("Expected " + (nodeCount * 12) + " bytes for " + nodeCount + + " nodes, but got " + data.length); } - dos.writeByte(zz); - } - // Decode a signed int from varint + zig-zag - private static int readVarInt(DataInputStream dis) throws IOException { - int shift = 0, result = 0; - while (true) { - byte b = dis.readByte(); - result |= (b & 0x7F) << shift; - if ((b & 0x80) == 0) { - break; - } - shift += 7; + int[] nodes = new int[nodeCount * 3]; + ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.BIG_ENDIAN); + for (int i = 0; i < nodes.length; i++) { + nodes[i] = buffer.getInt(); } - // reverse zig-zag - return (result >>> 1) ^ -(result & 1); + + return new Bdd(rootRef, conditionCount, resultCount, nodeCount, nodes); } /** diff --git a/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy b/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy index 5926d473029..69fd058c0c1 100644 --- a/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy +++ b/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy @@ -165,6 +165,9 @@ structure Result { /// Provided if type is "endpoint". endpoint: EndpointObject + + /// Conditions for the result (only used with decision tree rules). + conditions: Conditions } @private diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java index 57c9d1f4931..d0428d22194 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddBuilderTest.java @@ -202,7 +202,6 @@ void testResultInIte() { @Test void testSetConditionCountRequired() { - // Cannot create result without setting condition count assertThrows(IllegalStateException.class, () -> builder.makeResult(0)); } @@ -217,7 +216,6 @@ void testGetVariable() { assertEquals(1, builder.getVariable(node)); assertEquals(1, builder.getVariable(Math.abs(node))); // Use absolute value for complement - // Test result references int result = builder.makeResult(0); assertEquals(-1, builder.getVariable(result)); // results have no variable } @@ -233,7 +231,7 @@ void testReduceSimpleBdd() { int root = builder.makeNode(0, b, a); int nodesBefore = builder.getNodeCount(); - int reduced = builder.reduce(root); + builder.reduce(root); // Structure should be preserved if already optimal assertEquals(nodesBefore, builder.getNodeCount()); @@ -248,9 +246,8 @@ void testReduceNoChange() { int root = builder.makeNode(0, right, builder.makeFalse()); int nodesBefore = builder.getNodeCount(); - int reduced = builder.reduce(root); + builder.reduce(root); - // No change expected assertEquals(nodesBefore, builder.getNodeCount()); } @@ -279,7 +276,6 @@ void testReduceWithComplement() { int b = builder.makeNode(1, a, builder.negate(a)); int root = builder.makeNode(0, b, builder.makeFalse()); - // Reduce without complement first int reduced = builder.reduce(root); // Now test reducing with complemented root @@ -302,8 +298,7 @@ void testReduceClearsCache() { int b = builder.makeNode(1, builder.makeTrue(), builder.makeFalse()); int ite1 = builder.ite(a, b, builder.makeFalse()); - // Reduce - int reduced = builder.reduce(ite1); + builder.reduce(ite1); // Cache should be cleared, so same ITE creates new result // Recreate the nodes since reduce may have changed internal state @@ -323,7 +318,7 @@ void testReduceActuallyReduces() { int root = builder.makeNode(0, middle, bottom); int beforeSize = builder.getNodeCount(); - int reduced = builder.reduce(root); + builder.reduce(root); int afterSize = builder.getNodeCount(); // In this case, no reduction should occur since makeNode already optimized @@ -438,7 +433,6 @@ void testResultMaskNoCollisions() { int node2 = builder.makeNode(1, node1, builder.makeFalse()); int node3 = builder.makeNode(2, node2, node1); - // Create results int result0 = builder.makeResult(0); int result1 = builder.makeResult(1); @@ -463,13 +457,11 @@ void testReset() { builder.setConditionCount(2); // Create some state - int node = builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); - int result = builder.makeResult(0); + builder.makeNode(0, builder.makeTrue(), builder.makeFalse()); + builder.makeResult(0); - // Reset builder.reset(); - // Verify state is cleared assertEquals(1, builder.getNodeCount()); // Only terminal assertThrows(IllegalStateException.class, () -> builder.makeResult(0)); diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java index 96079a2c6cc..9f320f50635 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java @@ -196,47 +196,32 @@ void testStreamingConstructorValidation() { @Test void testArrayConstructorValidation() { - int[] vars = {-1}; - int[] highs = {1}; - int[] lows = {-1}; + int[] nodes = {-1, 1, -1}; // Valid construction assertDoesNotThrow(() -> { - new Bdd(vars, highs, lows, 1, 1, 1, 1); + new Bdd(1, 1, 1, 1, nodes); }); - // Null arrays - assertThrows(NullPointerException.class, () -> { - new Bdd(null, highs, lows, 1, 1, 1, 1); - }); - - assertThrows(NullPointerException.class, () -> { - new Bdd(vars, null, lows, 1, 1, 1, 1); - }); - - assertThrows(NullPointerException.class, () -> { - new Bdd(vars, highs, null, 1, 1, 1, 1); - }); - - // Mismatched array lengths - int[] shortArray = {}; + // Wrong array length (not multiple of 3) + int[] wrongLength = {-1, 1, -1, 0}; // 4 elements, not divisible by 3 assertThrows(IllegalArgumentException.class, () -> { - new Bdd(shortArray, highs, lows, 1, 1, 1, 1); + new Bdd(1, 1, 1, 1, wrongLength); }); - // Node count exceeds array capacity + // Array length doesn't match nodeCount assertThrows(IllegalArgumentException.class, () -> { - new Bdd(vars, highs, lows, 2, 1, 1, 1); // nodeCount=2 but arrays have length 1 + new Bdd(1, 1, 1, 2, nodes); // nodeCount=2 but array has 3 elements (1 node) }); // Root cannot be complemented (except -1) assertThrows(IllegalArgumentException.class, () -> { - new Bdd(vars, highs, lows, 1, -2, 1, 1); + new Bdd(-2, 1, 1, 1, nodes); }); // Root -1 (FALSE) is allowed assertDoesNotThrow(() -> { - new Bdd(vars, highs, lows, 1, -1, 1, 1); + new Bdd(-1, 1, 1, 1, nodes); }); } @@ -374,16 +359,13 @@ void testEquals() { // .build(); // // Cfg cfg = Cfg.from(ruleSet); - // Bdd bdd = Bdd.from(cfg); // - // BddTrait trait = BddTrait.builder().bdd(bdd).build(); + // BddTrait trait = BddTrait.from(cfg); // BddTraitValidator validator = new BddTraitValidator(); // ServiceShape service = ServiceShape.builder().id("foo#Bar").addTrait(trait).build(); // Model model = Model.builder().addShape(service).build(); // System.out.println(validator.validate(model)); // - // System.out.println(bdd); - // // // Get the base64 encoded nodes // System.out.println(Node.prettyPrintJson(trait.toNode())); // } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java index 1fbdc7c0ae7..fe004d1155d 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java @@ -70,19 +70,22 @@ void testBddTraitSerialization() { } private Bdd createSimpleBdd() { - return new Bdd(2, 1, 2, 2, consumer -> { - consumer.accept(-1, 1, -1); // node 0: terminal - consumer.accept(0, Bdd.RESULT_OFFSET + 1, -1); // node 1: if cond true, return result 1, else FALSE - }); + int[] nodes = new int[] { + -1, + 1, + -1, // node 0: terminal + 0, + Bdd.RESULT_OFFSET + 1, + -1 // node 1 + }; + return new Bdd(2, 1, 2, 2, nodes); } @Test void testEmptyBddTrait() { Parameters params = Parameters.builder().build(); - - Bdd bdd = new Bdd(-1, 0, 1, 1, consumer -> { - consumer.accept(-1, 1, -1); // terminal node only - }); + int[] nodes = new int[] {-1, 1, -1}; + Bdd bdd = new Bdd(-1, 0, 1, 1, nodes); BddTrait trait = BddTrait.builder() .parameters(params) diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.errors index a94a00b4490..13b553277b7 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.errors +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.errors @@ -1 +1 @@ -[ERROR] smithy.example#ValidBddService: Error creating trait `smithy.rules#bdd`: Failed to decode BDD nodes | Model +[ERROR] smithy.example#ValidBddService: Error creating trait `smithy.rules#bdd`: Expected 36 bytes for 3 nodes, but got 2 | Model diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy index a22a1ea2a45..5479643975c 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy @@ -8,7 +8,7 @@ use smithy.rules#bdd parameters: {} conditions: [] results: [] - nodes: "AAAA" // Base64 encoded empty node array + nodes: "" // Base64 encoded empty node array root: -5 // Invalid negative root reference (only -1 is allowed for FALSE) nodeCount: 0 }) diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.errors deleted file mode 100644 index a4999d204d9..00000000000 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.errors +++ /dev/null @@ -1 +0,0 @@ -[ERROR] smithy.example#ValidBddService: Error creating trait `smithy.rules#bdd`: Extra data found after decoding 2 nodes. 9 bytes remaining. | Model diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.smithy deleted file mode 100644 index c22a4416804..00000000000 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-node-count-mismatch.smithy +++ /dev/null @@ -1,51 +0,0 @@ -$version: "2.0" - -namespace smithy.example - -use smithy.rules#bdd - -@bdd({ - parameters: { - Region: { - type: "string" - required: true - documentation: "The AWS region" - } - UseFips: { - type: "boolean" - required: true - default: false - documentation: "Use FIPS endpoints" - } - } - conditions: [ - { - fn: "isSet" - argv: [{ref: "Region"}] - } - { - fn: "booleanEquals" - argv: [{ref: "UseFips"}, true] - } - ] - results: [ - { - type: "endpoint" - endpoint: { - url: "https://service.{Region}.amazonaws.com" - } - } - { - type: "endpoint" - endpoint: { - url: "https://service-fips.{Region}.amazonaws.com" - } - } - ] - nodes: "AQIBAAQBAoKS9AGAkvQB" - nodeCount: 2 - root: 1 -}) -service ValidBddService { - version: "2022-01-01" -} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy index 19b79f90b41..43f221e8014 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy @@ -36,6 +36,7 @@ use smithy.rules#clientContextParams ], "results": [ { + "conditions": [], "endpoint": { "url": "https://service-fips.{Region}.amazonaws.com", "properties": {}, @@ -44,6 +45,7 @@ use smithy.rules#clientContextParams "type": "endpoint" }, { + "conditions": [], "endpoint": { "url": "https://service.{Region}.amazonaws.com", "properties": {}, @@ -53,8 +55,8 @@ use smithy.rules#clientContextParams } ], "root": 2, - "nodes": "AQIBAIKEr1+EhK9f", - "nodeCount": 2 + "nodeCount": 2, + "nodes": "/////wAAAAH/////AAAAAAX14QEF9eEC" }) service ValidBddService { version: "2022-01-01" From 79b1cb3ee9833487ad5ee7ee04d78448fafb6d47 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Thu, 7 Aug 2025 22:48:13 -0500 Subject: [PATCH 11/23] Add coalesce function --- .../rulesengine/language/CoreExtension.java | 2 + .../language/evaluation/RuleEvaluator.java | 6 + .../language/evaluation/type/AnyType.java | 3 + .../language/evaluation/type/BooleanType.java | 3 + .../language/evaluation/type/EmptyType.java | 3 + .../evaluation/type/EndpointType.java | 3 + .../language/evaluation/type/IntegerType.java | 3 + .../language/evaluation/type/StringType.java | 3 + .../language/evaluation/type/Type.java | 30 ++-- .../syntax/expressions/ExpressionVisitor.java | 18 +++ .../expressions/functions/Coalesce.java | 146 ++++++++++++++++++ 11 files changed, 205 insertions(+), 15 deletions(-) create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java index 240422aa4c4..ba7d5113c12 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java @@ -6,6 +6,7 @@ import java.util.List; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionDefinition; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; @@ -36,6 +37,7 @@ public List getLibraryFunctions() { IsSet.getDefinition(), IsValidHostLabel.getDefinition(), Not.getDefinition(), + Coalesce.getDefinition(), ParseUrl.getDefinition(), StringEquals.getDefinition(), Substring.getDefinition(), diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java index e22635e6f53..08dfcffda68 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java @@ -183,6 +183,12 @@ public Value visitIsSet(Expression fn) { return Value.booleanValue(!fn.accept(this).isEmpty()); } + @Override + public Value visitCoalesce(Expression left, Expression right) { + Value leftValue = left.accept(this); + return leftValue.isEmpty() ? right.accept(this) : leftValue; + } + @Override public Value visitNot(Expression not) { return Value.booleanValue(!not.accept(this).expectBooleanValue().getValue()); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/AnyType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/AnyType.java index d23096dd3d9..799d15a3453 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/AnyType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/AnyType.java @@ -10,6 +10,9 @@ * The "any" type, which matches all other types. */ public final class AnyType extends AbstractType { + + static final AnyType INSTANCE = new AnyType(); + AnyType() {} @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/BooleanType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/BooleanType.java index 8c58f670224..68f11544649 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/BooleanType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/BooleanType.java @@ -8,6 +8,9 @@ * The "boolean" type. */ public final class BooleanType extends AbstractType { + + static final BooleanType INSTANCE = new BooleanType(); + BooleanType() {} @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EmptyType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EmptyType.java index 86ec7490935..7e9fec629de 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EmptyType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EmptyType.java @@ -10,6 +10,9 @@ * The "empty" type. */ public final class EmptyType extends AbstractType { + + static final EmptyType INSTANCE = new EmptyType(); + EmptyType() {} @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EndpointType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EndpointType.java index b10d57f667d..65a25d4ca97 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EndpointType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/EndpointType.java @@ -10,6 +10,9 @@ * The "endpoint" type, representing a valid client endpoint. */ public final class EndpointType extends AbstractType { + + static final EndpointType INSTANCE = new EndpointType(); + EndpointType() {} @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/IntegerType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/IntegerType.java index 33a71e2a58c..dc7834e2871 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/IntegerType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/IntegerType.java @@ -8,6 +8,9 @@ * The "integer" type. */ public final class IntegerType extends AbstractType { + + static final IntegerType INSTANCE = new IntegerType(); + IntegerType() {} @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/StringType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/StringType.java index 9738a474e4f..7765ce7a2e5 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/StringType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/StringType.java @@ -8,6 +8,9 @@ * The "string" type. */ public final class StringType extends AbstractType { + + static final StringType INSTANCE = new StringType(); + StringType() {} @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/Type.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/Type.java index 15738d40b83..e472f3f44a6 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/Type.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/Type.java @@ -38,20 +38,20 @@ default Type provenTruthy() { } static Type fromParameterType(ParameterType parameterType) { - if (parameterType == ParameterType.STRING) { - return stringType(); + switch (parameterType) { + case STRING: + return stringType(); + case BOOLEAN: + return booleanType(); + case STRING_ARRAY: + return arrayType(stringType()); + default: + throw new IllegalArgumentException("Unexpected parameter type: " + parameterType); } - if (parameterType == ParameterType.BOOLEAN) { - return booleanType(); - } - if (parameterType == ParameterType.STRING_ARRAY) { - return arrayType(stringType()); - } - throw new IllegalArgumentException("Unexpected parameter type: " + parameterType); } static AnyType anyType() { - return new AnyType(); + return AnyType.INSTANCE; } static ArrayType arrayType(Type inner) { @@ -59,19 +59,19 @@ static ArrayType arrayType(Type inner) { } static BooleanType booleanType() { - return new BooleanType(); + return BooleanType.INSTANCE; } static EmptyType emptyType() { - return new EmptyType(); + return EmptyType.INSTANCE; } static EndpointType endpointType() { - return new EndpointType(); + return EndpointType.INSTANCE; } static IntegerType integerType() { - return new IntegerType(); + return IntegerType.INSTANCE; } static OptionalType optionalType(Type type) { @@ -83,7 +83,7 @@ static RecordType recordType(Map inner) { } static StringType stringType() { - return new StringType(); + return StringType.INSTANCE; } static TupleType tupleType(List members) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java index 8d9da440bf0..2086426a76e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java @@ -5,9 +5,11 @@ package software.amazon.smithy.rulesengine.language.syntax.expressions; import java.util.List; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionDefinition; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.utils.ListUtils; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -49,6 +51,17 @@ public interface ExpressionVisitor { */ R visitIsSet(Expression fn); + /** + * Visits a coalesce function. + * + * @param left the first value to check. + * @param right the second value to check. + * @return the value from the visitor. + */ + default R visitCoalesce(Expression left, Expression right) { + return visitLibraryFunction(Coalesce.getDefinition(), ListUtils.of(left, right)); + } + /** * Visits a not function. * @@ -107,6 +120,11 @@ public R visitIsSet(Expression fn) { return getDefault(); } + @Override + public R visitCoalesce(Expression left, Expression right) { + return getDefault(); + } + @Override public R visitNot(Expression not) { return getDefault(); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java new file mode 100644 index 00000000000..1abefd9032b --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java @@ -0,0 +1,146 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.syntax.expressions.functions; + +import java.util.Arrays; +import java.util.List; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.evaluation.value.Value; +import software.amazon.smithy.rulesengine.language.syntax.ToExpression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.ExpressionVisitor; +import software.amazon.smithy.utils.SmithyUnstableApi; + +/** + * A coalesce function that returns the first non-empty value, with type-safe fallback handling. + * At runtime, returns the left value unless it's EmptyValue, in which case returns the right value. + * + *

    Type checking rules: + *

      + *
    • {@code coalesce(T, T) => T} (same types)
    • + *
    • {@code coalesce(T, S) => S} (if T.isA(S), i.e., S is more general)
    • + *
    • {@code coalesce(T, S) => T} (if S.isA(T), i.e., T is more general)
    • + *
    • {@code coalesce(Optional, S) => common_type(T, S)} (unwraps optional)
    • + *
    • {@code coalesce(T, Optional) => common_type(T, S)} (unwraps optional)
    • + *
    • {@code coalesce(Optional, Optional) => Optional}
    • + *
    + * + *

    Supports chaining: + * {@code coalesce(opt1, coalesce(opt2, coalesce(opt3, default)))} + */ +@SmithyUnstableApi +public final class Coalesce extends LibraryFunction { + public static final String ID = "coalesce"; + private static final Definition DEFINITION = new Definition(); + + private Coalesce(FunctionNode functionNode) { + super(DEFINITION, functionNode); + } + + /** + * Gets the {@link FunctionDefinition} implementation. + * + * @return the function definition. + */ + public static Definition getDefinition() { + return DEFINITION; + } + + /** + * Creates a {@link Coalesce} function from the given expressions. + * + * @param arg1 the first expression, typically optional. + * @param arg2 the second expression, used as fallback. + * @return The resulting {@link Coalesce} function. + */ + public static Coalesce ofExpressions(ToExpression arg1, ToExpression arg2) { + return DEFINITION.createFunction(FunctionNode.ofExpressions(ID, arg1, arg2)); + } + + @Override + public R accept(ExpressionVisitor visitor) { + List args = getArguments(); + return visitor.visitCoalesce(args.get(0), args.get(1)); + } + + @Override + public Type typeCheck(Scope scope) { + List args = getArguments(); + + if (args.size() != 2) { + throw new IllegalArgumentException("Coalesce requires exactly 2 arguments, got " + args.size()); + } + + Type leftType = args.get(0).typeCheck(scope); + Type rightType = args.get(1).typeCheck(scope); + Type leftInner = getInnerType(leftType); + Type rightInner = getInnerType(rightType); + + // Determine result type using isA + Type resultType; + if (leftInner.equals(rightInner)) { + resultType = leftInner; + } else if (leftInner.isA(rightInner)) { + resultType = rightInner; // right is more general + } else if (rightInner.isA(leftInner)) { + resultType = leftInner; // left is more general + } else { + throw new IllegalArgumentException(String.format( + "Type mismatch in coalesce: %s and %s have no common type", leftType, rightType)); + } + + // Only return Optional if both sides can be empty + if (leftType instanceof OptionalType && rightType instanceof OptionalType) { + return Type.optionalType(resultType); + } + + return resultType; + } + + private static Type getInnerType(Type t) { + return (t instanceof OptionalType) ? ((OptionalType) t).inner() : t; + } + + /** + * A {@link FunctionDefinition} for the {@link Coalesce} function. + */ + public static final class Definition implements FunctionDefinition { + private Definition() {} + + @Override + public String getId() { + return ID; + } + + @Override + public List getArguments() { + return Arrays.asList(Type.anyType(), Type.anyType()); + } + + @Override + public Type getReturnType() { + return Type.anyType(); + } + + @Override + public Value evaluate(List arguments) { + // Specialized in the ExpressionVisitor, so this doesn't need an implementation. + return null; + } + + @Override + public Coalesce createFunction(FunctionNode functionNode) { + return new Coalesce(functionNode); + } + + @Override + public int getCostHeuristic() { + // Coalesce can short-circuit, so it's cheap + return 1; + } + } +} From d140cd04c9dc40099415ce6fd5fac48efcfb52b8 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Mon, 11 Aug 2025 17:00:46 -0500 Subject: [PATCH 12/23] Improve initial ordering and add coalesce Our previous initial ordering could result in pathalogical orderings if it decided to moving something very early from the CFG to very late. This is in fact what happened when I added a coalesce method: it moved an early discriminating condition to very late, which blew up the BDD from ~40K nodes to 5.1M. This taught me that we really shouldn't throw away the ordering found in the CFG, and instead should leverage it when determining the initial ordering since it inherently gates logic and keeps related conditions together. So now the initial ordering is based on the CFG ordering and also on cone analysis (basically how many downstream nodes a node affects). We now get an initial ordering ~3K nodes, and with the coalesce method, we can now sift S3 down to ~800 nodes instead of ~1000. The coalesce function is added here so that we can fold bind-then-test conditions into a single condition. The current endpoints type system has strict nullability requirements. So you can't do a substring test and pass that directly into something that expects a non-null value. You have to first do the nullable function, then assign that to a value, then the next condition is inherently guarded and only called if the assigned value is non-null (the assignment acts as an implicit guard). The coalsce function allows us to identify these patterns and inline the test into a single condition by defaulting null to the zero value of the return type (integer=0, string="", array=[]). We only coalesce when the comparison is not to literally the zero value. When coalesce was added, it uncovered the original brittle ordering, leading to the much improved ordering in this PR. --- .../aws/language/functions/AwsPartition.java | 5 - .../functions/IsVirtualHostableS3Bucket.java | 5 - .../aws/language/functions/ParseArn.java | 5 - .../language/evaluation/type/ArrayType.java | 10 + .../language/evaluation/type/BooleanType.java | 9 + .../language/evaluation/type/IntegerType.java | 9 + .../evaluation/type/OptionalType.java | 7 + .../language/evaluation/type/StringType.java | 9 + .../language/evaluation/type/Type.java | 15 + .../evaluation/value/BooleanValue.java | 10 +- .../language/evaluation/value/EmptyValue.java | 2 + .../language/evaluation/value/Value.java | 4 +- .../syntax/expressions/Expression.java | 22 -- .../expressions/functions/BooleanEquals.java | 5 - .../expressions/functions/Coalesce.java | 78 ++-- .../functions/FunctionDefinition.java | 12 - .../syntax/expressions/functions/GetAttr.java | 5 - .../syntax/expressions/functions/IsSet.java | 5 - .../functions/IsValidHostLabel.java | 5 - .../functions/LibraryFunction.java | 9 - .../syntax/expressions/functions/Not.java | 5 - .../expressions/functions/ParseUrl.java | 5 - .../expressions/functions/StringEquals.java | 5 - .../expressions/functions/Substring.java | 5 - .../expressions/functions/UriEncode.java | 5 - .../expressions/literal/RecordLiteral.java | 9 - .../expressions/literal/StringLiteral.java | 23 -- .../expressions/literal/TupleLiteral.java | 9 - .../language/syntax/rule/Rule.java | 18 +- .../rulesengine/logic/bdd/BddCompiler.java | 8 +- .../logic/bdd/BddEquivalenceChecker.java | 3 +- .../rulesengine/logic/bdd/BddTrait.java | 2 +- .../logic/bdd/CfgConeAnalysis.java | 160 +++++++++ .../logic/bdd/CfgGuidedOrdering.java | 148 ++++++++ .../logic/bdd/ConditionDependencyGraph.java | 105 ------ .../logic/bdd/DefaultOrderingStrategy.java | 103 ------ .../logic/bdd/OrderConstraints.java | 92 ----- ...ingStrategy.java => OrderingStrategy.java} | 16 +- .../logic/bdd/SiftingOptimization.java | 26 +- .../rulesengine/logic/cfg/CfgBuilder.java | 3 +- .../logic/cfg/CoalesceTransform.java | 285 +++++++++++++++ .../logic/cfg/ConditionDependencyGraph.java | 315 ++++++++++++++++ .../logic/cfg/ReferenceRewriter.java | 164 +++++++++ .../rulesengine/logic/cfg/SsaTransform.java | 211 ++--------- .../logic/cfg/VariableAnalysis.java | 245 +++++++++++++ .../syntax/functions/CoalesceTest.java | 136 +++++++ .../logic/bdd/BddCompilerTest.java | 89 +++-- .../logic/bdd/BddEquivalenceCheckerTest.java | 73 ++-- .../smithy/rulesengine/logic/bdd/BddTest.java | 6 +- .../logic/bdd/CfgConeAnalysisTest.java | 233 ++++++++++++ .../logic/bdd/CfgGuidedOrderingTest.java | 335 ++++++++++++++++++ .../bdd/DefaultOrderingStrategyTest.java | 200 ----------- .../logic/bdd/OrderConstraintsTest.java | 133 ------- .../smithy/rulesengine/logic/cfg/CfgTest.java | 3 +- .../ConditionDependencyGraphTest.java | 5 +- .../logic/cfg/ReferenceRewriterTest.java | 181 ++++++++++ ...iguatorTest.java => SsaTransformTest.java} | 6 +- .../logic/cfg/VariableAnalysisTest.java | 327 +++++++++++++++++ 58 files changed, 2856 insertions(+), 1072 deletions(-) create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CfgConeAnalysis.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CfgGuidedOrdering.java delete mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraph.java delete mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategy.java delete mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraints.java rename smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/{ConditionOrderingStrategy.java => OrderingStrategy.java} (65%) create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDependencyGraph.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriter.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/CfgConeAnalysisTest.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/CfgGuidedOrderingTest.java delete mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategyTest.java delete mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraintsTest.java rename smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/{bdd => cfg}/ConditionDependencyGraphTest.java (96%) create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java rename smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/{VariableDisambiguatorTest.java => SsaTransformTest.java} (97%) create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysisTest.java diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsPartition.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsPartition.java index 7f1178a8d70..8a91ee62b98 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsPartition.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/AwsPartition.java @@ -184,11 +184,6 @@ public Value evaluate(List arguments) { public AwsPartition createFunction(FunctionNode functionNode) { return new AwsPartition(functionNode); } - - @Override - public int getCostHeuristic() { - return 6; - } } /** diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/IsVirtualHostableS3Bucket.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/IsVirtualHostableS3Bucket.java index 26e442bfbcd..d71b285d9d5 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/IsVirtualHostableS3Bucket.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/IsVirtualHostableS3Bucket.java @@ -97,11 +97,6 @@ public Value evaluate(List arguments) { public IsVirtualHostableS3Bucket createFunction(FunctionNode functionNode) { return new IsVirtualHostableS3Bucket(functionNode); } - - @Override - public int getCostHeuristic() { - return 8; - } } /** diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/ParseArn.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/ParseArn.java index e87ea48869e..1a4fa8c5283 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/ParseArn.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/ParseArn.java @@ -125,10 +125,5 @@ public Value evaluate(List arguments) { public ParseArn createFunction(FunctionNode functionNode) { return new ParseArn(functionNode); } - - @Override - public int getCostHeuristic() { - return 9; - } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/ArrayType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/ArrayType.java index ba20c757ca9..faeadf6e358 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/ArrayType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/ArrayType.java @@ -4,12 +4,17 @@ */ package software.amazon.smithy.rulesengine.language.evaluation.type; +import java.util.Collections; import java.util.Objects; +import java.util.Optional; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; /** * The "array" type, which contains entries of a member type. */ public final class ArrayType extends AbstractType { + + private static final Optional ZERO = Optional.of(Literal.tupleLiteral(Collections.emptyList())); private final Type member; ArrayType(Type member) { @@ -51,4 +56,9 @@ public int hashCode() { public String toString() { return String.format("ArrayType[%s]", member); } + + @Override + public Optional getZeroValue() { + return ZERO; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/BooleanType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/BooleanType.java index 68f11544649..af0cfd62543 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/BooleanType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/BooleanType.java @@ -4,11 +4,15 @@ */ package software.amazon.smithy.rulesengine.language.evaluation.type; +import java.util.Optional; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; + /** * The "boolean" type. */ public final class BooleanType extends AbstractType { + private static final Optional ZERO = Optional.of(Literal.of(false)); static final BooleanType INSTANCE = new BooleanType(); BooleanType() {} @@ -17,4 +21,9 @@ public final class BooleanType extends AbstractType { public BooleanType expectBooleanType() { return this; } + + @Override + public Optional getZeroValue() { + return ZERO; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/IntegerType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/IntegerType.java index dc7834e2871..713e047cac3 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/IntegerType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/IntegerType.java @@ -4,11 +4,15 @@ */ package software.amazon.smithy.rulesengine.language.evaluation.type; +import java.util.Optional; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; + /** * The "integer" type. */ public final class IntegerType extends AbstractType { + private static final Optional ZERO = Optional.of(Literal.of(0)); static final IntegerType INSTANCE = new IntegerType(); IntegerType() {} @@ -17,4 +21,9 @@ public final class IntegerType extends AbstractType { public IntegerType expectIntegerType() { return this; } + + @Override + public Optional getZeroValue() { + return ZERO; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/OptionalType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/OptionalType.java index d96389ce3b4..fa6150a16c9 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/OptionalType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/OptionalType.java @@ -5,7 +5,9 @@ package software.amazon.smithy.rulesengine.language.evaluation.type; import java.util.Objects; +import java.util.Optional; import software.amazon.smithy.rulesengine.language.error.InnerParseError; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; /** * The "optional" type, a container for a type that may or may not be present. @@ -78,4 +80,9 @@ public int hashCode() { public String toString() { return String.format("OptionalType[%s]", inner); } + + @Override + public Optional getZeroValue() { + return inner.getZeroValue(); + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/StringType.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/StringType.java index 7765ce7a2e5..1546d8113ca 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/StringType.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/StringType.java @@ -4,11 +4,15 @@ */ package software.amazon.smithy.rulesengine.language.evaluation.type; +import java.util.Optional; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; + /** * The "string" type. */ public final class StringType extends AbstractType { + private static final Optional ZERO = Optional.of(Literal.of("")); static final StringType INSTANCE = new StringType(); StringType() {} @@ -17,4 +21,9 @@ public final class StringType extends AbstractType { public StringType expectStringType() { return this; } + + @Override + public Optional getZeroValue() { + return ZERO; + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/Type.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/Type.java index e472f3f44a6..1313819497e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/Type.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/type/Type.java @@ -6,8 +6,10 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import software.amazon.smithy.rulesengine.language.error.InnerParseError; import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -129,4 +131,17 @@ default StringType expectStringType() throws InnerParseError { default TupleType expectTupleType() throws InnerParseError { throw new InnerParseError("Expected tuple but found " + this); } + + /** + * Gets the default zero-value of the type as a Literal. + * + *

    Strings, booleans, integers, and arrays have zero values. Other types do not. E.g., a map might have + * required properties, and the behavior of a tuple _seems_ to imply that each member is required. Optionals + * return the zero value of its inner type. + * + * @return the default zero value. + */ + default Optional getZeroValue() { + return Optional.empty(); + } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/BooleanValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/BooleanValue.java index 63a632c0d6c..0d9281a2d53 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/BooleanValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/BooleanValue.java @@ -14,9 +14,17 @@ * A boolean value of true or false. */ public final class BooleanValue extends Value { + + static final BooleanValue TRUE = new BooleanValue(true); + static final BooleanValue FALSE = new BooleanValue(false); + private final boolean value; - BooleanValue(boolean value) { + static BooleanValue create(boolean v) { + return v ? TRUE : FALSE; + } + + private BooleanValue(boolean value) { super(SourceLocation.none()); this.value = value; } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EmptyValue.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EmptyValue.java index 2409d9211d5..7b08d2458ec 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EmptyValue.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/EmptyValue.java @@ -12,6 +12,8 @@ * An empty value. */ public final class EmptyValue extends Value { + static final EmptyValue INSTANCE = new EmptyValue(); + public EmptyValue() { super(SourceLocation.none()); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/Value.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/Value.java index b9caeb0da1e..7abe193a2d5 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/Value.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/value/Value.java @@ -112,7 +112,7 @@ public static ArrayValue arrayValue(List value) { * @return returns the created BooleanValue. */ public static BooleanValue booleanValue(boolean value) { - return new BooleanValue(value); + return BooleanValue.create(value); } /** @@ -121,7 +121,7 @@ public static BooleanValue booleanValue(boolean value) { * @return returns the created EmptyValue. */ public static EmptyValue emptyValue() { - return new EmptyValue(); + return EmptyValue.INSTANCE; } /** diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java index 7f8398dfa87..da1e29bcf31 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/Expression.java @@ -36,7 +36,6 @@ @SmithyUnstableApi public abstract class Expression extends SyntaxElement implements FromSourceLocation, ToNode, TypeCheck { private final SourceLocation sourceLocation; - private Integer cachedComplexity; private Set cachedReferences; private Type cachedType; @@ -165,27 +164,6 @@ protected Set calculateReferences() { return Collections.emptySet(); } - /** - * Get the complexity heuristic of the expression, based on functions, references, etc. - * - * @return the complexity heuristic. - */ - public final int getComplexity() { - if (cachedComplexity == null) { - cachedComplexity = calculateComplexity(); - } - return cachedComplexity; - } - - /** - * Calculates the complexity of the expression, to be overridden by implementations. - * - * @return complexity estimate. - */ - protected int calculateComplexity() { - return 1; - } - /** * Invoke the {@link ExpressionVisitor} functions for this expression. * diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/BooleanEquals.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/BooleanEquals.java index 0eeaa8dbec8..066cf1cd966 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/BooleanEquals.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/BooleanEquals.java @@ -102,10 +102,5 @@ public Value evaluate(List arguments) { public BooleanEquals createFunction(FunctionNode functionNode) { return new BooleanEquals(functionNode); } - - @Override - public int getCostHeuristic() { - return 2; - } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java index 1abefd9032b..3e8061dcb44 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java @@ -7,6 +7,7 @@ import java.util.Arrays; import java.util.List; import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.AnyType; import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.evaluation.value.Value; @@ -22,6 +23,8 @@ *

    Type checking rules: *

      *
    • {@code coalesce(T, T) => T} (same types)
    • + *
    • {@code coalesce(T, AnyType) => T} (AnyType adapts to concrete type)
    • + *
    • {@code coalesce(AnyType, T) => T} (AnyType adapts to concrete type)
    • *
    • {@code coalesce(T, S) => S} (if T.isA(S), i.e., S is more general)
    • *
    • {@code coalesce(T, S) => T} (if S.isA(T), i.e., T is more general)
    • *
    • {@code coalesce(Optional, S) => common_type(T, S)} (unwraps optional)
    • @@ -29,6 +32,9 @@ *
    • {@code coalesce(Optional, Optional) => Optional}
    • *
    * + *

    Special handling for AnyType: Since AnyType can masquerade as any type, when coalescing + * with a concrete type, the concrete type is used as the result type. + * *

    Supports chaining: * {@code coalesce(opt1, coalesce(opt2, coalesce(opt3, default)))} */ @@ -67,6 +73,29 @@ public R accept(ExpressionVisitor visitor) { return visitor.visitCoalesce(args.get(0), args.get(1)); } + // Type checking rules for coalesce: + // + // This function returns the first non-empty value with type-safe fallback handling. + // The type resolution follows these rules: + // + // 1. If both types are identical, use that type + // 2. Special handling for AnyType: Since AnyType.isA() always returns true (it can masquerade as any type), we + // need to handle it specially. When coalescing AnyType with a concrete type, we use the concrete type as the + // result, since AnyType can adapt to it at runtime. + // 3. For other types, we use the isA relationship to find the more general type: + // - If left.isA(right), then right is more general, use right + // - If right.isA(left), then left is more general, use left + // 4. If no type relationship exists, throw a type mismatch error + // + // The result is wrapped in Optional only if BOTH inputs are Optional, since coalesce(optional, required) + // guarantees a non-empty result. + // + // Examples: + // - coalesce(String, String) => String + // - coalesce(Optional, String) => String + // - coalesce(Optional, Optional) => Optional + // - coalesce(String, AnyType) => String (AnyType adapts) + // - coalesce(SubType, SuperType) => SuperType (more general) @Override public Type typeCheck(Scope scope) { List args = getArguments(); @@ -77,21 +106,9 @@ public Type typeCheck(Scope scope) { Type leftType = args.get(0).typeCheck(scope); Type rightType = args.get(1).typeCheck(scope); - Type leftInner = getInnerType(leftType); - Type rightInner = getInnerType(rightType); - - // Determine result type using isA - Type resultType; - if (leftInner.equals(rightInner)) { - resultType = leftInner; - } else if (leftInner.isA(rightInner)) { - resultType = rightInner; // right is more general - } else if (rightInner.isA(leftInner)) { - resultType = leftInner; // left is more general - } else { - throw new IllegalArgumentException(String.format( - "Type mismatch in coalesce: %s and %s have no common type", leftType, rightType)); - } + + // Find the least upper bound (most specific common type) + Type resultType = lubForCoalesce(leftType, rightType); // Only return Optional if both sides can be empty if (leftType instanceof OptionalType && rightType instanceof OptionalType) { @@ -101,6 +118,28 @@ public Type typeCheck(Scope scope) { return resultType; } + // Finds the least upper bound (LUB) for coalesce type checking. + // The LUB is the most specific type that both input types can be assigned to. + // Special handling for AnyType: it adapts to concrete types rather than dominating them. + private static Type lubForCoalesce(Type a, Type b) { + Type ai = getInnerType(a); + Type bi = getInnerType(b); + + if (ai.equals(bi)) { + return ai; + } else if (ai instanceof AnyType) { + return bi; // AnyType adapts to concrete type + } else if (bi instanceof AnyType) { + return ai; // AnyType adapts to concrete type + } else if (ai.isA(bi)) { + return bi; // bi is more general + } else if (bi.isA(ai)) { + return ai; // ai is more general + } + + throw new IllegalArgumentException("Type mismatch in coalesce: " + a + " and " + b + " have no common type"); + } + private static Type getInnerType(Type t) { return (t instanceof OptionalType) ? ((OptionalType) t).inner() : t; } @@ -128,19 +167,12 @@ public Type getReturnType() { @Override public Value evaluate(List arguments) { - // Specialized in the ExpressionVisitor, so this doesn't need an implementation. - return null; + throw new UnsupportedOperationException("Coalesce evaluation is handled by ExpressionVisitor"); } @Override public Coalesce createFunction(FunctionNode functionNode) { return new Coalesce(functionNode); } - - @Override - public int getCostHeuristic() { - // Coalesce can short-circuit, so it's cheap - return 1; - } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java index 1573db78fa6..e8878c85ed6 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java @@ -45,16 +45,4 @@ public interface FunctionDefinition { * @return the created LibraryFunction implementation. */ LibraryFunction createFunction(FunctionNode functionNode); - - /** - * Get the relative "cost" of the function as compared to the baseline of "isset" which equals 1. - * - *

    If this function is considered more computationally expensive, then it has a value higher than 1. Otherwise, - * it has a value equal to 1. Defaults to "4" for unknown functions. - * - * @return the relative cost. - */ - default int getCostHeuristic() { - return 4; - } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/GetAttr.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/GetAttr.java index 82049b7c9cd..8e719afb25c 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/GetAttr.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/GetAttr.java @@ -238,11 +238,6 @@ public Value evaluate(List arguments) { public GetAttr createFunction(FunctionNode functionNode) { return new GetAttr(functionNode); } - - @Override - public int getCostHeuristic() { - return 7; - } } public interface Part { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsSet.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsSet.java index 9d55683bbcb..ee867a5cec6 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsSet.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsSet.java @@ -99,10 +99,5 @@ public Value evaluate(List arguments) { public IsSet createFunction(FunctionNode functionNode) { return new IsSet(functionNode); } - - @Override - public int getCostHeuristic() { - return 1; - } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsValidHostLabel.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsValidHostLabel.java index df45c277b97..4043f2da9aa 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsValidHostLabel.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/IsValidHostLabel.java @@ -93,11 +93,6 @@ public Value evaluate(List arguments) { public IsValidHostLabel createFunction(FunctionNode functionNode) { return new IsValidHostLabel(functionNode); } - - @Override - public int getCostHeuristic() { - return 8; - } } /** diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java index 3c9bcfddbe5..3f32dfd6252 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java @@ -245,13 +245,4 @@ private static boolean isReference(Expression arg) { } return false; } - - @Override - protected int calculateComplexity() { - int complexity = getFunctionDefinition().getCostHeuristic(); - for (Expression arg : getArguments()) { - complexity += arg.getComplexity(); - } - return complexity; - } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Not.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Not.java index 473e7e9d118..9b6330a7224 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Not.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Not.java @@ -87,10 +87,5 @@ public Value evaluate(List arguments) { public Not createFunction(FunctionNode functionNode) { return new Not(functionNode); } - - @Override - public int getCostHeuristic() { - return 2; - } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/ParseUrl.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/ParseUrl.java index e5a0bdd2a12..da9f3ca37f3 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/ParseUrl.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/ParseUrl.java @@ -119,11 +119,6 @@ public Value evaluate(List arguments) { public ParseUrl createFunction(FunctionNode functionNode) { return new ParseUrl(functionNode); } - - @Override - public int getCostHeuristic() { - return 10; - } } /** diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/StringEquals.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/StringEquals.java index 63645232cf6..8892e3ca32f 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/StringEquals.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/StringEquals.java @@ -111,10 +111,5 @@ public Value evaluate(List arguments) { public StringEquals createFunction(FunctionNode functionNode) { return new StringEquals(functionNode); } - - @Override - public int getCostHeuristic() { - return 3; - } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Substring.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Substring.java index 391daf0a7b5..70d5d41194c 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Substring.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Substring.java @@ -105,11 +105,6 @@ public Value evaluate(List arguments) { public Substring createFunction(FunctionNode functionNode) { return new Substring(functionNode); } - - @Override - public int getCostHeuristic() { - return 5; - } } /** diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/UriEncode.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/UriEncode.java index eca33d9de5c..b4815364538 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/UriEncode.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/UriEncode.java @@ -99,10 +99,5 @@ public Value evaluate(List arguments) { public UriEncode createFunction(FunctionNode functionNode) { return new UriEncode(functionNode); } - - @Override - public int getCostHeuristic() { - return 8; - } } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java index 92ecd99baf8..01454ef1724 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/RecordLiteral.java @@ -83,13 +83,4 @@ protected Set calculateReferences() { } return references; } - - @Override - protected int calculateComplexity() { - int complexity = 1; - for (Literal value : members().values()) { - complexity += value.getComplexity(); - } - return complexity; - } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java index f8597c51387..9e9326f34af 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/StringLiteral.java @@ -86,27 +86,4 @@ protected Set calculateReferences() { return references; } - - @Override - protected int calculateComplexity() { - Template template = value(); - if (template.isStatic()) { - return 1; - } - - int complexity = 1; - if (template.getParts().size() > 1) { - // Multiple parts are expensive - complexity += 8; - } - - for (Template.Part part : template.getParts()) { - if (part instanceof Template.Dynamic) { - Template.Dynamic dynamic = (Template.Dynamic) part; - complexity += dynamic.toExpression().getComplexity(); - } - } - - return complexity; - } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java index 9c269228559..71672ce554b 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/literal/TupleLiteral.java @@ -89,13 +89,4 @@ protected Set calculateReferences() { } return references; } - - @Override - protected int calculateComplexity() { - int complexity = 1; - for (Literal member : members()) { - complexity += member.getComplexity(); - } - return complexity; - } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java index 6ef6943d306..33b1efa5090 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/Rule.java @@ -135,20 +135,24 @@ public Optional getDocumentation() { public abstract T accept(RuleValueVisitor visitor); /** - * Get a new Rule of the same type that has the same values, but no conditions. + * Get a new Rule of the same type that has the same values, but with the given conditions. * - * @return the rule without conditions. + * @param conditions Conditions to use. + * @return the rule with the given conditions. * @throws UnsupportedOperationException if it is a TreeRule or Condition rule. */ - public Rule withoutConditions() { - if (getConditions().isEmpty()) { + public final Rule withConditions(List conditions) { + if (getConditions().equals(conditions)) { return this; } else if (this instanceof ErrorRule) { - return new ErrorRule(ErrorRule.builder(this), ((ErrorRule) this).getError()); + return new ErrorRule(ErrorRule.builder(this).conditions(conditions), ((ErrorRule) this).getError()); } else if (this instanceof EndpointRule) { - return new EndpointRule(EndpointRule.builder(this), ((EndpointRule) this).getEndpoint()); + return new EndpointRule(EndpointRule.builder(this).conditions(conditions), + ((EndpointRule) this).getEndpoint()); + } else if (this instanceof TreeRule) { + return new TreeRule(TreeRule.builder(this).conditions(conditions), ((TreeRule) this).getRules()); } else { - throw new UnsupportedOperationException("Cannot remove conditions from " + this); + throw new UnsupportedOperationException("Unknown rule type: " + this.getClass()); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java index 3d8741d6890..ee944970289 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompiler.java @@ -28,7 +28,7 @@ final class BddCompiler { private final Cfg cfg; private final BddBuilder bddBuilder; - private final ConditionOrderingStrategy orderingStrategy; + private final OrderingStrategy orderingStrategy; // Condition ordering private List orderedConditions; @@ -43,7 +43,11 @@ final class BddCompiler { // Simple cache to avoid recomputing identical subgraphs private final Map nodeCache = new HashMap<>(); - BddCompiler(Cfg cfg, ConditionOrderingStrategy orderingStrategy, BddBuilder bddBuilder) { + BddCompiler(Cfg cfg, BddBuilder bddBuilder) { + this(cfg, OrderingStrategy.initialOrdering(cfg), bddBuilder); + } + + BddCompiler(Cfg cfg, OrderingStrategy orderingStrategy, BddBuilder bddBuilder) { this.cfg = Objects.requireNonNull(cfg, "CFG cannot be null"); this.orderingStrategy = Objects.requireNonNull(orderingStrategy, "Ordering strategy cannot be null"); this.bddBuilder = Objects.requireNonNull(bddBuilder, "BDD builder cannot be null"); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java index b95d234b432..e59e9a47573 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java @@ -6,6 +6,7 @@ import java.time.Duration; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -329,7 +330,7 @@ private boolean resultsEqual(Rule r1, Rule r2) { } else if (r1 == null || r2 == null) { return false; } else { - return r1.withoutConditions().equals(r2.withoutConditions()); + return r1.withConditions(Collections.emptyList()).equals(r2.withConditions(Collections.emptyList())); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java index ca20c33a7ef..feecc4d38fe 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java @@ -65,7 +65,7 @@ private BddTrait(Builder builder) { * @return the BddTrait containing the compiled BDD and all context */ public static BddTrait from(Cfg cfg) { - BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); Bdd bdd = compiler.compile(); if (compiler.getOrderedConditions().size() != bdd.getConditionCount()) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CfgConeAnalysis.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CfgConeAnalysis.java new file mode 100644 index 00000000000..f7645ac2b3d --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CfgConeAnalysis.java @@ -0,0 +1,160 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.logic.cfg.CfgNode; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionNode; +import software.amazon.smithy.rulesengine.logic.cfg.ResultNode; + +/** + * Analyzes a CFG to compute cone information for each condition. + * + *

    A "cone" is the subgraph of the CFG that is reachable from a given condition node. Think of it as the + * "downstream impact" of a condition. A condition with a large cone controls many downstream decisions (high impact). + * A condition with many reachable result nodes in its cone affects many endpoints. Conditions that appear early in the + * CFG (low dominator depth) are "gates" that control access to large portions of the decision tree. + */ +final class CfgConeAnalysis { + /** + * Dominator depth for each condition, or how many edges from CFG root to first occurrence. + * Initialized to MAX_VALUE, then updated to minimum depth encountered during traversal. + */ + private final int[] dominatorDepth; + + /** Number of result nodes (endpoints/errors) reachable from each condition's cone. */ + private final int[] reachableResults; + + /** Cache of computed cone information for each CFG node to avoid redundant traversals. */ + private final Map coneCache = new HashMap<>(); + + /** Maps conditions to their indices for quick lookups. */ + private final Map conditionToIndex; + + /** + * Creates a new cone analysis for the given CFG and conditions. + * + * @param cfg the control flow graph to analyze + * @param conditions array of conditions in the rule set + * @param conditionToIndex mapping from conditions to their indices + */ + public CfgConeAnalysis(Cfg cfg, Condition[] conditions, Map conditionToIndex) { + this.conditionToIndex = conditionToIndex; + int n = conditions.length; + this.dominatorDepth = new int[n]; + this.reachableResults = new int[n]; + Arrays.fill(dominatorDepth, Integer.MAX_VALUE); + analyzeCfgNode(cfg.getRoot(), 0); + } + + /** + * Recursively analyzes a CFG node and its subtree, computing cone information. + * + * @param node the current CFG node being analyzed + * @param depth the current depth in the CFG traversal (edges from root) + * @return cone information for this subtree + */ + private ConeInfo analyzeCfgNode(CfgNode node, int depth) { + if (node == null) { + return ConeInfo.empty(); + } + + ConeInfo cached = coneCache.get(node); + if (cached != null) { + if (cached.inProgress) { + throw new IllegalStateException("Cycle detected in CFG during cone analysis: " + node); + } + return cached; + } + + // Cycle guard: if a transform accidentally introduced a cycle, fail fast. + coneCache.put(node, ConeInfo.IN_PROGRESS); + ConeInfo info; + + if (node instanceof ResultNode) { + info = ConeInfo.singleResult(); + } else if (node instanceof ConditionNode) { + ConditionNode condNode = (ConditionNode) node; + Condition condition = condNode.getCondition().getCondition(); + Integer conditionIdx = conditionToIndex.get(condition); + if (conditionIdx == null) { + throw new IllegalStateException("Condition not indexed in CFG: " + condition); + } + + // Handle conditions that appear multiple times by updating dominator depth. + // Keep the minimum depth where this condition appears. + dominatorDepth[conditionIdx] = Math.min(dominatorDepth[conditionIdx], depth); + + ConeInfo trueBranchCone = analyzeCfgNode(condNode.getTrueBranch(), depth + 1); + ConeInfo falseBranchCone = analyzeCfgNode(condNode.getFalseBranch(), depth + 1); + info = ConeInfo.combine(trueBranchCone, falseBranchCone); + + // Update the maximum result count this condition can influence + reachableResults[conditionIdx] = Math.max(reachableResults[conditionIdx], info.resultNodes); + } else { + throw new UnsupportedOperationException("Unknown node type: " + node); + } + + coneCache.put(node, info); + return info; + } + + /** + * Gets the dominator depth of a condition, the minimum depth at which this condition appears in the CFG. + * + *

    Lower values indicate conditions that appear earlier in the decision tree and have more influence over + * the overall control flow. + * + * @param conditionIdx the index of the condition + * @return the minimum depth, or Integer.MAX_VALUE if never encountered + */ + public int dominatorDepth(int conditionIdx) { + return dominatorDepth[conditionIdx]; + } + + /** + * Gets the cone size as the number of reachable result nodes for a condition, representing how many different + * endpoints/errors can be reached downstream of the condition. + * + *

    Larger values indicate conditions that have broader impact on the final outcome. + * + * @param conditionIdx the index of the condition + * @return the number of result nodes in this condition's cone + */ + public int coneSize(int conditionIdx) { + return reachableResults[conditionIdx]; + } + + private static final class ConeInfo { + private static final ConeInfo IN_PROGRESS = new ConeInfo(0, true); + + final int resultNodes; + final boolean inProgress; + + private ConeInfo(int resultNodes, boolean inProgress) { + this.resultNodes = resultNodes; + this.inProgress = inProgress; + } + + private static ConeInfo empty() { + return new ConeInfo(0, false); + } + + private static ConeInfo singleResult() { + return new ConeInfo(1, false); + } + + private static ConeInfo combine(ConeInfo trueBranch, ConeInfo falseBranch) { + if (trueBranch.inProgress || falseBranch.inProgress) { + throw new IllegalStateException("Cycle detected in CFG during cone analysis (branch in-progress)"); + } + return new ConeInfo(trueBranch.resultNodes + falseBranch.resultNodes, false); + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CfgGuidedOrdering.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CfgGuidedOrdering.java new file mode 100644 index 00000000000..009a3b724dc --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CfgGuidedOrdering.java @@ -0,0 +1,148 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.BitSet; +import java.util.List; +import java.util.Map; +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionDependencyGraph; + +/** + * Orders conditions by following the natural structure of the CFG. + * + *

    This strategy has proven to be the most effective for BDD construction because it preserves the locality that + * rule authors built into their decision trees. Conditions that are evaluated together in the original rules stay + * together in the BDD, enabling better node sharing. This ordering implementation flattens the tree structure while + * respecting data dependencies. + */ +final class CfgGuidedOrdering implements OrderingStrategy { + private static final Logger LOGGER = Logger.getLogger(CfgGuidedOrdering.class.getName()); + + /** How many distinct consumers make an isSet() a "gate". */ + private static final int GATE_SUCCESSOR_THRESHOLD = 2; + + private final Cfg cfg; + + CfgGuidedOrdering(Cfg cfg) { + this.cfg = cfg; + } + + @Override + public List orderConditions(Condition[] conditions) { + long startTime = System.currentTimeMillis(); + + ConditionDependencyGraph deps = new ConditionDependencyGraph(Arrays.asList(conditions)); + Map conditionToIndex = deps.getConditionToIndex(); + CfgConeAnalysis cones = new CfgConeAnalysis(cfg, conditions, conditionToIndex); + List order = buildCfgOrder(conditions, deps, cones); + + List result = new ArrayList<>(); + for (int id : order) { + result.add(conditions[id]); + } + + long elapsed = System.currentTimeMillis() - startTime; + LOGGER.info(() -> String.format("Initial ordering: %d conditions in %dms", conditions.length, elapsed)); + return result; + } + + // Builds an ordering using a topological sort that prefers conditions based on their position in the CFG. + private List buildCfgOrder(Condition[] conditions, ConditionDependencyGraph deps, CfgConeAnalysis cones) { + List result = new ArrayList<>(); + BitSet placed = new BitSet(); + BitSet ready = new BitSet(); + + // Start with conditions that have no dependencies + for (int i = 0; i < conditions.length; i++) { + if (deps.getPredecessorCount(i) == 0) { + ready.set(i); + } + } + + while (!ready.isEmpty()) { + int chosen = getNext(ready, cones, conditions, deps); + result.add(chosen); + placed.set(chosen); + ready.clear(chosen); + + // Make successors ready if all their dependencies are satisfied + for (int succ : deps.getSuccessors(chosen)) { + if (!placed.get(succ) && allPredecessorsPlaced(succ, deps, placed)) { + ready.set(succ); + } + } + } + + if (result.size() != conditions.length) { + throw new IllegalStateException("Topological ordering incomplete (possible cyclic deps). Placed=" + + result.size() + " of " + conditions.length); + } + + return result; + } + + /** + * Selects the next condition to place based on CFG structure. + * + *

    1. Pick conditions closest to the CFG root (minimum depth) + * 2. Prefer "gate" conditions that guard many branches + * 3. Break ties with cone size (bigger is more discriminating) + * 4. Tie-break with ID for determinism + */ + private int getNext(BitSet ready, CfgConeAnalysis cones, Condition[] conditions, ConditionDependencyGraph deps) { + int best = -1; + int bestDepth = Integer.MAX_VALUE; + int bestCone = -1; + boolean bestIsGate = false; + + for (int i = ready.nextSetBit(0); i >= 0; i = ready.nextSetBit(i + 1)) { + int depth = cones.dominatorDepth(i); + if (depth > bestDepth) { + continue; // Skip if worse + } + + int cone = cones.coneSize(i); + boolean isGate = isIsSet(conditions[i]) && deps.getSuccessorCount(i) > GATE_SUCCESSOR_THRESHOLD; + + if (depth < bestDepth) { + // New best if shallower + best = i; + bestDepth = depth; + bestCone = cone; + bestIsGate = isGate; + } else if (!bestIsGate && isGate) { + // Gates win + best = i; + bestCone = cone; + bestIsGate = true; + } else if (bestIsGate == isGate && (cone > bestCone || (cone == bestCone && i < best))) { + // Same gate status, so pick larger cone or lower ID for stability + best = i; + bestCone = cone; + } + } + + return best; + } + + private boolean allPredecessorsPlaced(int id, ConditionDependencyGraph deps, BitSet placed) { + for (int pred : deps.getPredecessors(id)) { + if (!placed.get(pred)) { + return false; + } + } + return true; + } + + private static boolean isIsSet(Condition c) { + return c.getFunction().getFunctionDefinition() == IsSet.getDefinition(); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraph.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraph.java deleted file mode 100644 index d4a678cce99..00000000000 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraph.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.logic.bdd; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; - -/** - * Immutable graph of dependencies between conditions. - * - *

    This class performs the expensive AST analysis once to extract: - *

      - *
    • Variable definitions - which conditions define which variables
    • - *
    • Variable usage - which conditions use which variables
    • - *
    • Raw dependencies - which conditions must come before others
    • - *
    - */ -final class ConditionDependencyGraph { - private final List conditions; - private final Map> dependencies; - private final Map> variableDefiners; - private final Map> isSetConditions; - - /** - * Creates a dependency graph by analyzing the given conditions. - * - * @param conditions the conditions to analyze - */ - public ConditionDependencyGraph(List conditions) { - this.conditions = Collections.unmodifiableList(new ArrayList<>(conditions)); - this.variableDefiners = new LinkedHashMap<>(); - this.isSetConditions = new LinkedHashMap<>(); - - // Categorize all conditions - for (Condition cond : conditions) { - // Track variable definition - if (cond.getResult().isPresent()) { - String definedVar = cond.getResult().get().toString(); - variableDefiners.computeIfAbsent(definedVar, k -> new LinkedHashSet<>()).add(cond); - } - - // Track isSet conditions - if (isIsset(cond)) { - for (String var : cond.getFunction().getReferences()) { - isSetConditions.computeIfAbsent(var, k -> new LinkedHashSet<>()).add(cond); - } - } - } - - // Compute dependencies - Map> deps = new LinkedHashMap<>(); - for (Condition cond : conditions) { - Set condDeps = new LinkedHashSet<>(); - - for (String usedVar : cond.getFunction().getReferences()) { - // Must come after any condition that defines this variable - condDeps.addAll(variableDefiners.getOrDefault(usedVar, Collections.emptySet())); - - // Non-isSet conditions must come after isSet checks for undefined variables - if (!isIsset(cond)) { - condDeps.addAll(isSetConditions.getOrDefault(usedVar, Collections.emptySet())); - } - } - - condDeps.remove(cond); // Remove self-dependencies - if (!condDeps.isEmpty()) { - deps.put(cond, Collections.unmodifiableSet(condDeps)); - } - } - - this.dependencies = Collections.unmodifiableMap(deps); - } - - /** - * Gets the dependencies for a condition. - * - * @param condition the condition to query - * @return set of conditions that must come before it (never null) - */ - public Set getDependencies(Condition condition) { - return dependencies.getOrDefault(condition, Collections.emptySet()); - } - - /** - * Gets the number of conditions in this dependency graph. - * - * @return the number of conditions - */ - public int size() { - return conditions.size(); - } - - private static boolean isIsset(Condition cond) { - return cond.getFunction().getFunctionDefinition() == IsSet.getDefinition(); - } -} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategy.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategy.java deleted file mode 100644 index b9cc2653399..00000000000 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategy.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.logic.bdd; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; - -/** - * Orders conditions for BDD construction while respecting variable dependencies. - * - *

    The ordering ensures that: - *

      - *
    • Variables are defined before they are used
    • - *
    • isSet checks come before value checks for the same variable
    • - *
    • Simpler conditions are evaluated first (when dependencies allow)
    • - *
    - */ -final class DefaultOrderingStrategy { - private DefaultOrderingStrategy() {} - - static List orderConditions(Condition[] conditions) { - ConditionDependencyGraph deps = new ConditionDependencyGraph(Arrays.asList(conditions)); - return sort(conditions, deps); - } - - private static List sort( - Condition[] conditions, - ConditionDependencyGraph deps - ) { - List result = new ArrayList<>(); - Set visited = new HashSet<>(); - Set visiting = new HashSet<>(); - - // Sort conditions by priority - List queue = new ArrayList<>(); - Collections.addAll(queue, conditions); - - queue.sort(Comparator - // fewer deps first - .comparingInt((Condition c) -> deps.getDependencies(c).size()) - // isSet() before everything else - .thenComparingInt(c -> c.getFunction().getFunctionDefinition() == IsSet.getDefinition() ? 0 : 1) - // variable-defining conditions first - .thenComparingInt(c -> c.getResult().isPresent() ? 0 : 1) - // fewer references first - .thenComparingInt(c -> c.getFunction().getReferences().size()) - // lower complexity first - .thenComparingInt(c -> c.getFunction().getComplexity()) - // stable tie-breaker - .thenComparing(Condition::toString)); - - // Visit in priority order - for (Condition cond : queue) { - if (!visited.contains(cond)) { - visit(cond, deps, visited, visiting, result); - } - } - - return result; - } - - private static void visit( - Condition cond, - ConditionDependencyGraph depGraph, - Set visited, - Set visiting, - List result - ) { - if (visiting.contains(cond)) { - throw new IllegalStateException("Circular dependency detected involving: " + cond); - } - - if (visited.contains(cond)) { - return; - } - - visiting.add(cond); - - // Visit dependencies first - Set deps = depGraph.getDependencies(cond); - if (!deps.isEmpty()) { - List sortedDeps = new ArrayList<>(deps); - sortedDeps.sort(Comparator.comparingInt(c -> c.getFunction().getComplexity())); - - for (Condition dep : sortedDeps) { - visit(dep, depGraph, visited, visiting, result); - } - } - - visiting.remove(cond); - visited.add(cond); - result.add(cond); - } -} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraints.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraints.java deleted file mode 100644 index 5f2ca2df997..00000000000 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraints.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.logic.bdd; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; - -/** - * Order-specific constraints derived from a dependency graph. - * - *

    This class efficiently computes position-based constraints for a specific ordering of conditions, using the - * pre-computed dependency graph. It can be created cheaply for each new ordering during optimization. - */ -final class OrderConstraints { - private final Condition[] conditions; - private final Map conditionToIndex; - private final int[] minValidPosition; - private final int[] maxValidPosition; - - /** - * Creates order constraints for a specific ordering. - * - * @param graph the pre-computed dependency graph - * @param conditions the conditions in their specific order - */ - public OrderConstraints(ConditionDependencyGraph graph, List conditions) { - int n = conditions.size(); - if (n != graph.size()) { - throw new IllegalArgumentException( - "Condition count (" + n + ") doesn't match dependency graph size (" + graph.size() + ")"); - } - - this.conditions = conditions.toArray(new Condition[0]); - this.conditionToIndex = new HashMap<>(n * 2); - this.minValidPosition = new int[n]; - this.maxValidPosition = new int[n]; - - // Build index mapping - for (int i = 0; i < n; i++) { - conditionToIndex.put(this.conditions[i], i); - } - - // Build dependencies and compute valid positions in one pass - for (int i = 0; i < n; i++) { - maxValidPosition[i] = n - 1; // Initialize max position - for (Condition dep : graph.getDependencies(this.conditions[i])) { - Integer depIndex = conditionToIndex.get(dep); - if (depIndex != null) { - // This condition must come after its dependency - minValidPosition[i] = Math.max(minValidPosition[i], depIndex + 1); - // The dependency must come before this condition - maxValidPosition[depIndex] = Math.min(maxValidPosition[depIndex], i - 1); - } - } - } - } - - /** - * Checks if moving a condition from one position to another would violate dependencies. - * - * @param from current position - * @param to target position - * @return true if the move is valid - */ - public boolean canMove(int from, int to) { - return from == to || (to >= minValidPosition[from] && to <= maxValidPosition[from]); - } - - /** - * Gets the minimum valid position for a condition. - * - * @param conditionIndex the condition index - * @return the minimum position where this condition can be placed - */ - public int getMinValidPosition(int conditionIndex) { - return minValidPosition[conditionIndex]; - } - - /** - * Gets the maximum valid position for a condition. - * - * @param conditionIndex the condition index - * @return the maximum position where this condition can be placed - */ - public int getMaxValidPosition(int conditionIndex) { - return maxValidPosition[conditionIndex]; - } -} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionOrderingStrategy.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderingStrategy.java similarity index 65% rename from smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionOrderingStrategy.java rename to smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderingStrategy.java index 06f9e48cd59..46220e01b13 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionOrderingStrategy.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderingStrategy.java @@ -6,12 +6,13 @@ import java.util.List; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; /** * Strategy interface for ordering conditions in a BDD. */ @FunctionalInterface -interface ConditionOrderingStrategy { +interface OrderingStrategy { /** * Orders the given conditions for BDD construction. * @@ -21,20 +22,19 @@ interface ConditionOrderingStrategy { List orderConditions(Condition[] conditions); /** - * Default ordering strategy that uses the existing ConditionOrderer. + * Creates an initial ordering strategy using the given CFG. * - * @return return the default ordering strategy. + * @param cfg CFG to process. + * @return the initial ordering strategy. */ - static ConditionOrderingStrategy defaultOrdering() { - return DefaultOrderingStrategy::orderConditions; + static OrderingStrategy initialOrdering(Cfg cfg) { + return new CfgGuidedOrdering(cfg); } /** * Fixed ordering strategy that uses a pre-determined order. - * - * @return a fixed ordering strategy. */ - static ConditionOrderingStrategy fixed(List ordering) { + static OrderingStrategy fixed(List ordering) { return conditions -> ordering; } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java index d71eef0a349..e375bd8a1e0 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -14,6 +14,7 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.rulesengine.logic.cfg.ConditionDependencyGraph; import software.amazon.smithy.utils.SmithyBuilder; /** @@ -60,7 +61,7 @@ public final class SiftingOptimization implements Function { private enum OptimizationEffort { COARSE(12, 4, 0), MEDIUM(2, 18, 5), - GRANULAR(1, 30, 6); + GRANULAR(1, 48, 10); final int sampleRate; final int maxPositions; @@ -253,7 +254,7 @@ private OptimizationResult runOptimizationPass( ) { int improvements = 0; - OrderConstraints constraints = new OrderConstraints(dependencyGraph, orderView); + ConditionDependencyGraph.OrderConstraints constraints = dependencyGraph.createOrderConstraints(orderView); Bdd bestBdd = null; int bestSize = currentSize; List bestResults = null; @@ -287,7 +288,7 @@ private OptimizationResult runOptimizationPass( improvements++; // Update constraints after successful move - constraints = new OrderConstraints(dependencyGraph, orderView); + constraints = dependencyGraph.createOrderConstraints(orderView); } } @@ -295,7 +296,7 @@ private OptimizationResult runOptimizationPass( } private OptimizationResult performAdjacentSwaps(Condition[] order, List orderView, int currentSize) { - OrderConstraints constraints = new OrderConstraints(dependencyGraph, orderView); + ConditionDependencyGraph.OrderConstraints constraints = dependencyGraph.createOrderConstraints(orderView); Bdd bestBdd = null; int bestSize = currentSize; List bestResults = null; @@ -339,7 +340,11 @@ private PositionCount findBestPosition( .orElse(null); } - private List getPositions(int varIdx, OrderConstraints constraints, OptimizationEffort effort) { + private List getPositions( + int varIdx, + ConditionDependencyGraph.OrderConstraints constraints, + OptimizationEffort effort + ) { int min = constraints.getMinValidPosition(varIdx); int max = constraints.getMaxValidPosition(varIdx); int range = max - min; @@ -348,7 +353,12 @@ private List getPositions(int varIdx, OrderConstraints constraints, Opt : getStrategicPositions(varIdx, min, max, range, constraints, effort); } - private List getExhaustivePositions(int varIdx, int min, int max, OrderConstraints constraints) { + private List getExhaustivePositions( + int varIdx, + int min, + int max, + ConditionDependencyGraph.OrderConstraints constraints + ) { List positions = new ArrayList<>(max - min); for (int p = min; p < max; p++) { if (p != varIdx && constraints.canMove(varIdx, p)) { @@ -363,7 +373,7 @@ private List getStrategicPositions( int min, int max, int range, - OrderConstraints constraints, + ConditionDependencyGraph.OrderConstraints constraints, OptimizationEffort effort ) { List positions = new ArrayList<>(effort.maxPositions); @@ -426,7 +436,7 @@ private static void move(Condition[] arr, int from, int to) { */ private BddCompilationResult compileBddWithResults(List ordering) { BddBuilder builder = threadBuilder.get().reset(); - BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.fixed(ordering), builder); + BddCompiler compiler = new BddCompiler(cfg, OrderingStrategy.fixed(ordering), builder); Bdd bdd = compiler.compile(); return new BddCompilationResult(bdd, compiler.getIndexedResults()); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java index 9bedc35478b..bd7f874e555 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java @@ -4,6 +4,7 @@ */ package software.amazon.smithy.rulesengine.logic.cfg; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -86,7 +87,7 @@ public CfgNode createCondition(ConditionReference condRef, CfgNode trueBranch, C * @return a result node (cached if identical rule already seen) */ public CfgNode createResult(Rule rule) { - Rule canonical = rule.withoutConditions(); + Rule canonical = rule.withConditions(Collections.emptyList()); Rule interned = resultCache.computeIfAbsent(canonical, k -> k); return resultNodeCache.computeIfAbsent(interned, ResultNode::new); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java new file mode 100644 index 00000000000..2d9d29af337 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java @@ -0,0 +1,285 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; +import software.amazon.smithy.rulesengine.language.evaluation.type.RecordType; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; + +/** + * Coalesces bind-then-use patterns in conditions, identifying conditions that bind a variable followed immediately by + * a condition that uses that variable, and merges them using coalesce. + */ +final class CoalesceTransform { + private static final Logger LOGGER = Logger.getLogger(CoalesceTransform.class.getName()); + + private final Map coalesceCache = new HashMap<>(); + private int coalesceCount = 0; + private int cacheHits = 0; + private int skippedNoZeroValue = 0; + private int skippedMultipleUses = 0; + private final Set skippedRecordTypes = new HashSet<>(); + + static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + CoalesceTransform transform = new CoalesceTransform(); + + List transformedRules = new ArrayList<>(); + for (Rule rule : ruleSet.getRules()) { + transformedRules.add(transform.transformRule(rule)); + } + + if (LOGGER.isLoggable(Level.INFO)) { + StringBuilder msg = new StringBuilder(); + msg.append("\n=== Coalescing Transform Complete ===\n"); + msg.append("Total: ").append(transform.coalesceCount).append(" coalesced, "); + msg.append(transform.cacheHits).append(" cache hits, "); + msg.append(transform.skippedNoZeroValue).append(" skipped (no zero value), "); + msg.append(transform.skippedMultipleUses).append(" skipped (multiple uses)"); + if (!transform.skippedRecordTypes.isEmpty()) { + msg.append("\nSkipped record-returning functions: ").append(transform.skippedRecordTypes); + } + LOGGER.info(msg.toString()); + } + + return EndpointRuleSet.builder() + .parameters(ruleSet.getParameters()) + .rules(transformedRules) + .version(ruleSet.getVersion()) + .build(); + } + + private Rule transformRule(Rule rule) { + Set eliminatedConditions = new HashSet<>(); + List conditions = rule.getConditions(); + Map localVarUsage = countLocalVariableUsage(conditions); + List transformedConditions = transformConditions(conditions, eliminatedConditions, localVarUsage); + + if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + List transformedNestedRules = new ArrayList<>(); + boolean nestedChanged = false; + + for (Rule nestedRule : treeRule.getRules()) { + Rule transformedNested = transformRule(nestedRule); + transformedNestedRules.add(transformedNested); + if (transformedNested != nestedRule) { + nestedChanged = true; + } + } + + if (!transformedConditions.equals(conditions) || nestedChanged) { + return TreeRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(transformedConditions) + .treeRule(transformedNestedRules); + } + } else if (!transformedConditions.equals(conditions)) { + // For other rule types, just update conditions + return rule.withConditions(transformedConditions); + } + + return rule; + } + + private Map countLocalVariableUsage(List conditions) { + Map usage = new HashMap<>(); + + // Count how many times each variable is used within this specific rule + for (Condition condition : conditions) { + for (String ref : condition.getFunction().getReferences()) { + usage.merge(ref, 1, Integer::sum); + } + } + + return usage; + } + + private List transformConditions( + List conditions, + Set eliminatedConditions, + Map localVarUsage + ) { + List result = new ArrayList<>(); + + for (int i = 0; i < conditions.size(); i++) { + Condition current = conditions.get(i); + if (eliminatedConditions.contains(current)) { + continue; + } + + // Check if this is a bind that can be coalesced with the next condition + if (i + 1 < conditions.size() && current.getResult().isPresent()) { + String var = current.getResult().get().toString(); + Condition next = conditions.get(i + 1); + + if (canCoalesce(var, current, next, localVarUsage)) { + // Create coalesced condition + Condition coalesced = createCoalescedCondition(current, next, var); + result.add(coalesced); + // Mark both conditions as eliminated + eliminatedConditions.add(current); + eliminatedConditions.add(next); + // Skip the next condition + i++; + continue; + } + } + + // No coalescing possible, keep the condition as-is + result.add(current); + } + + return result; + } + + private boolean canCoalesce(String var, Condition bind, Condition use, Map localVarUsage) { + if (!use.getFunction().getReferences().contains(var)) { + // The use condition must reference the variable + return false; + } else if (use.getFunction().getFunctionDefinition() == IsSet.getDefinition()) { + // Never coalesce into presence checks (isSet) + return false; + } + + // Check if variable is only used once in this local rule context (even if it appears multiple times globally) + Integer localUses = localVarUsage.get(var); + if (localUses == null || localUses > 1) { + skippedMultipleUses++; + return false; + } + + // Get the actual return type (could be Optional or T) + Type type = bind.getFunction().getFunctionDefinition().getReturnType(); + + // Check if we can get a zero value for this type. For OptionalType, we use the inner type's zero value + Type innerType = type; + if (type instanceof OptionalType) { + innerType = ((OptionalType) type).inner(); + } + + if (innerType instanceof RecordType) { + skippedNoZeroValue++; + skippedRecordTypes.add(bind.getFunction().getName()); + return false; + } + + if (!innerType.getZeroValue().isPresent()) { + skippedNoZeroValue++; + return false; + } + + return true; + } + + private Condition createCoalescedCondition(Condition bind, Condition use, String var) { + LibraryFunction bindExpr = bind.getFunction(); + LibraryFunction useExpr = use.getFunction(); + + // Get the type and its zero value + Type type = bindExpr.getFunctionDefinition().getReturnType(); + Type innerType = type; + if (type instanceof OptionalType) { + innerType = ((OptionalType) type).inner(); + } + + Literal zero = innerType.getZeroValue().get(); + + // Create cache key based on canonical representations + String bindCanonical = bindExpr.canonicalize().toString(); + String zeroCanonical = zero.toString(); + String useCanonical = useExpr.canonicalize().toString(); + String resultVar = use.getResult().map(Identifier::toString).orElse(""); + + CoalesceKey key = new CoalesceKey(bindCanonical, zeroCanonical, useCanonical, var, resultVar); + Condition cached = coalesceCache.get(key); + if (cached != null) { + cacheHits++; + return cached; + } + + Expression coalesced = Coalesce.ofExpressions(bindExpr, zero); + + // Replace the variable reference in the use expression + Map replacements = new HashMap<>(); + replacements.put(var, coalesced); + ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + Expression replaced = rewriter.rewrite(useExpr); + LibraryFunction canonicalized = ((LibraryFunction) replaced).canonicalize(); + + Condition.Builder builder = Condition.builder().fn(canonicalized); + if (use.getResult().isPresent()) { + builder.result(use.getResult().get()); + } + + Condition result = builder.build(); + coalesceCache.put(key, result); + coalesceCount++; + + if (LOGGER.isLoggable(Level.FINE)) { + LOGGER.fine("Coalesced #" + coalesceCount + ":\n" + + " " + var + " = " + bind.getFunction() + "\n" + + " " + use.getFunction() + "\n" + + " => " + canonicalized); + } + + return result; + } + + private static final class CoalesceKey { + final String bindFunction; + final String zeroValue; + final String useFunction; + final String replacedVar; + final String resultVar; + final int hashCode; + + CoalesceKey(String bindFunction, String zeroValue, String useFunction, String replacedVar, String resultVar) { + this.bindFunction = bindFunction; + this.zeroValue = zeroValue; + this.useFunction = useFunction; + this.replacedVar = replacedVar; + this.resultVar = resultVar; + this.hashCode = Objects.hash(bindFunction, zeroValue, useFunction, replacedVar, resultVar); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } else if (!(o instanceof CoalesceKey)) { + return false; + } + CoalesceKey that = (CoalesceKey) o; + return bindFunction.equals(that.bindFunction) && zeroValue.equals(that.zeroValue) + && useFunction.equals(that.useFunction) + && replacedVar.equals(that.replacedVar) + && resultVar.equals(that.resultVar); + } + + @Override + public int hashCode() { + return hashCode; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDependencyGraph.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDependencyGraph.java new file mode 100644 index 00000000000..5b4730f3633 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDependencyGraph.java @@ -0,0 +1,315 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; + +/** + * Graph of dependencies between conditions based on variable definitions and usage. + * + *

    This class performs AST analysis once to extract: + *

      + *
    • Variable definitions - which conditions define which variables
    • + *
    • Variable usage - which conditions use which variables
    • + *
    • Dependencies - which conditions must come before others
    • + *
    + */ +public final class ConditionDependencyGraph { + private final List conditions; + private final Map conditionToIndex; + private final Map> dependencies; + private final Map> variableDefiners; + private final Map> isSetConditions; + + // Indexed dependency information for fast access + private final List> predecessors; + private final List> successors; + + /** + * Creates a dependency graph by analyzing the given conditions. + * + * @param conditions the conditions to analyze + */ + public ConditionDependencyGraph(List conditions) { + this.conditions = Collections.unmodifiableList(new ArrayList<>(conditions)); + this.conditionToIndex = new HashMap<>(); + this.variableDefiners = new HashMap<>(); + this.isSetConditions = new HashMap<>(); + + int n = conditions.size(); + for (int i = 0; i < n; i++) { + conditionToIndex.put(conditions.get(i), i); + } + + // Initialize indexed structures + this.predecessors = new ArrayList<>(n); + this.successors = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + predecessors.add(new HashSet<>()); + successors.add(new HashSet<>()); + } + + // Categorize all conditions + for (Condition cond : conditions) { + // Track variable definition + if (cond.getResult().isPresent()) { + String definedVar = cond.getResult().get().toString(); + variableDefiners.computeIfAbsent(definedVar, k -> new HashSet<>()).add(cond); + } + + // Track isSet conditions + if (isIsSet(cond)) { + for (String var : cond.getFunction().getReferences()) { + isSetConditions.computeIfAbsent(var, k -> new HashSet<>()).add(cond); + } + } + } + + // Compute dependencies + Map> deps = new HashMap<>(); + Map> producers = new HashMap<>(); + Map> isSetters = new HashMap<>(); + + // Build producer and isSet indices using Identifier + for (int i = 0; i < n; i++) { + Condition c = conditions.get(i); + + if (c.getResult().isPresent()) { + Identifier var = c.getResult().get(); + producers.computeIfAbsent(var, k -> new HashSet<>()).add(i); + } + + if (isIsSet(c)) { + for (String ref : c.getFunction().getReferences()) { + Identifier var = Identifier.of(ref); + isSetters.computeIfAbsent(var, k -> new HashSet<>()).add(i); + } + } + } + + // Build both object-based and index-based dependencies + for (int i = 0; i < n; i++) { + Condition cond = conditions.get(i); + Set condDeps = new HashSet<>(); + + for (String usedVar : cond.getFunction().getReferences()) { + // Object-based dependencies + condDeps.addAll(variableDefiners.getOrDefault(usedVar, Collections.emptySet())); + if (!isIsSet(cond)) { + condDeps.addAll(isSetConditions.getOrDefault(usedVar, Collections.emptySet())); + } + + // Index-based dependencies + Identifier var = Identifier.of(usedVar); + for (int prod : producers.getOrDefault(var, Collections.emptySet())) { + if (prod != i) { + predecessors.get(i).add(prod); + successors.get(prod).add(i); + } + } + + if (!isIsSet(cond)) { + for (int setter : isSetters.getOrDefault(var, Collections.emptySet())) { + if (setter != i) { + predecessors.get(i).add(setter); + successors.get(setter).add(i); + } + } + } + } + + condDeps.remove(cond); // Remove self-dependencies + if (!condDeps.isEmpty()) { + deps.put(cond, Collections.unmodifiableSet(condDeps)); + } + } + + this.dependencies = Collections.unmodifiableMap(deps); + } + + /** + * Gets the dependencies for a condition. + * + * @param condition the condition to query + * @return set of conditions that must come before it (never null) + */ + public Set getDependencies(Condition condition) { + return dependencies.getOrDefault(condition, Collections.emptySet()); + } + + /** + * Gets the predecessors (dependencies) for a condition by index. + * + * @param index the condition index + * @return set of predecessor indices + */ + public Set getPredecessors(int index) { + return predecessors.get(index); + } + + /** + * Gets the successors (dependents) for a condition by index. + * + * @param index the condition index + * @return set of successor indices + */ + public Set getSuccessors(int index) { + return successors.get(index); + } + + /** + * Gets the number of predecessors for a condition. + * + * @param index the condition index + * @return the predecessor count + */ + public int getPredecessorCount(int index) { + return predecessors.get(index).size(); + } + + /** + * Gets the number of successors for a condition. + * + * @param index the condition index + * @return the successor count + */ + public int getSuccessorCount(int index) { + return successors.get(index).size(); + } + + /** + * Checks if there's a dependency from one condition to another. + * + * @param from the dependent condition index + * @param to the dependency condition index + * @return true if 'from' depends on 'to' + */ + public boolean hasDependency(int from, int to) { + return predecessors.get(from).contains(to); + } + + /** + * Creates order constraints for a specific ordering of conditions. + * + * @param ordering the ordering to compute constraints for + * @return the order constraints + */ + public OrderConstraints createOrderConstraints(List ordering) { + return new OrderConstraints(ordering); + } + + /** + * Gets the index mapping for conditions. + * + * @return map from condition to index + */ + public Map getConditionToIndex() { + return Collections.unmodifiableMap(conditionToIndex); + } + + /** + * Gets the number of conditions in this dependency graph. + * + * @return the number of conditions + */ + public int size() { + return conditions.size(); + } + + private static boolean isIsSet(Condition cond) { + return cond.getFunction().getFunctionDefinition() == IsSet.getDefinition(); + } + + /** + * Order-specific constraints for a particular condition ordering. + */ + public final class OrderConstraints { + private final Condition[] orderedConditions; + private final Map orderIndex; + private final int[] minValidPosition; + private final int[] maxValidPosition; + + private OrderConstraints(List ordering) { + int n = ordering.size(); + if (n != conditions.size()) { + throw new IllegalArgumentException( + "Ordering size (" + n + ") doesn't match dependency graph size (" + conditions.size() + ")"); + } + + this.orderedConditions = ordering.toArray(new Condition[0]); + this.orderIndex = new HashMap<>(n * 2); + this.minValidPosition = new int[n]; + this.maxValidPosition = new int[n]; + + // Build index mapping for this ordering + for (int i = 0; i < n; i++) { + orderIndex.put(orderedConditions[i], i); + } + + // Compute valid positions based on dependencies + for (int i = 0; i < n; i++) { + maxValidPosition[i] = n - 1; // Initialize max position + + Condition cond = orderedConditions[i]; + Integer originalIdx = conditionToIndex.get(cond); + if (originalIdx == null) { + throw new IllegalArgumentException("Condition not in dependency graph: " + cond); + } + + // Check all dependencies + for (int depIdx : predecessors.get(originalIdx)) { + Condition depCond = conditions.get(depIdx); + Integer depOrderIdx = orderIndex.get(depCond); + if (depOrderIdx != null) { + // This condition must come after its dependency + minValidPosition[i] = Math.max(minValidPosition[i], depOrderIdx + 1); + // The dependency must come before this condition + maxValidPosition[depOrderIdx] = Math.min(maxValidPosition[depOrderIdx], i - 1); + } + } + } + } + + /** + * Checks if moving a condition from one position to another would violate dependencies. + * + * @param from current position + * @param to target position + * @return true if the move is valid + */ + public boolean canMove(int from, int to) { + return from == to || (to >= minValidPosition[from] && to <= maxValidPosition[from]); + } + + /** + * Gets the minimum valid position for a condition. + * + * @param positionIndex the position index in the ordering + * @return the minimum position where this condition can be placed + */ + public int getMinValidPosition(int positionIndex) { + return minValidPosition[positionIndex]; + } + + /** + * Gets the maximum valid position for a condition. + * + * @param positionIndex the position index in the ordering + * @return the maximum position where this condition can be placed + */ + public int getMaxValidPosition(int positionIndex) { + return maxValidPosition[positionIndex]; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriter.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriter.java new file mode 100644 index 00000000000..6dda29af584 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriter.java @@ -0,0 +1,164 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Predicate; +import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionNode; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.RecordLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.TupleLiteral; + +/** + * Utility for rewriting references within expression trees. + */ +final class ReferenceRewriter { + + private final Function referenceTransformer; + private final Predicate shouldRewrite; + + /** + * Creates a new reference rewriter. + * + * @param referenceTransformer function to transform references + * @param shouldRewrite predicate to determine if an expression needs rewriting + */ + ReferenceRewriter( + Function referenceTransformer, + Predicate shouldRewrite + ) { + this.referenceTransformer = referenceTransformer; + this.shouldRewrite = shouldRewrite; + } + + /** + * Rewrites references within an expression tree. + * + * @param expression the expression to rewrite + * @return the rewritten expression, or the original if no changes needed + */ + public Expression rewrite(Expression expression) { + if (!shouldRewrite.test(expression)) { + return expression; + } + + if (expression instanceof StringLiteral) { + return rewriteStringLiteral((StringLiteral) expression); + } else if (expression instanceof TupleLiteral) { + return rewriteTupleLiteral((TupleLiteral) expression); + } else if (expression instanceof RecordLiteral) { + return rewriteRecordLiteral((RecordLiteral) expression); + } else if (expression instanceof Reference) { + return referenceTransformer.apply((Reference) expression); + } else if (expression instanceof LibraryFunction) { + return rewriteLibraryFunction((LibraryFunction) expression); + } + + return expression; + } + + private Expression rewriteStringLiteral(StringLiteral str) { + Template template = str.value(); + if (template.isStatic()) { + return str; + } + + StringBuilder templateBuilder = new StringBuilder(); + boolean changed = false; + + for (Template.Part part : template.getParts()) { + if (part instanceof Template.Dynamic) { + Template.Dynamic dynamic = (Template.Dynamic) part; + Expression original = dynamic.toExpression(); + Expression rewritten = rewrite(original); + if (rewritten != original) { + changed = true; + } + templateBuilder.append('{').append(rewritten).append('}'); + } else { + templateBuilder.append(((Template.Literal) part).getValue()); + } + } + + return changed ? Literal.stringLiteral(Template.fromString(templateBuilder.toString())) : str; + } + + private Expression rewriteTupleLiteral(TupleLiteral tuple) { + List rewrittenMembers = new ArrayList<>(); + boolean changed = false; + + for (Literal member : tuple.members()) { + Literal rewritten = (Literal) rewrite(member); + rewrittenMembers.add(rewritten); + if (rewritten != member) { + changed = true; + } + } + + return changed ? Literal.tupleLiteral(rewrittenMembers) : tuple; + } + + private Expression rewriteRecordLiteral(RecordLiteral record) { + Map rewrittenMembers = new LinkedHashMap<>(); + boolean changed = false; + + for (Map.Entry entry : record.members().entrySet()) { + Literal original = entry.getValue(); + Literal rewritten = (Literal) rewrite(original); + rewrittenMembers.put(entry.getKey(), rewritten); + if (rewritten != original) { + changed = true; + } + } + + return changed ? Literal.recordLiteral(rewrittenMembers) : record; + } + + private Expression rewriteLibraryFunction(LibraryFunction fn) { + List rewrittenArgs = new ArrayList<>(); + boolean changed = false; + + for (Expression arg : fn.getArguments()) { + Expression rewritten = rewrite(arg); + rewrittenArgs.add(rewritten); + if (rewritten != arg) { + changed = true; + } + } + + if (!changed) { + return fn; + } + + FunctionNode node = FunctionNode.builder() + .name(Node.from(fn.getName())) + .arguments(rewrittenArgs) + .build(); + return fn.getFunctionDefinition().createFunction(node); + } + + /** + * Creates a simple rewriter that replaces specific references. + * + * @param replacements map of variable names to replacement expressions + * @return a reference rewriter that performs the replacements + */ + public static ReferenceRewriter forReplacements(Map replacements) { + return new ReferenceRewriter( + ref -> replacements.getOrDefault(ref.getName().toString(), ref), + expr -> expr.getReferences().stream().anyMatch(replacements::containsKey)); + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java index 08956b65f93..127bfc83bad 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java @@ -14,20 +14,12 @@ import java.util.List; import java.util.Map; import java.util.Set; -import software.amazon.smithy.model.node.Node; import software.amazon.smithy.rulesengine.language.Endpoint; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; -import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; -import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionNode; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.RecordLiteral; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.TupleLiteral; -import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; @@ -44,44 +36,39 @@ * *

    Note that this transform is only applied when the reassignment is done using different * arguments than previously seen assignments of the same variable name. - * - *

    TODO: This transform does not yet introduce phi nodes at control flow merge points. - * We need to add an OR function to the rules engine to do that. */ final class SsaTransform { // Stack of scopes, each mapping original variable names to their current SSA names private final Deque> scopeStack = new ArrayDeque<>(); - - // Cache of already rewritten conditions to avoid redundant work private final Map rewrittenConditions = new IdentityHashMap<>(); - - // Cache of already rewritten rules private final Map rewrittenRules = new IdentityHashMap<>(); + private final VariableAnalysis variableAnalysis; + private final ReferenceRewriter referenceRewriter; - // Set of input parameter names that should never be rewritten - private final Set inputParams; - - // Map from variable name -> expression -> SSA name - // Pre-computed to ensure consistent naming across the tree - private final Map> variableExpressionMappings; - - private SsaTransform(Set inputParams, Map> variableExpressionMappings) { + private SsaTransform(VariableAnalysis variableAnalysis) { // Start with an empty global scope scopeStack.push(new HashMap<>()); - this.inputParams = inputParams; - this.variableExpressionMappings = variableExpressionMappings; + this.variableAnalysis = variableAnalysis; + + // Create a reference rewriter that uses our scope resolution + this.referenceRewriter = new ReferenceRewriter( + ref -> { + String originalName = ref.getName().toString(); + String uniqueName = resolveReference(originalName); + return Expression.getReference(Identifier.of(uniqueName)); + }, + this::needsRewriting); } static EndpointRuleSet transform(EndpointRuleSet ruleSet) { - Set inputParameters = extractInputParameters(ruleSet); + ruleSet = CoalesceTransform.transform(ruleSet); - // Collect all variable bindings and create unique names for each unique expression - Map> variableBindings = collectVariableBindings(ruleSet.getRules()); - Map> variableExpressionMappings = createExpressionMappings(variableBindings); + // Use VariableAnalysis to get all the information we need + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); // Rewrite with the pre-computed mappings - SsaTransform ssaTransform = new SsaTransform(inputParameters, variableExpressionMappings); + SsaTransform ssaTransform = new SsaTransform(analysis); List rewrittenRules = new ArrayList<>(ruleSet.getRules().size()); for (Rule original : ruleSet.getRules()) { rewrittenRules.add(ssaTransform.processRule(original)); @@ -94,80 +81,6 @@ static EndpointRuleSet transform(EndpointRuleSet ruleSet) { .build(); } - // Collect a set of input parameter names. We use these to know that expressions that only work with an input - // parameter and use the same arguments can be kept as-is rather than need a cloned and renamed assignment. - private static Set extractInputParameters(EndpointRuleSet ruleSet) { - Set inputParameters = new HashSet<>(); - for (Parameter param : ruleSet.getParameters()) { - inputParameters.add(param.getName().toString()); - } - return inputParameters; - } - - private static Map> collectVariableBindings(List rules) { - Map> variableBindings = new HashMap<>(); - collectBindingsFromRules(rules, variableBindings); - return variableBindings; - } - - private static void collectBindingsFromRules(List rules, Map> variableBindings) { - for (Rule rule : rules) { - collectBindingsFromRule(rule, variableBindings); - } - } - - private static void collectBindingsFromRule(Rule rule, Map> variableBindings) { - for (Condition condition : rule.getConditions()) { - if (condition.getResult().isPresent()) { - String varName = condition.getResult().get().toString(); - String expression = condition.getFunction().toString(); - variableBindings.computeIfAbsent(varName, k -> new HashSet<>()).add(expression); - } - } - - if (rule instanceof TreeRule) { - TreeRule treeRule = (TreeRule) rule; - collectBindingsFromRules(treeRule.getRules(), variableBindings); - } - } - - /** - * Creates a mapping from variable name -> expression -> SSA name. - * Variables assigned multiple times get unique SSA names (x, x_1, x_2, etc). - */ - private static Map> createExpressionMappings(Map> bindings) { - Map> result = new HashMap<>(); - for (Map.Entry> entry : bindings.entrySet()) { - String varName = entry.getKey(); - Set expressions = entry.getValue(); - result.put(varName, createMappingForVariable(varName, expressions)); - } - - return result; - } - - private static Map createMappingForVariable(String varName, Set expressions) { - Map mapping = new HashMap<>(); - - if (expressions.size() == 1) { - // Only one expression for this variable, so no SSA renaming needed - String expression = expressions.iterator().next(); - mapping.put(expression, varName); - } else { - // Multiple expressions, so create unique SSA names - List sortedExpressions = new ArrayList<>(expressions); - sortedExpressions.sort(String::compareTo); // Ensure deterministic ordering - - for (int i = 0; i < sortedExpressions.size(); i++) { - String expression = sortedExpressions.get(i); - String uniqueName = (i == 0) ? varName : varName + "_" + i; - mapping.put(expression, uniqueName); - } - } - - return mapping; - } - private Rule processRule(Rule rule) { enterScope(); Rule rewrittenRule = rewriteRule(rule); @@ -208,7 +121,7 @@ private Condition rewriteCondition(Condition condition) { boolean needsUniqueBinding = false; if (hasBinding) { String varName = condition.getResult().get().toString(); - Map expressionMap = variableExpressionMappings.get(varName); + Map expressionMap = variableAnalysis.getExpressionMappings().get(varName); if (expressionMap != null) { uniqueBindingName = expressionMap.get(fn.toString()); needsUniqueBinding = uniqueBindingName != null && !uniqueBindingName.equals(varName); @@ -224,7 +137,7 @@ private Condition rewriteCondition(Condition condition) { } // Rewrite the expression - LibraryFunction rewrittenExpr = (LibraryFunction) rewriteExpression(fn); + LibraryFunction rewrittenExpr = (LibraryFunction) referenceRewriter.rewrite(fn); boolean exprChanged = rewrittenExpr != fn; // Build the rewritten condition @@ -253,11 +166,11 @@ private Condition rewriteCondition(Condition condition) { } private Set filterOutInputParameters(Set references) { - if (references.isEmpty() || inputParams.isEmpty()) { + if (references.isEmpty() || variableAnalysis.getInputParams().isEmpty()) { return references; } Set filtered = new HashSet<>(references); - filtered.removeAll(inputParams); + filtered.removeAll(variableAnalysis.getInputParams()); return filtered; } @@ -324,7 +237,7 @@ private Rule rewriteEndpointRule( Endpoint endpoint = rule.getEndpoint(); // Rewrite endpoint components to use SSA names - Expression rewrittenUrl = rewriteExpression(endpoint.getUrl()); + Expression rewrittenUrl = referenceRewriter.rewrite(endpoint.getUrl()); Map> rewrittenHeaders = rewriteHeaders(endpoint.getHeaders()); Map rewrittenProperties = rewriteProperties(endpoint.getProperties()); @@ -347,7 +260,7 @@ private Rule rewriteEndpointRule( } private Rule rewriteErrorRule(ErrorRule rule, List rewrittenConditions, boolean conditionsChanged) { - Expression rewrittenError = rewriteExpression(rule.getError()); + Expression rewrittenError = referenceRewriter.rewrite(rule.getError()); if (conditionsChanged || rewrittenError != rule.getError()) { return ErrorRule.builder() @@ -388,7 +301,7 @@ private Map> rewriteHeaders(Map> entry : headers.entrySet()) { List rewrittenValues = new ArrayList<>(); for (Expression expr : entry.getValue()) { - rewrittenValues.add(rewriteExpression(expr)); + rewrittenValues.add(referenceRewriter.rewrite(expr)); } rewritten.put(entry.getKey(), rewrittenValues); } @@ -398,7 +311,7 @@ private Map> rewriteHeaders(Map rewriteProperties(Map properties) { Map rewritten = new LinkedHashMap<>(); for (Map.Entry entry : properties.entrySet()) { - Expression rewrittenExpr = rewriteExpression(entry.getValue()); + Expression rewrittenExpr = referenceRewriter.rewrite(entry.getValue()); if (!(rewrittenExpr instanceof Literal)) { throw new IllegalStateException("Property value must be a literal"); } @@ -407,81 +320,9 @@ private Map rewriteProperties(Map prop return rewritten; } - // Recursively rewrites an expression to use SSA names - private Expression rewriteExpression(Expression expression) { - if (!needsRewriting(expression)) { - return expression; - } - - if (expression instanceof StringLiteral) { - return rewriteStringLiteral((StringLiteral) expression); - } else if (expression instanceof TupleLiteral) { - return rewriteTupleLiteral((TupleLiteral) expression); - } else if (expression instanceof RecordLiteral) { - return rewriteRecordLiteral((RecordLiteral) expression); - } else if (expression instanceof Reference) { - return rewriteReference((Reference) expression); - } else if (expression instanceof LibraryFunction) { - return rewriteLibraryFunction((LibraryFunction) expression); - } - - return expression; - } - - private Expression rewriteStringLiteral(StringLiteral str) { - Template template = str.value(); - if (template.isStatic()) { - return str; - } - - StringBuilder templateBuilder = new StringBuilder(); - for (Template.Part part : template.getParts()) { - if (part instanceof Template.Dynamic) { - Template.Dynamic dynamic = (Template.Dynamic) part; - Expression rewritten = rewriteExpression(dynamic.toExpression()); - templateBuilder.append('{').append(rewritten).append('}'); - } else { - templateBuilder.append(((Template.Literal) part).getValue()); - } - } - return Literal.stringLiteral(Template.fromString(templateBuilder.toString())); - } - - private Expression rewriteTupleLiteral(TupleLiteral tuple) { - List rewrittenMembers = new ArrayList<>(); - for (Literal member : tuple.members()) { - rewrittenMembers.add((Literal) rewriteExpression(member)); - } - return Literal.tupleLiteral(rewrittenMembers); - } - - private Expression rewriteRecordLiteral(RecordLiteral record) { - Map rewrittenMembers = new LinkedHashMap<>(); - for (Map.Entry entry : record.members().entrySet()) { - rewrittenMembers.put(entry.getKey(), (Literal) rewriteExpression(entry.getValue())); - } - return Literal.recordLiteral(rewrittenMembers); - } - - private Expression rewriteReference(Reference ref) { - String originalName = ref.getName().toString(); - String uniqueName = resolveReference(originalName); - return Expression.getReference(Identifier.of(uniqueName)); - } - - private Expression rewriteLibraryFunction(LibraryFunction fn) { - List rewrittenArgs = new ArrayList<>(fn.getArguments()); - rewrittenArgs.replaceAll(this::rewriteExpression); - FunctionNode node = FunctionNode.builder() - .name(Node.from(fn.getName())) - .arguments(rewrittenArgs) - .build(); - return fn.getFunctionDefinition().createFunction(node); - } - private String resolveReference(String originalName) { // Input parameters are never rewritten - return inputParams.contains(originalName) + return variableAnalysis.getInputParams().contains(originalName) ? originalName : scopeStack.peek().getOrDefault(originalName, originalName); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java new file mode 100644 index 00000000000..f870024c60a --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java @@ -0,0 +1,245 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.RecordLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.TupleLiteral; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; + +/** + * Analyzes variables in an endpoint rule set, collecting bindings, reference counts, + * and other metadata needed for SSA transformation and optimization. + */ +final class VariableAnalysis { + private final Set inputParams; + private final Map> bindings; + private final Map referenceCounts; + private final Map> expressionMappings; + + private VariableAnalysis( + Set inputParams, + Map> bindings, + Map referenceCounts + ) { + this.inputParams = inputParams; + this.bindings = bindings; + this.referenceCounts = referenceCounts; + this.expressionMappings = createExpressionMappings(bindings); + } + + static VariableAnalysis analyze(EndpointRuleSet ruleSet) { + Set inputParameters = extractInputParameters(ruleSet); + Map> variableBindings = new HashMap<>(); + Map referenceCounts = new HashMap<>(); + + Visitor visitor = new Visitor(variableBindings, referenceCounts); + for (Rule rule : ruleSet.getRules()) { + visitor.visitRule(rule); + } + + return new VariableAnalysis(inputParameters, variableBindings, referenceCounts); + } + + Set getInputParams() { + return inputParams; + } + + /** + * Gets the mapping from variable name to expression to SSA name. + * This is used to determine when variables need SSA renaming. + */ + Map> getExpressionMappings() { + return expressionMappings; + } + + /** + * Gets the reference count for a variable. + * + * @param variableName the variable name + * @return the number of times the variable is referenced, or 0 if not found + */ + int getReferenceCount(String variableName) { + return referenceCounts.getOrDefault(variableName, 0); + } + + /** + * Checks if a variable is referenced exactly once. + * + * @param variableName the variable name + * @return true if the variable is referenced exactly once + */ + boolean isReferencedOnce(String variableName) { + return getReferenceCount(variableName) == 1; + } + + /** + * Checks if a variable has only one binding expression. + * + * @param variableName the variable name + * @return true if the variable has exactly one binding + */ + boolean hasSingleBinding(String variableName) { + Set expressions = bindings.get(variableName); + return expressions != null && expressions.size() == 1; + } + + /** + * Checks if a variable is safe to inline (single binding, single reference). + * + * @param variableName the variable name + * @return true if the variable can be safely inlined + */ + boolean isSafeToInline(String variableName) { + return hasSingleBinding(variableName) && isReferencedOnce(variableName); + } + + private static Set extractInputParameters(EndpointRuleSet ruleSet) { + Set inputParameters = new HashSet<>(); + for (Parameter param : ruleSet.getParameters()) { + inputParameters.add(param.getName().toString()); + } + return inputParameters; + } + + /** + * Creates a mapping from variable name to expression to SSA name. + * Variables assigned multiple times get unique SSA names (x, x_1, x_2, etc). + */ + private static Map> createExpressionMappings( + Map> bindings + ) { + Map> result = new HashMap<>(); + for (Map.Entry> entry : bindings.entrySet()) { + String varName = entry.getKey(); + Set expressions = entry.getValue(); + result.put(varName, createMappingForVariable(varName, expressions)); + } + return result; + } + + private static Map createMappingForVariable( + String varName, + Set expressions + ) { + Map mapping = new HashMap<>(); + + if (expressions.size() == 1) { + // Only one expression for this variable, so no SSA renaming needed + String expression = expressions.iterator().next(); + mapping.put(expression, varName); + } else { + // Multiple expressions, so create unique SSA names + List sortedExpressions = new ArrayList<>(expressions); + sortedExpressions.sort(String::compareTo); // Ensure deterministic ordering + + for (int i = 0; i < sortedExpressions.size(); i++) { + String expression = sortedExpressions.get(i); + String uniqueName = (i == 0) ? varName : varName + "_" + i; + mapping.put(expression, uniqueName); + } + } + + return mapping; + } + + // Visitor that collects variable bindings and reference counts. + private static class Visitor { + private final Map> variableBindings; + private final Map referenceCounts; + + Visitor(Map> variableBindings, Map referenceCounts) { + this.variableBindings = variableBindings; + this.referenceCounts = referenceCounts; + } + + void visitRule(Rule rule) { + for (Condition condition : rule.getConditions()) { + if (condition.getResult().isPresent()) { + String varName = condition.getResult().get().toString(); + String expression = condition.getFunction().toString(); + variableBindings.computeIfAbsent(varName, k -> new HashSet<>()) + .add(expression); + } + + countReferences(condition.getFunction()); + } + + if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + for (Rule nestedRule : treeRule.getRules()) { + visitRule(nestedRule); + } + } else if (rule instanceof EndpointRule) { + EndpointRule endpointRule = (EndpointRule) rule; + Endpoint endpoint = endpointRule.getEndpoint(); + countReferences(endpoint.getUrl()); + for (List headerValues : endpoint.getHeaders().values()) { + for (Expression expr : headerValues) { + countReferences(expr); + } + } + for (Literal literal : endpoint.getProperties().values()) { + countReferences(literal); + } + } else if (rule instanceof ErrorRule) { + ErrorRule errorRule = (ErrorRule) rule; + countReferences(errorRule.getError()); + } + } + + private void countReferences(Expression expression) { + if (expression instanceof Reference) { + Reference ref = (Reference) expression; + String name = ref.getName().toString(); + referenceCounts.merge(name, 1, Integer::sum); + } else if (expression instanceof StringLiteral) { + StringLiteral str = (StringLiteral) expression; + Template template = str.value(); + if (!template.isStatic()) { + for (Template.Part part : template.getParts()) { + if (part instanceof Template.Dynamic) { + Template.Dynamic dynamic = (Template.Dynamic) part; + countReferences(dynamic.toExpression()); + } + } + } + } else if (expression instanceof LibraryFunction) { + LibraryFunction fn = (LibraryFunction) expression; + for (Expression arg : fn.getArguments()) { + countReferences(arg); + } + } else if (expression instanceof TupleLiteral) { + TupleLiteral tuple = (TupleLiteral) expression; + for (Literal member : tuple.members()) { + countReferences(member); + } + } else if (expression instanceof RecordLiteral) { + RecordLiteral record = (RecordLiteral) expression; + for (Literal value : record.members().values()) { + countReferences(value); + } + } + } + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java new file mode 100644 index 00000000000..66e80ce32f4 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java @@ -0,0 +1,136 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.syntax.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; + +public class CoalesceTest { + + @Test + void testCoalesceWithSameTypes() { + Expression left = Literal.of("default"); + Expression right = Literal.of("fallback"); + Coalesce coalesce = Coalesce.ofExpressions(left, right); + + Scope scope = new Scope<>(); + Type resultType = coalesce.typeCheck(scope); + + assertEquals(Type.stringType(), resultType); + } + + @Test + void testCoalesceWithOptionalLeft() { + Expression optionalVar = Expression.getReference(Identifier.of("maybeValue")); + Expression fallback = Literal.of("default"); + Coalesce coalesce = Coalesce.ofExpressions(optionalVar, fallback); + + Scope scope = new Scope<>(); + scope.insert("maybeValue", Type.optionalType(Type.stringType())); + + Type resultType = coalesce.typeCheck(scope); + + // Should unwrap optional and return non-optional String + assertEquals(Type.stringType(), resultType); + } + + @Test + void testCoalesceWithBothOptional() { + Expression var1 = Expression.getReference(Identifier.of("maybe1")); + Expression var2 = Expression.getReference(Identifier.of("maybe2")); + Coalesce coalesce = Coalesce.ofExpressions(var1, var2); + + Scope scope = new Scope<>(); + scope.insert("maybe1", Type.optionalType(Type.stringType())); + scope.insert("maybe2", Type.optionalType(Type.stringType())); + + Type resultType = coalesce.typeCheck(scope); + + // Both optional means result is optional + assertEquals(Type.optionalType(Type.stringType()), resultType); + } + + @Test + void testCoalesceWithCompatibleTypes() { + // Test with optional types that should resolve to non-optional + Expression optionalString = Expression.getReference(Identifier.of("optional")); + Expression requiredString = Expression.getReference(Identifier.of("required")); + Coalesce coalesce = Coalesce.ofExpressions(optionalString, requiredString); + + Scope scope = new Scope<>(); + scope.insert("optional", Type.optionalType(Type.stringType())); + scope.insert("required", Type.stringType()); + + Type resultType = coalesce.typeCheck(scope); + + // When coalescing Optional with String, should return String + assertEquals(Type.stringType(), resultType); + } + + @Test + void testCoalesceWithIncompatibleTypes() { + Expression stringExpr = Literal.of("text"); + Expression intExpr = Literal.of(42); + Coalesce coalesce = Coalesce.ofExpressions(stringExpr, intExpr); + + Scope scope = new Scope<>(); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> coalesce.typeCheck(scope)); + assertTrue(ex.getMessage().contains("Type mismatch in coalesce")); + } + + @Test + void testCoalesceNonOptionalWithNonOptional() { + Expression int1 = Literal.of(42); + Expression int2 = Literal.of(100); + Coalesce coalesce = Coalesce.ofExpressions(int1, int2); + + Scope scope = new Scope<>(); + Type resultType = coalesce.typeCheck(scope); + + // Two non-optionals of same type should return that type + assertEquals(Type.integerType(), resultType); + } + + @Test + void testCoalesceOptionalWithNonOptional() { + Expression optionalInt = Expression.getReference(Identifier.of("maybeInt")); + Expression defaultInt = Literal.of(0); + Coalesce coalesce = Coalesce.ofExpressions(optionalInt, defaultInt); + + Scope scope = new Scope<>(); + scope.insert("maybeInt", Type.optionalType(Type.integerType())); + + Type resultType = coalesce.typeCheck(scope); + + // Optional coalesced with Int should return Int + assertEquals(Type.integerType(), resultType); + } + + @Test + void testCoalesceArrayTypes() { + Expression arr1 = Expression.getReference(Identifier.of("array1")); + Expression arr2 = Expression.getReference(Identifier.of("array2")); + Coalesce coalesce = Coalesce.ofExpressions(arr1, arr2); + + Scope scope = new Scope<>(); + scope.insert("array1", Type.arrayType(Type.stringType())); + scope.insert("array2", Type.arrayType(Type.stringType())); + + Type resultType = coalesce.typeCheck(scope); + + assertEquals(Type.arrayType(Type.stringType()), resultType); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java index 0c6e57eac7b..40f7fa09f1a 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddCompilerTest.java @@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; @@ -58,15 +59,15 @@ void testCompileSimpleEndpointRule() { .build(); Cfg cfg = Cfg.from(ruleSet); - BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); Bdd bdd = compiler.compile(); assertNotNull(bdd); assertEquals(1, bdd.getConditionCount()); - // Results include: endpoint when condition true, no match when false - assertTrue(bdd.getResultCount() >= 2); - assertTrue(bdd.getRootRef() > 0); + // Results include: NoMatchRule, endpoint, and possibly a terminal for the false branch + assertEquals(3, bdd.getResultCount()); + assertTrue(bdd.getRootRef() != 0); } @Test @@ -82,13 +83,13 @@ void testCompileErrorRule() { .build(); Cfg cfg = Cfg.from(ruleSet); - BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); Bdd bdd = compiler.compile(); assertEquals(1, bdd.getConditionCount()); - // Similar to endpoint rule - assertTrue(bdd.getResultCount() >= 2); + // Results: NoMatchRule, error, and possibly a terminal + assertEquals(3, bdd.getResultCount()); } @Test @@ -104,12 +105,12 @@ void testCompileTreeRule() { EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).addRule(treeRule).build(); Cfg cfg = Cfg.from(ruleSet); - BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); Bdd bdd = compiler.compile(); assertEquals(2, bdd.getConditionCount()); - assertTrue(bdd.getNodeCount() > 2); // Should have multiple nodes + assertTrue(bdd.getNodeCount() >= 2); // Should have multiple nodes } @Test @@ -143,15 +144,19 @@ void testCompileWithCustomOrdering() { assertNotNull(condB, "Could not find condition for B"); // Use fixed ordering (B before A) - ConditionOrderingStrategy customOrdering = ConditionOrderingStrategy.fixed(Arrays.asList(condB, condA)); + OrderingStrategy customOrdering = OrderingStrategy.fixed(Arrays.asList(condB, condA)); BddCompiler compiler = new BddCompiler(cfg, customOrdering, new BddBuilder()); Bdd bdd = compiler.compile(); - // Verify ordering was applied by checking the compiled BDD - // Since we don't have access to conditions from Bdd anymore, we just verify compilation succeeded + List orderedConditions = compiler.getOrderedConditions(); + + // Verify ordering was applied assertEquals(2, bdd.getConditionCount()); - assertNotNull(bdd); + assertEquals(2, orderedConditions.size()); + // Verify B comes before A in the ordering + assertEquals(condB, orderedConditions.get(0)); + assertEquals(condA, orderedConditions.get(1)); } @Test @@ -160,14 +165,14 @@ void testCompileEmptyRuleSet() { EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(Parameters.builder().build()).build(); Cfg cfg = Cfg.from(ruleSet); - BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); Bdd bdd = compiler.compile(); assertEquals(0, bdd.getConditionCount()); - // Even with no rules, there's still a result (no match) - assertTrue(bdd.getResultCount() > 0); + // Even with no rules, there's the NoMatchRule and possibly a terminal + assertEquals(2, bdd.getResultCount()); // Should have at least terminal node - assertTrue(bdd.getNodeCount() > 0); + assertEquals(1, bdd.getNodeCount()); } @Test @@ -193,12 +198,54 @@ void testCompileSameResultMultiplePaths() { .build(); Cfg cfg = Cfg.from(ruleSet); - BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); Bdd bdd = compiler.compile(); + List results = compiler.getIndexedResults(); - // The BDD compiler might create separate result nodes even for same endpoint - // depending on how the CFG is structured - assertEquals(3, bdd.getResultCount()); + // Should have 2 conditions + assertEquals(2, bdd.getConditionCount()); + // Results: NoMatchRule at index 0, plus the endpoint(s) + // The compiler may deduplicate identical endpoints or keep them separate + assertTrue(bdd.getResultCount() >= 2); + assertTrue(bdd.getResultCount() <= 3); + + // Verify NoMatchRule is always at index 0 + assertEquals("NoMatchRule", results.get(0).getClass().getSimpleName()); + } + + @Test + void testCompileWithReduction() { + // Test that the BDD is properly reduced after compilation + Rule rule = EndpointRule.builder() + .conditions( + Condition.builder().fn(TestHelpers.isSet("Region")).build(), + Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(REGION_PARAM) + .addParameter(BUCKET_PARAM) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddBuilder builder = new BddBuilder(); + BddCompiler compiler = new BddCompiler(cfg, builder); + + Bdd bdd = compiler.compile(); + + // The BDD should be reduced (no redundant nodes) + assertNotNull(bdd); + assertEquals(2, bdd.getConditionCount()); + + // After reduction, we should have a minimal BDD + // For 2 conditions with AND semantics leading to one endpoint: + // We expect approximately 3-4 nodes (depending on the exact structure) + assertTrue(bdd.getNodeCount() <= 5, "BDD should be reduced to minimal form"); } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceCheckerTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceCheckerTest.java index 403b27bfc78..458f696bdf6 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceCheckerTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceCheckerTest.java @@ -37,25 +37,19 @@ void testSimpleEquivalentBdd() { .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) .endpoint(TestHelpers.endpoint("https://example.com")), // Default case - ErrorRule.builder() - .error(Literal.of("No region provided")))) + ErrorRule.builder().error(Literal.of("No region provided")))) .build(); - // Convert to CFG Cfg cfg = Cfg.from(ruleSet); - - // Create BDD from CFG - BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); Bdd bdd = compiler.compile(); - // Create checker BddEquivalenceChecker checker = BddEquivalenceChecker.of( cfg, bdd, compiler.getOrderedConditions(), compiler.getIndexedResults()); - // Should pass verification assertDoesNotThrow(checker::verify); } @@ -64,14 +58,11 @@ void testEmptyRulesetEquivalence() { // Empty ruleset with a default endpoint EndpointRuleSet ruleSet = EndpointRuleSet.builder() .parameters(Parameters.builder().build()) - .rules(ListUtils.of( - EndpointRule.builder() - .endpoint(TestHelpers.endpoint("https://default.com")))) + .rules(ListUtils.of(EndpointRule.builder().endpoint(TestHelpers.endpoint("https://default.com")))) .build(); Cfg cfg = Cfg.from(ruleSet); - - BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); Bdd bdd = compiler.compile(); BddEquivalenceChecker checker = BddEquivalenceChecker.of( @@ -97,13 +88,11 @@ void testMultipleConditionsEquivalence() { Condition.builder().fn(TestHelpers.isSet("Region")).build(), Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) .endpoint(TestHelpers.endpoint("https://example.com")), - // Default case - ErrorRule.builder() - .error(Literal.of("Missing required parameters")))) + ErrorRule.builder().error(Literal.of("Missing required parameters")))) .build(); Cfg cfg = Cfg.from(ruleSet); - BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); Bdd bdd = compiler.compile(); BddEquivalenceChecker checker = BddEquivalenceChecker.of( @@ -130,9 +119,8 @@ void testSetMaxSamples() { .endpoint(TestHelpers.endpoint("https://example" + i + ".com"))); } - // Add default case - rules.add(ErrorRule.builder() - .error(Literal.of("No parameters set"))); + // default case + rules.add(ErrorRule.builder().error(Literal.of("No parameters set"))); EndpointRuleSet ruleSet = EndpointRuleSet.builder() .parameters(paramsBuilder.build()) @@ -140,7 +128,7 @@ void testSetMaxSamples() { .build(); Cfg cfg = Cfg.from(ruleSet); - BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); Bdd bdd = compiler.compile(); BddEquivalenceChecker checker = BddEquivalenceChecker.of( @@ -166,12 +154,11 @@ void testSetMaxDuration() { EndpointRule.builder() .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) .endpoint(TestHelpers.endpoint("https://example.com")), - ErrorRule.builder() - .error(Literal.of("No region provided")))) + ErrorRule.builder().error(Literal.of("No region provided")))) .build(); Cfg cfg = Cfg.from(ruleSet); - BddCompiler compiler = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); Bdd bdd = compiler.compile(); BddEquivalenceChecker checker = BddEquivalenceChecker.of( @@ -185,4 +172,42 @@ void testSetMaxDuration() { assertDoesNotThrow(checker::verify); } + + @Test + void testLargeNumberOfConditions() { + // Test with 25 conditions to ensure it uses sampling rather than exhaustive testing + Parameters.Builder paramsBuilder = Parameters.builder(); + List conditions = new ArrayList<>(); + + for (int i = 0; i < 25; i++) { + String paramName = String.format("Param%02d", i); + paramsBuilder.addParameter(Parameter.builder().name(paramName).type(ParameterType.STRING).build()); + conditions.add(Condition.builder().fn(TestHelpers.isSet(paramName)).build()); + } + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(paramsBuilder.build()) + .rules(ListUtils.of( + EndpointRule.builder() + .conditions(conditions) + .endpoint(TestHelpers.endpoint("https://example.com")), + ErrorRule.builder().error(Literal.of("Not all parameters set")))) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); + Bdd bdd = compiler.compile(); + + BddEquivalenceChecker checker = BddEquivalenceChecker.of( + cfg, + bdd, + compiler.getOrderedConditions(), + compiler.getIndexedResults()); + + // Set reasonable limits for large condition sets + checker.setMaxSamples(10000); + checker.setMaxDuration(Duration.ofSeconds(5)); + + assertDoesNotThrow(checker::verify); + } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java index 9f320f50635..a0c2483052d 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTest.java @@ -74,7 +74,7 @@ void testFromRuleSet() { .build(); Cfg cfg = Cfg.from(ruleSet); - Bdd bdd = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()).compile(); + Bdd bdd = new BddCompiler(cfg, OrderingStrategy.initialOrdering(cfg), new BddBuilder()).compile(); assertTrue(bdd.getConditionCount() > 0); assertTrue(bdd.getResultCount() > 0); @@ -89,7 +89,7 @@ void testFromCfg() { .build(); Cfg cfg = Cfg.from(ruleSet); - Bdd bdd = new BddCompiler(cfg, ConditionOrderingStrategy.defaultOrdering(), new BddBuilder()).compile(); + Bdd bdd = new BddCompiler(cfg, OrderingStrategy.initialOrdering(cfg), new BddBuilder()).compile(); assertEquals(0, bdd.getConditionCount()); // No conditions assertTrue(bdd.getResultCount() > 0); @@ -297,7 +297,7 @@ void testEquals() { assertNotEquals(bdd1, bdd6); // Different root ref (use result reference) - Bdd bdd7 = new Bdd(Bdd.RESULT_OFFSET + 0, 1, 1, 2, consumer -> { + Bdd bdd7 = new Bdd(Bdd.RESULT_OFFSET, 1, 1, 2, consumer -> { consumer.accept(-1, 1, -1); consumer.accept(0, 1, -1); }); diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/CfgConeAnalysisTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/CfgConeAnalysisTest.java new file mode 100644 index 00000000000..df1c9a6ec80 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/CfgConeAnalysisTest.java @@ -0,0 +1,233 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.utils.ListUtils; + +class CfgConeAnalysisTest { + + @Test + void testSingleConditionSingleResult() { + // Simple rule: if Region is set, return endpoint + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build()) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Condition[] conditions = cfg.getConditions(); + Map conditionToIndex = new HashMap<>(); + for (int i = 0; i < conditions.length; i++) { + conditionToIndex.put(conditions[i], i); + } + + CfgConeAnalysis analysis = new CfgConeAnalysis(cfg, conditions, conditionToIndex); + + // The single condition should have dominator depth 0 (at root) + assertEquals(0, analysis.dominatorDepth(0)); + // Should reach 2 result nodes (the endpoint and the terminal/no-match) + assertEquals(2, analysis.coneSize(0)); + } + + @Test + void testChainedConditions() { + // Rule with two conditions in sequence (AND logic) + Rule rule = EndpointRule.builder() + .conditions( + Condition.builder().fn(TestHelpers.isSet("Region")).build(), + Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Condition[] conditions = cfg.getConditions(); + Map conditionToIndex = new HashMap<>(); + for (int i = 0; i < conditions.length; i++) { + conditionToIndex.put(conditions[i], i); + } + + CfgConeAnalysis analysis = new CfgConeAnalysis(cfg, conditions, conditionToIndex); + + // First condition should be at depth 0 + assertEquals(0, analysis.dominatorDepth(0)); + // Second condition should be at depth 1 (one edge from root) + assertEquals(1, analysis.dominatorDepth(1)); + + // Both conditions lead to the same result + assertTrue(analysis.coneSize(0) >= 1); + assertTrue(analysis.coneSize(1) >= 1); + } + + @Test + void testMultipleBranches() { + // Two separate rules leading to different endpoints + Rule rule1 = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://regional.com")); + + Rule rule2 = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://bucket.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Condition[] conditions = cfg.getConditions(); + Map conditionToIndex = new HashMap<>(); + for (int i = 0; i < conditions.length; i++) { + conditionToIndex.put(conditions[i], i); + } + + CfgConeAnalysis analysis = new CfgConeAnalysis(cfg, conditions, conditionToIndex); + + // Both conditions should be at the root level (depth 0 or 1) + assertTrue(analysis.dominatorDepth(0) <= 1); + assertTrue(analysis.dominatorDepth(1) <= 1); + + // Each condition reaches at least one result + assertTrue(analysis.coneSize(0) >= 1); + assertTrue(analysis.coneSize(1) >= 1); + } + + @Test + void testNestedTreeRule() { + // Tree rule with nested structure + Rule innerRule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://bucket.com")); + + Rule treeRule = TreeRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .treeRule(innerRule); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(treeRule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Condition[] conditions = cfg.getConditions(); + Map conditionToIndex = new HashMap<>(); + for (int i = 0; i < conditions.length; i++) { + conditionToIndex.put(conditions[i], i); + } + + CfgConeAnalysis analysis = new CfgConeAnalysis(cfg, conditions, conditionToIndex); + + // Region condition should be at root (depth 0) + int regionIdx = -1; + int bucketIdx = -1; + for (int i = 0; i < conditions.length; i++) { + if (conditions[i].toString().contains("Region")) { + regionIdx = i; + } else if (conditions[i].toString().contains("Bucket")) { + bucketIdx = i; + } + } + + if (regionIdx >= 0) { + assertEquals(0, analysis.dominatorDepth(regionIdx)); + } + + // Bucket condition should be deeper (at least depth 1) + if (bucketIdx >= 0) { + assertTrue(analysis.dominatorDepth(bucketIdx) >= 1); + } + } + + @Test + void testErrorRule() { + // Rule that returns an error instead of endpoint + Rule rule = ErrorRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("InvalidParam")).build()) + .error("Invalid parameter provided"); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("InvalidParam").type(ParameterType.STRING).build()) + .build()) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Condition[] conditions = cfg.getConditions(); + Map conditionToIndex = new HashMap<>(); + for (int i = 0; i < conditions.length; i++) { + conditionToIndex.put(conditions[i], i); + } + + CfgConeAnalysis analysis = new CfgConeAnalysis(cfg, conditions, conditionToIndex); + + // The condition should be at root + assertEquals(0, analysis.dominatorDepth(0)); + // Should reach 2 results (the error and terminal/no-match) + assertEquals(2, analysis.coneSize(0)); + } + + @Test + void testEmptyCfg() { + // Empty ruleset + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().build()) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + Condition[] conditions = cfg.getConditions(); + + // Should have no conditions + assertEquals(0, conditions.length); + + // Analysis should handle empty CFG gracefully + Map conditionToIndex = new HashMap<>(); + CfgConeAnalysis analysis = new CfgConeAnalysis(cfg, conditions, conditionToIndex); + + // No assertions needed - just verify it doesn't throw + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/CfgGuidedOrderingTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/CfgGuidedOrderingTest.java new file mode 100644 index 00000000000..fa7374c1357 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/CfgGuidedOrderingTest.java @@ -0,0 +1,335 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.bdd; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.rulesengine.logic.cfg.Cfg; +import software.amazon.smithy.utils.ListUtils; + +class CfgGuidedOrderingTest { + + @Test + void testSimpleOrdering() { + // Single rule with one condition + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build()) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + CfgGuidedOrdering ordering = new CfgGuidedOrdering(cfg); + + List ordered = ordering.orderConditions(cfg.getConditions()); + + assertNotNull(ordered); + assertEquals(1, ordered.size()); + assertTrue(ordered.get(0).toString().contains("Region")); + } + + @Test + void testDependencyOrdering() { + // Rule with variable dependencies: x = isSet(A), then use x + Condition defineX = Condition.builder() + .fn(TestHelpers.isSet("A")) + .result(Identifier.of("x")) + .build(); + + Condition useX = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("x")), + Literal.of(true))) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(defineX, useX) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("A").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + + // After SSA and coalesce transforms, there might be only one condition + Condition[] conditions = cfg.getConditions(); + + // If coalesce merged them, we'll have 1 condition. Otherwise 2. + assertTrue(conditions.length >= 1 && conditions.length <= 2); + + CfgGuidedOrdering ordering = new CfgGuidedOrdering(cfg); + List ordered = ordering.orderConditions(conditions); + + assertEquals(conditions.length, ordered.size()); + + // If we still have 2 conditions, verify dependency order + if (ordered.size() == 2) { + // Find which condition defines x and which uses it + int defineIndex = -1; + int useIndex = -1; + for (int i = 0; i < ordered.size(); i++) { + if (ordered.get(i).getResult().isPresent() && + ordered.get(i).getResult().get().toString().equals("x")) { + defineIndex = i; + } else if (ordered.get(i).toString().contains("x")) { + useIndex = i; + } + } + + // Define must come before use (if both exist) + if (defineIndex >= 0 && useIndex >= 0) { + assertTrue(defineIndex < useIndex, "Definition of x must come before its use"); + } + } + } + + @Test + void testGateConditionPriority() { + // Create a gate condition (isSet) that multiple other conditions depend on + Condition gate = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + // Multiple conditions that use hasRegion + Condition branch1 = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasRegion")), + Literal.of(true))) + .build(); + + Condition branch2 = Condition.builder() + .fn(TestHelpers.isSet("Bucket")) + .build(); + + Rule rule1 = EndpointRule.builder() + .conditions(gate, branch1) + .endpoint(TestHelpers.endpoint("https://example1.com")); + + Rule rule2 = EndpointRule.builder() + .conditions(gate, branch2) + .endpoint(TestHelpers.endpoint("https://example2.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + CfgGuidedOrdering ordering = new CfgGuidedOrdering(cfg); + + List ordered = ordering.orderConditions(cfg.getConditions()); + + // Gate condition should be ordered early since multiple branches depend on it + assertNotNull(ordered); + assertTrue(ordered.size() >= 2); + + // Find the gate condition + int gateIndex = -1; + for (int i = 0; i < ordered.size(); i++) { + if (ordered.get(i).getResult().isPresent() && + ordered.get(i).getResult().get().toString().equals("hasRegion")) { + gateIndex = i; + break; + } + } + + // Gate should be ordered before conditions that depend on it + assertTrue(gateIndex >= 0, "Gate condition should be in the ordering"); + assertTrue(gateIndex < ordered.size() - 1, "Gate should not be last"); + } + + @Test + void testNestedTreeOrdering() { + // Nested tree structure to test depth-based ordering + Rule innerRule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Bucket")).build()) + .endpoint(TestHelpers.endpoint("https://bucket.com")); + + Rule treeRule = TreeRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .treeRule(innerRule); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(treeRule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + CfgGuidedOrdering ordering = new CfgGuidedOrdering(cfg); + + List ordered = ordering.orderConditions(cfg.getConditions()); + + assertEquals(2, ordered.size()); + + // Region should come before Bucket due to tree structure + int regionIndex = -1; + int bucketIndex = -1; + for (int i = 0; i < ordered.size(); i++) { + if (ordered.get(i).toString().contains("Region")) { + regionIndex = i; + } else if (ordered.get(i).toString().contains("Bucket")) { + bucketIndex = i; + } + } + + assertTrue(regionIndex < bucketIndex, "Region should be ordered before Bucket"); + } + + @Test + void testMultipleIndependentConditions() { + // Multiple conditions with no dependencies + Rule rule = EndpointRule.builder() + .conditions( + Condition.builder().fn(TestHelpers.isSet("A")).build(), + Condition.builder().fn(TestHelpers.isSet("B")).build(), + Condition.builder().fn(TestHelpers.isSet("C")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("A").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("B").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("C").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + CfgGuidedOrdering ordering = new CfgGuidedOrdering(cfg); + + List ordered = ordering.orderConditions(cfg.getConditions()); + + // Should order all conditions + assertEquals(3, ordered.size()); + + // Order should be deterministic (based on CFG structure). Run twice to ensure consistency. + List ordered2 = ordering.orderConditions(cfg.getConditions()); + assertEquals(ordered, ordered2, "Ordering should be deterministic"); + } + + @Test + void testEmptyConditions() { + // Ruleset with no conditions + Rule rule = EndpointRule.builder() + .endpoint(TestHelpers.endpoint("https://default.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder().build()) + .addRule(rule) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + CfgGuidedOrdering ordering = new CfgGuidedOrdering(cfg); + + List ordered = ordering.orderConditions(cfg.getConditions()); + + assertNotNull(ordered); + assertEquals(0, ordered.size()); + } + + @Test + void testIsSetGatePriority() { + // Test that isSet conditions used by multiple consumers get priority + Condition isSetGate = Condition.builder() + .fn(IsSet.ofExpressions(Expression.getReference(Identifier.of("Input")))) + .result(Identifier.of("hasInput")) + .build(); + + // Multiple rules use the hasInput variable + Rule rule1 = EndpointRule.builder() + .conditions( + isSetGate, + Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasInput")), + Literal.of(true))) + .build()) + .endpoint(TestHelpers.endpoint("https://example1.com")); + + Rule rule2 = EndpointRule.builder() + .conditions( + isSetGate, + Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasInput")), + Literal.of(false))) + .build()) + .endpoint(TestHelpers.endpoint("https://example2.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Input").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + + Cfg cfg = Cfg.from(ruleSet); + CfgGuidedOrdering ordering = new CfgGuidedOrdering(cfg); + + List ordered = ordering.orderConditions(cfg.getConditions()); + + // The isSet gate should be ordered early + assertNotNull(ordered); + + // After transforms, the structure might change + // Just verify that if an isSet condition exists, it's prioritized + boolean hasIsSet = false; + for (Condition c : ordered) { + if (c.getFunction().getFunctionDefinition() == IsSet.getDefinition()) { + hasIsSet = true; + break; + } + } + + // The test's purpose is to verify prioritization works, not specific positions + // After coalesce transform, the isSet might be merged into other conditions + assertFalse(ordered.isEmpty(), "Should have at least one condition"); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategyTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategyTest.java deleted file mode 100644 index f04b26fef19..00000000000 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/DefaultOrderingStrategyTest.java +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.logic.bdd; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import java.util.List; -import org.junit.jupiter.api.Test; -import software.amazon.smithy.rulesengine.language.syntax.Identifier; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.ParseUrl; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.logic.TestHelpers; - -class DefaultOrderingStrategyTest { - - @Test - void testIsSetComesFirst() { - // isSet should be ordered before other conditions - Condition isSetCond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); - Condition stringEqualsCond = Condition.builder() - .fn(StringEquals.ofExpressions(Literal.of("{Region}"), Literal.of("us-east-1"))) - .build(); - - Condition[] conditions = {stringEqualsCond, isSetCond}; - - List ordered = DefaultOrderingStrategy.orderConditions(conditions); - - // isSet should come first - assertEquals(isSetCond, ordered.get(0)); - assertEquals(stringEqualsCond, ordered.get(1)); - } - - @Test - void testVariableDefiningConditionsFirst() { - // Conditions that define variables should come before those that don't - Condition definer = Condition.builder() - .fn(TestHelpers.isSet("Region")) - .result(Identifier.of("hasRegion")) - .build(); - - Condition nonDefiner = Condition.builder().fn(TestHelpers.isSet("Bucket")).build(); - - Condition[] conditions = {nonDefiner, definer}; - - List ordered = DefaultOrderingStrategy.orderConditions(conditions); - - // Variable-defining condition should come first - assertEquals(definer, ordered.get(0)); - assertEquals(nonDefiner, ordered.get(1)); - } - - @Test - void testDependencyOrdering() { - // Condition that defines a variable - Condition definer = Condition.builder() - .fn(TestHelpers.isSet("Region")) - .result(Identifier.of("hasRegion")) - .build(); - - // Condition that uses the variable - Condition user = Condition.builder() - .fn(BooleanEquals.ofExpressions(Literal.of("{hasRegion}"), Literal.of(true))) - .build(); - - Condition[] conditions = {user, definer}; - - List ordered = DefaultOrderingStrategy.orderConditions(conditions); - - // Definer must come before user - assertEquals(definer, ordered.get(0)); - assertEquals(user, ordered.get(1)); - } - - @Test - void testComplexityOrdering() { - // Simple condition - Condition simple = Condition.builder().fn(TestHelpers.isSet("Region")).build(); - - // Complex condition (parseURL has higher cost) - Condition complex = Condition.builder().fn(ParseUrl.ofExpressions(Literal.of("https://example.com"))).build(); - - Condition[] conditions = {complex, simple}; - - List ordered = DefaultOrderingStrategy.orderConditions(conditions); - - // Simple should come before complex - assertEquals(simple, ordered.get(0)); - assertEquals(complex, ordered.get(1)); - } - - @Test - void testCircularDependencyDetection() { - // Create conditions with circular dependency - // Note: This is a pathological case that shouldn't happen in practice - Condition cond1 = Condition.builder() - .fn(BooleanEquals.ofExpressions(Literal.of("{var2}"), Literal.of(true))) - .result(Identifier.of("var1")) - .build(); - - Condition cond2 = Condition.builder() - .fn(BooleanEquals.ofExpressions(Literal.of("{var1}"), Literal.of(true))) - .result(Identifier.of("var2")) - .build(); - - Condition[] conditions = {cond1, cond2}; - - assertThrows(IllegalStateException.class, - () -> DefaultOrderingStrategy.orderConditions(conditions)); - } - - @Test - void testMultiLevelDependencies() { - // A -> B -> C dependency chain - Condition condA = Condition.builder() - .fn(TestHelpers.isSet("input")) - .result(Identifier.of("var1")) - .build(); - - Condition condB = Condition.builder() - .fn(BooleanEquals.ofExpressions(Literal.of("{var1}"), Literal.of(true))) - .result(Identifier.of("var2")) - .build(); - - Condition condC = Condition.builder() - .fn(BooleanEquals.ofExpressions(Literal.of("{var2}"), Literal.of(false))) - .build(); - - // Mix up the order - Condition[] conditions = {condC, condA, condB}; - - List ordered = DefaultOrderingStrategy.orderConditions(conditions); - - assertEquals(condA, ordered.get(0)); - assertEquals(condB, ordered.get(1)); - assertEquals(condC, ordered.get(2)); - } - - @Test - void testStableSortForEqualPriority() { - // Two similar conditions with no dependencies use stable sort - Condition cond1 = Condition.builder().fn(TestHelpers.isSet("Region")).build(); - Condition cond2 = Condition.builder().fn(TestHelpers.isSet("Bucket")).build(); - - Condition[] conditions = {cond1, cond2}; - - List ordered = DefaultOrderingStrategy.orderConditions(conditions); - - // Order should be deterministic based on toString - assertEquals(2, ordered.size()); - assertTrue(ordered.contains(cond1)); - assertTrue(ordered.contains(cond2)); - } - - @Test - void testEmptyConditions() { - Condition[] conditions = new Condition[0]; - - List ordered = DefaultOrderingStrategy.orderConditions(conditions); - - assertEquals(0, ordered.size()); - } - - @Test - void testSingleCondition() { - Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); - - Condition[] conditions = {cond}; - - List ordered = DefaultOrderingStrategy.orderConditions(conditions); - - assertEquals(1, ordered.size()); - assertEquals(cond, ordered.get(0)); - } - - @Test - void testIsSetDependencyForSameVariable() { - // isSet and value check for same variable - Condition isSet = Condition.builder().fn(TestHelpers.isSet("Region")).build(); - - Condition valueCheck = Condition.builder() - .fn(StringEquals.ofExpressions(Literal.of("{Region}"), Literal.of("us-east-1"))) - .build(); - - // Put value check first to test ordering - Condition[] conditions = {valueCheck, isSet}; - - List ordered = DefaultOrderingStrategy.orderConditions(conditions); - - // isSet must come before value check - assertEquals(isSet, ordered.get(0)); - assertEquals(valueCheck, ordered.get(1)); - } -} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraintsTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraintsTest.java deleted file mode 100644 index ea872c6dd7f..00000000000 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/OrderConstraintsTest.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ -package software.amazon.smithy.rulesengine.logic.bdd; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import org.junit.jupiter.api.Test; -import software.amazon.smithy.rulesengine.language.syntax.Identifier; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; -import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.logic.TestHelpers; - -class OrderConstraintsTest { - - @Test - void testIndependentConditions() { - // Two conditions with no dependencies - Condition cond1 = Condition.builder().fn(TestHelpers.isSet("Region")).build(); - Condition cond2 = Condition.builder().fn(TestHelpers.isSet("Bucket")).build(); - - List conditions = Arrays.asList(cond1, cond2); - ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); - OrderConstraints constraints = new OrderConstraints(graph, conditions); - - // Both conditions can be placed anywhere - assertTrue(constraints.canMove(0, 1)); - assertTrue(constraints.canMove(1, 0)); - - assertEquals(0, constraints.getMinValidPosition(0)); - assertEquals(1, constraints.getMaxValidPosition(0)); - assertEquals(0, constraints.getMinValidPosition(1)); - assertEquals(1, constraints.getMaxValidPosition(1)); - } - - @Test - void testDependentConditions() { - // cond1 defines var, cond2 uses it - Condition cond1 = Condition.builder() - .fn(TestHelpers.isSet("Region")) - .result(Identifier.of("hasRegion")) - .build(); - - Condition cond2 = Condition.builder() - .fn(BooleanEquals.ofExpressions(Literal.of("{hasRegion}"), Literal.of(true))) - .build(); - - List conditions = Arrays.asList(cond1, cond2); - ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); - OrderConstraints constraints = new OrderConstraints(graph, conditions); - - // cond1 can only stay in place (cannot move past its dependent) - assertTrue(constraints.canMove(0, 0)); // Stay in place - assertFalse(constraints.canMove(0, 1)); // Cannot move past cond2 - - // cond2 cannot move before cond1 - assertFalse(constraints.canMove(1, 0)); - assertTrue(constraints.canMove(1, 1)); // Stay in place - - assertEquals(0, constraints.getMinValidPosition(0)); - assertEquals(0, constraints.getMaxValidPosition(0)); // Must come before cond2 - assertEquals(1, constraints.getMinValidPosition(1)); // Must come after cond1 - assertEquals(1, constraints.getMaxValidPosition(1)); - } - - @Test - void testChainedDependencies() { - // A -> B -> C dependency chain - Condition condA = Condition.builder().fn(TestHelpers.isSet("input")).result(Identifier.of("var1")).build(); - Condition condB = Condition.builder() - .fn(BooleanEquals.ofExpressions(Literal.of("{var1}"), Literal.of(true))) - .result(Identifier.of("var2")) - .build(); - Condition condC = Condition.builder() - .fn(BooleanEquals.ofExpressions(Literal.of("{var2}"), Literal.of(false))) - .build(); - - List conditions = Arrays.asList(condA, condB, condC); - ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); - OrderConstraints constraints = new OrderConstraints(graph, conditions); - - // A can only be at position 0 - assertEquals(0, constraints.getMinValidPosition(0)); - assertEquals(0, constraints.getMaxValidPosition(0)); - - // B must be between A and C - assertEquals(1, constraints.getMinValidPosition(1)); - assertEquals(1, constraints.getMaxValidPosition(1)); - - // C must be last - assertEquals(2, constraints.getMinValidPosition(2)); - assertEquals(2, constraints.getMaxValidPosition(2)); - - // No movement possible in this rigid chain - assertFalse(constraints.canMove(0, 1)); - assertFalse(constraints.canMove(1, 0)); - assertFalse(constraints.canMove(1, 2)); - assertFalse(constraints.canMove(2, 1)); - } - - @Test - void testCanMoveToSamePosition() { - Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); - List conditions = Collections.singletonList(cond); - ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); - OrderConstraints constraints = new OrderConstraints(graph, conditions); - - // Moving to same position is always allowed - assertTrue(constraints.canMove(0, 0)); - } - - @Test - void testMismatchedSizes() { - Condition cond1 = Condition.builder().fn(TestHelpers.isSet("Region")).build(); - List graphConditions = Collections.singletonList(cond1); - ConditionDependencyGraph graph = new ConditionDependencyGraph(graphConditions); - - // Try to create constraints with more conditions than in graph - List conditions = Arrays.asList( - cond1, - Condition.builder().fn(TestHelpers.isSet("Bucket")).build()); - - assertThrows(IllegalArgumentException.class, () -> new OrderConstraints(graph, conditions)); - } -} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java index 583ca4bac7c..888d833f9d8 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java @@ -14,6 +14,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.ArrayList; +import java.util.Collections; import java.util.HashSet; import java.util.Iterator; import java.util.List; @@ -66,7 +67,7 @@ void fromCreatesSimpleCfg() { // Root should be a result node for a simple endpoint rule assertInstanceOf(ResultNode.class, cfg.getRoot()); ResultNode resultNode = (ResultNode) cfg.getRoot(); - assertEquals(rule.withoutConditions(), resultNode.getResult()); + assertEquals(rule.withConditions(Collections.emptyList()), resultNode.getResult()); } @Test diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraphTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDependencyGraphTest.java similarity index 96% rename from smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraphTest.java rename to smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDependencyGraphTest.java index 0932dec64c7..f118ce41e65 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/ConditionDependencyGraphTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ConditionDependencyGraphTest.java @@ -2,12 +2,13 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.rulesengine.logic.bdd; +package software.amazon.smithy.rulesengine.logic.cfg; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Set; import org.junit.jupiter.api.Test; @@ -98,7 +99,7 @@ void testUnknownConditionReturnsEmptyDependencies() { Condition known = Condition.builder().fn(TestHelpers.isSet("Region")).build(); Condition unknown = Condition.builder().fn(TestHelpers.isSet("Bucket")).build(); - List conditions = Arrays.asList(known); + List conditions = Collections.singletonList(known); ConditionDependencyGraph graph = new ConditionDependencyGraph(conditions); // Getting dependencies for unknown condition returns empty set diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java new file mode 100644 index 00000000000..8edac1db7d7 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java @@ -0,0 +1,181 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.RecordLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.TupleLiteral; +import software.amazon.smithy.utils.ListUtils; +import software.amazon.smithy.utils.MapUtils; + +class ReferenceRewriterTest { + + @Test + void testSimpleReferenceReplacement() { + // Create a rewriter that replaces "x" with "y" + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("y"))); + + ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + + // Test rewriting a simple reference + Reference original = Expression.getReference(Identifier.of("x")); + Expression rewritten = rewriter.rewrite(original); + + assertEquals("y", ((Reference) rewritten).getName().toString()); + } + + @Test + void testNoRewriteNeeded() { + // Create a rewriter with no relevant replacements + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("y"))); + + ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + + // Reference to "z" should not be rewritten + Reference original = Expression.getReference(Identifier.of("z")); + Expression rewritten = rewriter.rewrite(original); + + assertEquals(original, rewritten); + } + + @Test + void testRewriteInStringLiteral() { + // Create a string literal with template variable + Template template = Template.fromString("Value is {x}"); + Literal original = Literal.stringLiteral(template); + + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("newVar"))); + + ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + Expression rewritten = rewriter.rewrite(original); + + assertInstanceOf(StringLiteral.class, rewritten); + StringLiteral rewrittenStr = (StringLiteral) rewritten; + assertTrue(rewrittenStr.toString().contains("newVar")); + assertNotEquals(original, rewritten); + } + + @Test + void testRewriteInTupleLiteral() { + // Create a tuple with references + Literal original = Literal.tupleLiteral(ListUtils.of(Literal.of("constant"), Literal.of("{x}"))); + + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("replaced"))); + + ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + Expression rewritten = rewriter.rewrite(original); + + assertInstanceOf(TupleLiteral.class, rewritten); + TupleLiteral rewrittenTuple = (TupleLiteral) rewritten; + assertEquals(2, rewrittenTuple.members().size()); + assertTrue(rewrittenTuple.members().get(1).toString().contains("replaced")); + } + + @Test + void testRewriteInRecordLiteral() { + // Create a record with references + Literal original = Literal.recordLiteral(MapUtils.of( + Identifier.of("field1"), + Literal.of("value1"), + Identifier.of("field2"), + Literal.of("{x}"))); + + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("newX"))); + + ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + Expression rewritten = rewriter.rewrite(original); + + assertInstanceOf(RecordLiteral.class, rewritten); + RecordLiteral rewrittenRecord = (RecordLiteral) rewritten; + assertEquals(2, rewrittenRecord.members().size()); + assertTrue(rewrittenRecord.members().get(Identifier.of("field2")).toString().contains("newX")); + } + + @Test + void testRewriteInLibraryFunction() { + // Create a function that uses references + Expression original = StringEquals.ofExpressions( + Expression.getReference(Identifier.of("x")), + Literal.of("test")); + + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("replacedVar"))); + + ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + Expression rewritten = rewriter.rewrite(original); + + assertTrue(rewritten.toString().contains("replacedVar")); + assertNotEquals(original, rewritten); + } + + @Test + void testMultipleReplacements() { + // Create a function with multiple references + Expression original = BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("a")), + Expression.getReference(Identifier.of("b"))); + + Map replacements = new HashMap<>(); + replacements.put("a", Expression.getReference(Identifier.of("x"))); + replacements.put("b", Expression.getReference(Identifier.of("y"))); + + ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + Expression rewritten = rewriter.rewrite(original); + + assertTrue(rewritten.toString().contains("x")); + assertTrue(rewritten.toString().contains("y")); + } + + @Test + void testNestedRewriting() { + // Create nested functions with references + Expression inner = IsSet.ofExpressions(Expression.getReference(Identifier.of("x"))); + Expression original = BooleanEquals.ofExpressions(inner, Literal.of(true)); + + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("newVar"))); + + ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + Expression rewritten = rewriter.rewrite(original); + + assertTrue(rewritten.toString().contains("newVar")); + assertNotEquals(original, rewritten); + } + + @Test + void testStaticStringNotRewritten() { + // Static strings without templates should not be rewritten + Literal original = Literal.of("static string"); + + Map replacements = new HashMap<>(); + replacements.put("x", Expression.getReference(Identifier.of("y"))); + + ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + Expression rewritten = rewriter.rewrite(original); + + assertEquals(original, rewritten); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableDisambiguatorTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java similarity index 97% rename from smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableDisambiguatorTest.java rename to smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java index 0ad5494d09d..8a0d2b6e14f 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableDisambiguatorTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java @@ -26,7 +26,7 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; -public class VariableDisambiguatorTest { +public class SsaTransformTest { @Test void testNoDisambiguationNeeded() { @@ -53,7 +53,6 @@ void testNoDisambiguationNeeded() { EndpointRuleSet result = SsaTransform.transform(original); - // Should be unchanged assertEquals(original, result); } @@ -82,11 +81,9 @@ void testSimpleShadowing() { List resultRules = result.getRules(); assertEquals(2, resultRules.size()); - // First rule should keep "temp" EndpointRule resultRule1 = (EndpointRule) resultRules.get(0); assertEquals("temp", resultRule1.getConditions().get(0).getResult().get().toString()); - // Second rule should have "temp_1" EndpointRule resultRule2 = (EndpointRule) resultRules.get(1); assertEquals("temp_1", resultRule2.getConditions().get(0).getResult().get().toString()); } @@ -143,7 +140,6 @@ void testErrorRuleHandling() { EndpointRuleSet result = SsaTransform.transform(original); - // Should handle error rules without issues assertEquals(1, result.getRules().size()); assertInstanceOf(ErrorRule.class, result.getRules().get(0)); } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysisTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysisTest.java new file mode 100644 index 00000000000..204483ef615 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysisTest.java @@ -0,0 +1,327 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Map; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.utils.ListUtils; + +class VariableAnalysisTest { + + @Test + void testSimpleVariableBinding() { + // Rule with one variable binding + Condition condition = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(condition) + .endpoint(TestHelpers.endpoint("https://example.com")); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build()) + .addRule(rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertTrue(analysis.getInputParams().contains("Region")); + assertTrue(analysis.hasSingleBinding("hasRegion")); + assertEquals(0, analysis.getReferenceCount("hasRegion")); + } + + @Test + void testVariableReference() { + // Define and use a variable + Condition define = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition use = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasRegion")), + Literal.of(true))) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(define, use) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertEquals(1, analysis.getReferenceCount("hasRegion")); + assertTrue(analysis.isReferencedOnce("hasRegion")); + + assertTrue(analysis.isSafeToInline("hasRegion")); + } + + @Test + void testMultipleBindings() { + // Same variable assigned in different branches + Rule rule1 = EndpointRule.builder() + .conditions(Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("x")) + .build()) + .endpoint(TestHelpers.endpoint("https://example1.com")); + + Rule rule2 = EndpointRule.builder() + .conditions(Condition.builder() + .fn(TestHelpers.isSet("Bucket")) + .result(Identifier.of("x")) + .build()) + .endpoint(TestHelpers.endpoint("https://example2.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertFalse(analysis.hasSingleBinding("x")); + + // Not safe to inline when multiple bindings exist + assertFalse(analysis.isSafeToInline("x")); + + // Should have different SSA names for different expressions + Map> mappings = analysis.getExpressionMappings(); + assertNotNull(mappings.get("x")); + assertEquals(2, mappings.get("x").size()); + } + + @Test + void testMultipleReferences() { + // Variable referenced multiple times + Condition define = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition use1 = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasRegion")), + Literal.of(true))) + .build(); + + Condition use2 = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasRegion")), + Literal.of(false))) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(define, use1, use2) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertEquals(2, analysis.getReferenceCount("hasRegion")); + assertFalse(analysis.isReferencedOnce("hasRegion")); + + assertFalse(analysis.isSafeToInline("hasRegion")); + } + + @Test + void testReferencesInEndpoint() { + // Variable used in endpoint URL - just use the Region parameter directly + Condition checkRegion = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .build(); + + Endpoint endpoint = Endpoint.builder() + .url(Literal.stringLiteral(Template.fromString("https://{Region}.example.com"))) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkRegion) + .endpoint(endpoint); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertEquals(2, analysis.getReferenceCount("Region")); + } + + @Test + void testReferencesInErrorRule() { + // First prove Region is set, then check if it's invalid + Condition checkRegion = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition checkInvalid = Condition.builder() + .fn(StringEquals.ofExpressions( + Expression.getReference(Identifier.of("Region")), + Literal.of("invalid"))) + .result(Identifier.of("isInvalid")) + .build(); + + // Use the Region value directly in the error message + Rule rule = ErrorRule.builder() + .conditions(checkRegion, checkInvalid) + .error(Literal.stringLiteral(Template.fromString("Invalid region: {Region}"))); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertEquals(3, analysis.getReferenceCount("Region")); + assertTrue(analysis.getReferenceCount("hasRegion") >= 0); + assertEquals(0, analysis.getReferenceCount("isInvalid")); + } + + @Test + void testNestedTreeRuleAnalysis() { + // Nested rules with variable bindings + Condition outerDefine = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition innerUse = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasRegion")), + Literal.of(true))) + .build(); + + Rule innerRule = EndpointRule.builder() + .conditions(innerUse) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Rule treeRule = TreeRule.builder() + .conditions(outerDefine) + .treeRule(innerRule); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(treeRule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertTrue(analysis.hasSingleBinding("hasRegion")); + assertEquals(1, analysis.getReferenceCount("hasRegion")); + assertTrue(analysis.isSafeToInline("hasRegion")); + } + + @Test + void testInputParametersIdentified() { + // Multiple input parameters + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("UseDualStack").type(ParameterType.BOOLEAN).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertEquals(3, analysis.getInputParams().size()); + assertTrue(analysis.getInputParams().contains("Region")); + assertTrue(analysis.getInputParams().contains("Bucket")); + assertTrue(analysis.getInputParams().contains("UseDualStack")); + } + + @Test + void testNoVariables() { + // Simple ruleset with no variable bindings + Rule rule = EndpointRule.builder() + .conditions(Condition.builder().fn(TestHelpers.isSet("Region")).build()) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + // No variables bound + assertEquals(0, analysis.getReferenceCount("anyVar")); + assertFalse(analysis.hasSingleBinding("anyVar")); + assertFalse(analysis.isSafeToInline("anyVar")); + + // But Region is an input parameter + assertTrue(analysis.getInputParams().contains("Region")); + } +} From 0dfebc883ac591814174c14c2b37b7151d17c4bb Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Wed, 13 Aug 2025 14:22:38 -0500 Subject: [PATCH 13/23] Improve sifting --- .../logic/bdd/SiftingOptimization.java | 485 +++++++++--------- 1 file changed, 239 insertions(+), 246 deletions(-) diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java index e375bd8a1e0..7fc06152a1a 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -7,10 +7,13 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; +import java.util.IdentityHashMap; import java.util.List; -import java.util.concurrent.ForkJoinPool; +import java.util.Map; import java.util.function.Function; import java.util.logging.Logger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; @@ -20,28 +23,18 @@ /** * BDD optimization using tiered parallel position evaluation with dependency-aware constraints. * - *

    This algorithm improves BDD size through a multi-stage approach: + *

    The optimization runs in three stages with decreasing granularity: *

      - *
    • Coarse optimization for large BDDs (fast reduction)
    • - *
    • Medium optimization for moderate BDDs (balanced approach)
    • - *
    • Granular optimization for small BDDs (maximum quality)
    • + *
    • Coarse: Fast reduction with large steps
    • + *
    • Medium: Balanced optimization
    • + *
    • Granular: Fine-tuned optimization for maximum reduction
    • *
    - * - *

    Each stage runs until reaching its target size or maximum passes. */ public final class SiftingOptimization implements Function { private static final Logger LOGGER = Logger.getLogger(SiftingOptimization.class.getName()); - // Default thresholds and passes for each optimization level - private static final int DEFAULT_COARSE_MIN_NODES = 50_000; - private static final int DEFAULT_COARSE_MAX_PASSES = 5; - private static final int DEFAULT_MEDIUM_MIN_NODES = 10_000; - private static final int DEFAULT_MEDIUM_MAX_PASSES = 5; - private static final int DEFAULT_GRANULAR_MAX_NODES = 10_000; - private static final int DEFAULT_GRANULAR_MAX_PASSES = 8; - - // When a variable has fewer than this many valid positions, try them all. - private static final int EXHAUSTIVE_THRESHOLD = 20; + // When to use a parallel stream + private static final int PARALLEL_THRESHOLD = 7; // Thread-local BDD builders to avoid allocation overhead private final ThreadLocal threadBuilder = ThreadLocal.withInitial(BddBuilder::new); @@ -59,18 +52,31 @@ public final class SiftingOptimization implements Function { // Internal effort levels for the tiered optimization stages. private enum OptimizationEffort { - COARSE(12, 4, 0), - MEDIUM(2, 18, 5), - GRANULAR(1, 48, 10); + COARSE(11, 4, 0, 20, 4_000, 6), + MEDIUM(2, 20, 6, 20, 1_000, 6), + GRANULAR(1, 50, 12, 20, 8_000, 12); final int sampleRate; final int maxPositions; final int nearbyRadius; - - OptimizationEffort(int sampleRate, int maxPositions, int nearbyRadius) { + final int exhaustiveThreshold; + final int defaultNodeThreshold; + final int defaultMaxPasses; + + OptimizationEffort( + int sampleRate, + int maxPositions, + int nearbyRadius, + int exhaustiveThreshold, + int defaultNodeThreshold, + int defaultMaxPasses + ) { this.sampleRate = sampleRate; this.maxPositions = maxPositions; this.nearbyRadius = nearbyRadius; + this.exhaustiveThreshold = exhaustiveThreshold; + this.defaultNodeThreshold = defaultNodeThreshold; + this.defaultMaxPasses = defaultMaxPasses; } } @@ -85,11 +91,6 @@ private SiftingOptimization(Builder builder) { this.dependencyGraph = new ConditionDependencyGraph(Arrays.asList(cfg.getConditions())); } - /** - * Creates a new builder for SiftingOptimization. - * - * @return a new builder instance - */ public static Builder builder() { return new Builder(); } @@ -106,20 +107,20 @@ public BddTrait apply(BddTrait trait) { private BddTrait doApply(BddTrait trait) { LOGGER.info("Starting BDD sifting optimization"); long startTime = System.currentTimeMillis(); - - // Pre-spin the ForkJoinPool for better first-pass performance - ForkJoinPool.commonPool().submit(() -> {}).join(); - OptimizationState state = initializeOptimization(trait); LOGGER.info(String.format("Initial size: %d nodes", state.initialSize)); - state = runCoarseStage(state); - state = runMediumStage(state); - state = runGranularStage(state); + state = runOptimizationStage("Coarse", state, OptimizationEffort.COARSE, coarseMinNodes, coarseMaxPasses, 4.0); + state = runOptimizationStage("Medium", state, OptimizationEffort.MEDIUM, mediumMinNodes, mediumMaxPasses, 1.5); + if (state.currentSize <= granularMaxNodes) { + state = runOptimizationStage("Granular", state, OptimizationEffort.GRANULAR, 0, granularMaxPasses, 0.0); + } else { + LOGGER.info("Skipping granular stage - too large"); + } + state = runAdjacentSwaps(state); double totalTimeInSeconds = (System.currentTimeMillis() - startTime) / 1000.0; - // Only rebuild if we found an improvement if (state.bestSize >= state.initialSize) { LOGGER.info(String.format("No improvements found in %fs", totalTimeInSeconds)); return trait; @@ -131,77 +132,30 @@ private BddTrait doApply(BddTrait trait) { (1.0 - (double) state.bestSize / state.initialSize) * 100, totalTimeInSeconds)); - // Rebuild the BddTrait with the optimized ordering and BDD - return trait.toBuilder() - .conditions(state.orderView) - .results(state.results) - .bdd(state.bestBdd) - .build(); + return trait.toBuilder().conditions(state.orderView).results(state.results).bdd(state.bestBdd).build(); } private OptimizationState initializeOptimization(BddTrait trait) { // Use the trait's existing ordering as the starting point List initialOrder = new ArrayList<>(trait.getConditions()); - Condition[] order = initialOrder.toArray(new Condition[0]); List orderView = Arrays.asList(order); - - // Get the initial size from the input BDD Bdd bdd = trait.getBdd(); - int initialSize = bdd.getNodeCount() - 1; // -1 for terminal - - // No need to recompile just for results—use the trait's own results - return new OptimizationState(order, - orderView, - bdd, - initialSize, - initialSize, - trait.getResults()); - } - - private OptimizationState runCoarseStage(OptimizationState state) { - if (state.currentSize <= coarseMinNodes) { - return state; - } - return runOptimizationStage(state, "Coarse", OptimizationEffort.COARSE, coarseMinNodes, coarseMaxPasses, 4.0); - } - - private OptimizationState runMediumStage(OptimizationState state) { - if (state.currentSize <= mediumMinNodes) { - return state; - } - return runOptimizationStage(state, "Medium", OptimizationEffort.MEDIUM, mediumMinNodes, mediumMaxPasses, 1.5); - } - - private OptimizationState runGranularStage(OptimizationState state) { - if (state.currentSize > granularMaxNodes) { - LOGGER.info(String.format("Skipping granular stage - BDD too large (%d nodes > %d threshold)", - state.currentSize, - granularMaxNodes)); - return state; - } - - // Run with no minimums - state = runOptimizationStage(state, "Granular", OptimizationEffort.GRANULAR, 0, granularMaxPasses, 0.0); - - // Also perform adjacent swaps in granular stage - OptimizationResult swapResult = performAdjacentSwaps(state.order, state.orderView, state.currentSize); - if (swapResult.improved) { - LOGGER.info(String.format("Adjacent swaps: %d -> %d nodes", state.currentSize, swapResult.size)); - return state.withResult(swapResult.bdd, swapResult.size, swapResult.results); - } - - return state; + int initialSize = bdd.getNodeCount() - 1; + return new OptimizationState(order, orderView, bdd, initialSize, initialSize, trait.getResults()); } private OptimizationState runOptimizationStage( - OptimizationState state, String stageName, + OptimizationState state, OptimizationEffort effort, int targetNodeCount, int maxPasses, double minReductionPercent ) { + if (targetNodeCount > 0 && state.currentSize <= targetNodeCount) { + return state; + } LOGGER.info(String.format("Stage: %s optimization (%d nodes%s)", stageName, @@ -209,24 +163,14 @@ private OptimizationState runOptimizationStage( targetNodeCount > 0 ? String.format(", target < %d", targetNodeCount) : "")); OptimizationState currentState = state; - for (int pass = 1; pass <= maxPasses; pass++) { - // Stop if we've reached the target if (targetNodeCount > 0 && currentState.currentSize <= targetNodeCount) { break; } int passStartSize = currentState.currentSize; - OptimizationResult result = runOptimizationPass( - currentState.order, - currentState.orderView, - currentState.currentSize, - effort); - - if (!result.improved) { - LOGGER.fine(String.format("%s pass %d found no improvements", stageName, pass)); - break; - } else { + OptimizationResult result = runPass(currentState, effort); + if (result.improved) { currentState = currentState.withResult(result.bdd, result.size, result.results); double reduction = (1.0 - (double) result.size / passStartSize) * 100; LOGGER.fine(String.format("%s pass %d: %d -> %d nodes (%.1f%% reduction)", @@ -235,112 +179,112 @@ private OptimizationState runOptimizationStage( passStartSize, result.size, reduction)); - // Check for diminishing returns if (minReductionPercent > 0 && reduction < minReductionPercent) { LOGGER.fine(String.format("%s optimization yielding diminishing returns", stageName)); break; } + } else { + LOGGER.fine(String.format("%s pass %d found no improvements", stageName, pass)); + break; } } return currentState; } - private OptimizationResult runOptimizationPass( - Condition[] order, - List orderView, - int currentSize, - OptimizationEffort effort - ) { + private OptimizationState runAdjacentSwaps(OptimizationState state) { + if (state.currentSize > granularMaxNodes) { + return state; + } - int improvements = 0; - ConditionDependencyGraph.OrderConstraints constraints = dependencyGraph.createOrderConstraints(orderView); - Bdd bestBdd = null; - int bestSize = currentSize; - List bestResults = null; + LOGGER.info("Running adjacent swaps optimization"); + OptimizationState currentState = state; - // Sample variables based on effort level - for (int varIdx = 0; varIdx < order.length; varIdx += effort.sampleRate) { - List positions = getPositions(varIdx, constraints, effort); - if (positions.isEmpty()) { - continue; - } else if (positions.size() > effort.maxPositions) { - positions = positions.subList(0, effort.maxPositions); + // Run multiple sweeps until no improvement + for (int sweep = 1; sweep <= 3; sweep++) { + OptimizationContext context = new OptimizationContext(currentState, dependencyGraph); + int startSize = currentState.currentSize; + + for (int i = 0; i < currentState.order.length - 1; i++) { + if (context.constraints.canMove(i, i + 1)) { + move(currentState.order, i, i + 1); + BddCompilationResult compilationResult = compileBddWithResults(currentState.orderView); + int swappedSize = compilationResult.bdd.getNodeCount() - 1; + if (swappedSize < context.bestSize) { + context = context.withImprovement( + new PositionResult(i + 1, + swappedSize, + compilationResult.bdd, + compilationResult.results)); + } else { + move(currentState.order, i + 1, i); // Swap back + } + } } - // Find best position - PositionCount best = findBestPosition(positions, order, bestSize, varIdx); - - if (best == null || best.count >= bestSize) { - continue; + if (context.improvements > 0) { + currentState = currentState.withResult(context.bestBdd, context.bestSize, context.bestResults); + LOGGER.fine(String.format("Adjacent swaps sweep %d: %d -> %d nodes", + sweep, + startSize, + context.bestSize)); + } else { + break; } + } + + return currentState; + } - // Move to best position and build BDD once - move(order, varIdx, best.position); - BddCompilationResult compilationResult = compileBddWithResults(orderView); - Bdd newBdd = compilationResult.bdd; - int newSize = newBdd.getNodeCount() - 1; + private OptimizationResult runPass(OptimizationState state, OptimizationEffort effort) { + OptimizationContext context = new OptimizationContext(state, dependencyGraph); - if (newSize < bestSize) { - bestBdd = newBdd; - bestSize = newSize; - bestResults = compilationResult.results; - improvements++; + List selectedConditions = IntStream.range(0, state.orderView.size()) + .filter(i -> i % effort.sampleRate == 0) + .mapToObj(state.orderView::get) + .collect(Collectors.toList()); - // Update constraints after successful move - constraints = dependencyGraph.createOrderConstraints(orderView); + for (Condition condition : selectedConditions) { + Integer varIdx = context.liveIndex.get(condition); + if (varIdx == null) { + continue; + } + + List positions = getStrategicPositions(varIdx, context.constraints, effort); + if (positions.isEmpty()) { + continue; } + + context = tryImprovePosition(context, varIdx, positions); } - return new OptimizationResult(bestBdd, bestSize, improvements > 0, bestResults); + return context.toResult(); } - private OptimizationResult performAdjacentSwaps(Condition[] order, List orderView, int currentSize) { - ConditionDependencyGraph.OrderConstraints constraints = dependencyGraph.createOrderConstraints(orderView); - Bdd bestBdd = null; - int bestSize = currentSize; - List bestResults = null; - boolean improved = false; - - for (int i = 0; i < order.length - 1; i++) { - if (constraints.canMove(i, i + 1)) { - move(order, i, i + 1); - int swappedSize = countNodes(orderView); - if (swappedSize < bestSize) { - BddCompilationResult compilationResult = compileBddWithResults(orderView); - bestBdd = compilationResult.bdd; - bestSize = swappedSize; - bestResults = compilationResult.results; - improved = true; - } else { - // Swap back if no improvement - move(order, i + 1, i); - } - } + private OptimizationContext tryImprovePosition(OptimizationContext context, int varIdx, List positions) { + PositionResult best = findBestPosition(positions, context, varIdx); + if (best != null && best.count <= context.bestSize) { // Accept ties + move(context.order, varIdx, best.position); + return context.withImprovement(best); } - return new OptimizationResult(bestBdd, bestSize, improved, bestResults); + return context; } - private PositionCount findBestPosition( - List positions, - final Condition[] currentOrder, - final int currentSize, - final int varIdx - ) { - return positions.parallelStream() + private PositionResult findBestPosition(List positions, OptimizationContext ctx, int varIdx) { + return (positions.size() > PARALLEL_THRESHOLD ? positions.parallelStream() : positions.stream()) .map(pos -> { - Condition[] threadOrder = currentOrder.clone(); - move(threadOrder, varIdx, pos); - int nodeCount = countNodes(Arrays.asList(threadOrder)); - return new PositionCount(pos, nodeCount); + Condition[] order = ctx.order.clone(); + move(order, varIdx, pos); + BddCompilationResult cr = compileBddWithResults(Arrays.asList(order)); + return new PositionResult(pos, cr.bdd.getNodeCount() - 1, cr.bdd, cr.results); }) - .filter(pc -> pc.count < currentSize) - .min(Comparator.comparingInt(pc -> pc.count)) + .filter(pr -> pr.count <= ctx.bestSize) + .min(Comparator.comparingInt((PositionResult pr) -> pr.count).thenComparingInt(pr -> pr.position)) .orElse(null); } - private List getPositions( + private static List getStrategicPositions( int varIdx, ConditionDependencyGraph.OrderConstraints constraints, OptimizationEffort effort @@ -348,73 +292,64 @@ private List getPositions( int min = constraints.getMinValidPosition(varIdx); int max = constraints.getMaxValidPosition(varIdx); int range = max - min; - return range <= EXHAUSTIVE_THRESHOLD - ? getExhaustivePositions(varIdx, min, max, constraints) - : getStrategicPositions(varIdx, min, max, range, constraints, effort); - } - private List getExhaustivePositions( - int varIdx, - int min, - int max, - ConditionDependencyGraph.OrderConstraints constraints - ) { - List positions = new ArrayList<>(max - min); - for (int p = min; p < max; p++) { - if (p != varIdx && constraints.canMove(varIdx, p)) { - positions.add(p); + if (range <= effort.exhaustiveThreshold) { + List positions = new ArrayList<>(range); + for (int p = min; p < max; p++) { + if (p != varIdx && constraints.canMove(varIdx, p)) { + positions.add(p); + } } + return positions; } - return positions; - } - private List getStrategicPositions( - int varIdx, - int min, - int max, - int range, - ConditionDependencyGraph.OrderConstraints constraints, - OptimizationEffort effort - ) { List positions = new ArrayList<>(effort.maxPositions); - // Boundaries (these are most likely to be optimal) + // Test extremes first since they often yield the best improvements if (min != varIdx && constraints.canMove(varIdx, min)) { positions.add(min); } + if (positions.size() >= effort.maxPositions) { + return positions; + } + if (max - 1 != varIdx && constraints.canMove(varIdx, max - 1)) { positions.add(max - 1); } + if (positions.size() >= effort.maxPositions) { + return positions; + } - // Nearby positions (only if effort includes nearbyRadius) - if (effort.nearbyRadius > 0) { - for (int offset = -effort.nearbyRadius; offset <= effort.nearbyRadius; offset++) { - if (offset != 0) { - int p = varIdx + offset; - if (p >= min && p < max && !positions.contains(p) && constraints.canMove(varIdx, p)) { - positions.add(p); - } + // Test local moves that preserve relative ordering with neighbors + for (int offset = -effort.nearbyRadius; offset <= effort.nearbyRadius; offset++) { + if (offset != 0) { + if (positions.size() >= effort.maxPositions) { + return positions; + } + int p = varIdx + offset; + if (p >= min && p < max && !positions.contains(p) && constraints.canMove(varIdx, p)) { + positions.add(p); } } } - // Adaptive sampling: fewer samples for smaller ranges + // Sample intermediate positions to find global improvements + if (positions.size() >= effort.maxPositions) { + return positions; + } + int maxSamples = Math.min(15, effort.maxPositions / 2); int samples = Math.min(maxSamples, Math.max(2, range / 4)); int step = Math.max(1, range / samples); - for (int p = min + step; p < max - step; p += step) { + for (int p = min + step; p < max - step && positions.size() < effort.maxPositions; p += step) { if (p != varIdx && !positions.contains(p) && constraints.canMove(varIdx, p)) { positions.add(p); } } - return positions; } - /** - * Moves an element in an array from one position to another. - */ private static void move(Condition[] arr, int from, int to) { if (from == to) { return; @@ -422,18 +357,21 @@ private static void move(Condition[] arr, int from, int to) { Condition moving = arr[from]; if (from < to) { - // Moving right: shift elements left System.arraycopy(arr, from + 1, arr, from, to - from); } else { - // Moving left: shift elements right System.arraycopy(arr, to, arr, to + 1, from - to); } arr[to] = moving; } - /** - * Compiles a BDD with the given condition ordering and returns both BDD and results. - */ + private static Map rebuildIndex(List orderView) { + Map index = new IdentityHashMap<>(); + for (int i = 0; i < orderView.size(); i++) { + index.put(orderView.get(i), i); + } + return index; + } + private BddCompilationResult compileBddWithResults(List ordering) { BddBuilder builder = threadBuilder.get().reset(); BddCompiler compiler = new BddCompiler(cfg, OrderingStrategy.fixed(ordering), builder); @@ -441,15 +379,72 @@ private BddCompilationResult compileBddWithResults(List ordering) { return new BddCompilationResult(bdd, compiler.getIndexedResults()); } - /** - * Counts nodes for a given ordering without keeping the BDD. - */ - private int countNodes(List ordering) { - Bdd bdd = compileBddWithResults(ordering).bdd; - return bdd.getNodeCount() - 1; // -1 for terminal + // Helper class to track optimization context within a pass + private static final class OptimizationContext { + final Condition[] order; + final List orderView; + final ConditionDependencyGraph dependencyGraph; + final ConditionDependencyGraph.OrderConstraints constraints; + final Map liveIndex; + final Bdd bestBdd; + final int bestSize; + final List bestResults; + final int improvements; + + OptimizationContext(OptimizationState state, ConditionDependencyGraph dependencyGraph) { + this.order = state.order; + this.orderView = state.orderView; + this.dependencyGraph = dependencyGraph; + this.constraints = dependencyGraph.createOrderConstraints(orderView); + this.liveIndex = rebuildIndex(orderView); + this.bestBdd = null; + this.bestSize = state.currentSize; + this.bestResults = null; + this.improvements = 0; + } + + private OptimizationContext( + Condition[] order, + List orderView, + ConditionDependencyGraph dependencyGraph, + ConditionDependencyGraph.OrderConstraints constraints, + Map liveIndex, + Bdd bestBdd, + int bestSize, + List bestResults, + int improvements + ) { + this.order = order; + this.orderView = orderView; + this.dependencyGraph = dependencyGraph; + this.constraints = constraints; + this.liveIndex = liveIndex; + this.bestBdd = bestBdd; + this.bestSize = bestSize; + this.bestResults = bestResults; + this.improvements = improvements; + } + + OptimizationContext withImprovement(PositionResult result) { + ConditionDependencyGraph.OrderConstraints newConstraints = + dependencyGraph.createOrderConstraints(orderView); + Map newIndex = rebuildIndex(orderView); + return new OptimizationContext(order, + orderView, + dependencyGraph, + newConstraints, + newIndex, + result.bdd, + result.count, + result.results, + improvements + 1); + } + + OptimizationResult toResult() { + return new OptimizationResult(bestBdd, bestSize, improvements > 0, bestResults); + } } - // Container for BDD compilation results private static final class BddCompilationResult { final Bdd bdd; final List results; @@ -460,18 +455,20 @@ private static final class BddCompilationResult { } } - // Position and its node count - private static final class PositionCount { + private static final class PositionResult { final int position; final int count; + final Bdd bdd; + final List results; - PositionCount(int position, int count) { + PositionResult(int position, int count, Bdd bdd, List results) { this.position = position; this.count = count; + this.bdd = bdd; + this.results = results; } } - // Result of an optimization pass private static final class OptimizationResult { final Bdd bdd; final int size; @@ -486,7 +483,6 @@ private static final class OptimizationResult { } } - // State tracking during optimization private static final class OptimizationState { final Condition[] order; final List orderView; @@ -518,17 +514,14 @@ OptimizationState withResult(Bdd newBdd, int newSize, List newResults) { } } - /** - * Builder for SiftingOptimization. - */ public static final class Builder implements SmithyBuilder { private Cfg cfg; - private int coarseMinNodes = DEFAULT_COARSE_MIN_NODES; - private int coarseMaxPasses = DEFAULT_COARSE_MAX_PASSES; - private int mediumMinNodes = DEFAULT_MEDIUM_MIN_NODES; - private int mediumMaxPasses = DEFAULT_MEDIUM_MAX_PASSES; - private int granularMaxNodes = DEFAULT_GRANULAR_MAX_NODES; - private int granularMaxPasses = DEFAULT_GRANULAR_MAX_PASSES; + private int coarseMinNodes = OptimizationEffort.COARSE.defaultNodeThreshold; + private int coarseMaxPasses = OptimizationEffort.COARSE.defaultMaxPasses; + private int mediumMinNodes = OptimizationEffort.MEDIUM.defaultNodeThreshold; + private int mediumMaxPasses = OptimizationEffort.MEDIUM.defaultMaxPasses; + private int granularMaxNodes = OptimizationEffort.GRANULAR.defaultNodeThreshold; + private int granularMaxPasses = OptimizationEffort.GRANULAR.defaultMaxPasses; private Builder() {} @@ -549,8 +542,8 @@ public Builder cfg(Cfg cfg) { *

    Coarse optimization runs until the BDD has fewer than minNodeCount nodes * or maxPasses have been completed. * - * @param minNodeCount the target size to stop coarse optimization (default: 50,000) - * @param maxPasses the maximum number of coarse passes (default: 3) + * @param minNodeCount the target size to stop coarse optimization (default: 4,000) + * @param maxPasses the maximum number of coarse passes (default: 6) * @return this builder */ public Builder coarseEffort(int minNodeCount, int maxPasses) { @@ -565,8 +558,8 @@ public Builder coarseEffort(int minNodeCount, int maxPasses) { *

    Medium optimization runs until the BDD has fewer than minNodeCount nodes * or maxPasses have been completed. * - * @param minNodeCount the target size to stop medium optimization (default: 10,000) - * @param maxPasses the maximum number of medium passes (default: 4) + * @param minNodeCount the target size to stop medium optimization (default: 1,000) + * @param maxPasses the maximum number of medium passes (default: 6) * @return this builder */ public Builder mediumEffort(int minNodeCount, int maxPasses) { @@ -581,8 +574,8 @@ public Builder mediumEffort(int minNodeCount, int maxPasses) { *

    Granular optimization only runs if the BDD has fewer than maxNodeCount nodes, * and runs for at most maxPasses. * - * @param maxNodeCount the maximum size to attempt granular optimization (default: 3,000) - * @param maxPasses the maximum number of granular passes (default: 2) + * @param maxNodeCount the maximum size to attempt granular optimization (default: 8,000) + * @param maxPasses the maximum number of granular passes (default: 12) * @return this builder */ public Builder granularEffort(int maxNodeCount, int maxPasses) { From 0df404a9a2587308574c6db8f30a935e440aef1d Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Wed, 13 Aug 2025 14:24:17 -0500 Subject: [PATCH 14/23] Improve transforms Rename CfgGuidedOrdering to InitialOrdering --- .../logic/bdd/BddEquivalenceChecker.java | 2 +- .../rulesengine/logic/bdd/BddTrait.java | 2 +- ...idedOrdering.java => InitialOrdering.java} | 7 +- .../logic/bdd/OrderingStrategy.java | 2 +- .../smithy/rulesengine/logic/cfg/Cfg.java | 13 +- .../logic/cfg/CoalesceTransform.java | 116 ++-- .../rulesengine/logic/cfg/SsaTransform.java | 85 +-- ...ferenceRewriter.java => TreeRewriter.java} | 158 ++++- .../logic/cfg/VariableAnalysis.java | 101 ++- .../cfg/VariableConsolidationTransform.java | 285 +++++++++ .../smithy/rulesengine/logic/TestHelpers.java | 74 ++- ...ringTest.java => InitialOrderingTest.java} | 16 +- .../rulesengine/logic/cfg/CfgBuilderTest.java | 2 +- .../smithy/rulesengine/logic/cfg/CfgTest.java | 7 +- .../logic/cfg/CoalesceTransformTest.java | 575 ++++++++++++++++++ .../logic/cfg/ReferenceRewriterTest.java | 18 +- .../logic/cfg/VariableAnalysisTest.java | 142 +++++ 17 files changed, 1347 insertions(+), 258 deletions(-) rename smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/{CfgGuidedOrdering.java => InitialOrdering.java} (96%) rename smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/{ReferenceRewriter.java => TreeRewriter.java} (53%) create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java rename smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/{CfgGuidedOrderingTest.java => InitialOrderingTest.java} (96%) create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransformTest.java diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java index e59e9a47573..eb60b7abba1 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddEquivalenceChecker.java @@ -60,7 +60,7 @@ private BddEquivalenceChecker(Cfg cfg, Bdd bdd, List conditions, List this.bdd = bdd; this.conditions = conditions; this.results = results; - this.parameters = new ArrayList<>(cfg.getRuleSet().getParameters().toList()); + this.parameters = new ArrayList<>(cfg.getParameters().toList()); for (int i = 0; i < conditions.size(); i++) { conditionToIndex.put(conditions.get(i), i); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java index feecc4d38fe..09315363fbd 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java @@ -73,7 +73,7 @@ public static BddTrait from(Cfg cfg) { } return builder() - .parameters(cfg.getRuleSet().getParameters()) + .parameters(cfg.getParameters()) .conditions(compiler.getOrderedConditions()) .results(compiler.getIndexedResults()) .bdd(bdd) diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CfgGuidedOrdering.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrdering.java similarity index 96% rename from smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CfgGuidedOrdering.java rename to smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrdering.java index 009a3b724dc..39c8aabc277 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CfgGuidedOrdering.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrdering.java @@ -23,15 +23,15 @@ * together in the BDD, enabling better node sharing. This ordering implementation flattens the tree structure while * respecting data dependencies. */ -final class CfgGuidedOrdering implements OrderingStrategy { - private static final Logger LOGGER = Logger.getLogger(CfgGuidedOrdering.class.getName()); +final class InitialOrdering implements OrderingStrategy { + private static final Logger LOGGER = Logger.getLogger(InitialOrdering.class.getName()); /** How many distinct consumers make an isSet() a "gate". */ private static final int GATE_SUCCESSOR_THRESHOLD = 2; private final Cfg cfg; - CfgGuidedOrdering(Cfg cfg) { + InitialOrdering(Cfg cfg) { this.cfg = cfg; } @@ -51,6 +51,7 @@ public List orderConditions(Condition[] conditions) { long elapsed = System.currentTimeMillis() - startTime; LOGGER.info(() -> String.format("Initial ordering: %d conditions in %dms", conditions.length, elapsed)); + result.forEach(System.out::println); return result; } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderingStrategy.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderingStrategy.java index 46220e01b13..3378c319427 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderingStrategy.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/OrderingStrategy.java @@ -28,7 +28,7 @@ interface OrderingStrategy { * @return the initial ordering strategy. */ static OrderingStrategy initialOrdering(Cfg cfg) { - return new CfgGuidedOrdering(cfg); + return new InitialOrdering(cfg); } /** diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java index 14d0af25c2b..b31869bf29e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java @@ -17,6 +17,7 @@ import java.util.NoSuchElementException; import java.util.Set; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; @@ -38,7 +39,7 @@ */ public final class Cfg implements Iterable { - private final EndpointRuleSet ruleSet; + private final Parameters parameters; private final CfgNode root; // Lazily computed condition data @@ -46,8 +47,12 @@ public final class Cfg implements Iterable { private Map conditionToIndex; Cfg(EndpointRuleSet ruleSet, CfgNode root) { - this.ruleSet = ruleSet; + this(ruleSet == null ? Parameters.builder().build() : ruleSet.getParameters(), root); + } + + Cfg(Parameters parameters, CfgNode root) { this.root = SmithyBuilder.requiredState("root", root); + this.parameters = parameters; } /** @@ -125,8 +130,8 @@ private synchronized void extractConditions() { this.conditionToIndex = indexMap; } - public EndpointRuleSet getRuleSet() { - return ruleSet; + public Parameters getParameters() { + return parameters; } @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java index 2d9d29af337..365ebe3e77c 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransform.java @@ -45,21 +45,17 @@ static EndpointRuleSet transform(EndpointRuleSet ruleSet) { CoalesceTransform transform = new CoalesceTransform(); List transformedRules = new ArrayList<>(); - for (Rule rule : ruleSet.getRules()) { - transformedRules.add(transform.transformRule(rule)); + for (int i = 0; i < ruleSet.getRules().size(); i++) { + transformedRules.add(transform.transformRule(ruleSet.getRules().get(i), "root/rule[" + i + "]")); } if (LOGGER.isLoggable(Level.INFO)) { - StringBuilder msg = new StringBuilder(); - msg.append("\n=== Coalescing Transform Complete ===\n"); - msg.append("Total: ").append(transform.coalesceCount).append(" coalesced, "); - msg.append(transform.cacheHits).append(" cache hits, "); - msg.append(transform.skippedNoZeroValue).append(" skipped (no zero value), "); - msg.append(transform.skippedMultipleUses).append(" skipped (multiple uses)"); - if (!transform.skippedRecordTypes.isEmpty()) { - msg.append("\nSkipped record-returning functions: ").append(transform.skippedRecordTypes); - } - LOGGER.info(msg.toString()); + LOGGER.info(String.format( + "Coalescing: %d coalesced, %d cache hits, %d skipped (no zero), %d skipped (multiple uses)", + transform.coalesceCount, + transform.cacheHits, + transform.skippedNoZeroValue, + transform.skippedMultipleUses)); } return EndpointRuleSet.builder() @@ -69,50 +65,31 @@ static EndpointRuleSet transform(EndpointRuleSet ruleSet) { .build(); } - private Rule transformRule(Rule rule) { - Set eliminatedConditions = new HashSet<>(); - List conditions = rule.getConditions(); - Map localVarUsage = countLocalVariableUsage(conditions); - List transformedConditions = transformConditions(conditions, eliminatedConditions, localVarUsage); - - if (rule instanceof TreeRule) { - TreeRule treeRule = (TreeRule) rule; - List transformedNestedRules = new ArrayList<>(); - boolean nestedChanged = false; - - for (Rule nestedRule : treeRule.getRules()) { - Rule transformedNested = transformRule(nestedRule); - transformedNestedRules.add(transformedNested); - if (transformedNested != nestedRule) { - nestedChanged = true; - } - } - - if (!transformedConditions.equals(conditions) || nestedChanged) { - return TreeRule.builder() - .description(rule.getDocumentation().orElse(null)) - .conditions(transformedConditions) - .treeRule(transformedNestedRules); + private Rule transformRule(Rule rule, String rulePath) { + // Count local usage for THIS rule's conditions + Map localVarUsage = new HashMap<>(); + for (Condition condition : rule.getConditions()) { + for (String ref : condition.getFunction().getReferences()) { + localVarUsage.merge(ref, 1, Integer::sum); } - } else if (!transformedConditions.equals(conditions)) { - // For other rule types, just update conditions - return rule.withConditions(transformedConditions); } - return rule; - } - - private Map countLocalVariableUsage(List conditions) { - Map usage = new HashMap<>(); + Set eliminatedConditions = new HashSet<>(); + List transformedConditions = transformConditions( + rule.getConditions(), + eliminatedConditions, + localVarUsage); - // Count how many times each variable is used within this specific rule - for (Condition condition : conditions) { - for (String ref : condition.getFunction().getReferences()) { - usage.merge(ref, 1, Integer::sum); - } + if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + return TreeRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(transformedConditions) + .treeRule(TreeRewriter.transformNestedRules(treeRule, rulePath, this::transformRule)); } - return usage; + // CoalesceTransform only modifies conditions, not endpoints/errors + return rule.withConditions(transformedConditions); } private List transformConditions( @@ -128,25 +105,19 @@ private List transformConditions( continue; } - // Check if this is a bind that can be coalesced with the next condition if (i + 1 < conditions.size() && current.getResult().isPresent()) { String var = current.getResult().get().toString(); Condition next = conditions.get(i + 1); if (canCoalesce(var, current, next, localVarUsage)) { - // Create coalesced condition - Condition coalesced = createCoalescedCondition(current, next, var); - result.add(coalesced); - // Mark both conditions as eliminated + result.add(createCoalescedCondition(current, next, var)); eliminatedConditions.add(current); eliminatedConditions.add(next); - // Skip the next condition - i++; + i++; // Skip next continue; } } - // No coalescing possible, keep the condition as-is result.add(current); } @@ -155,28 +126,21 @@ private List transformConditions( private boolean canCoalesce(String var, Condition bind, Condition use, Map localVarUsage) { if (!use.getFunction().getReferences().contains(var)) { - // The use condition must reference the variable return false; - } else if (use.getFunction().getFunctionDefinition() == IsSet.getDefinition()) { - // Never coalesce into presence checks (isSet) + } + + if (use.getFunction().getFunctionDefinition() == IsSet.getDefinition()) { return false; } - // Check if variable is only used once in this local rule context (even if it appears multiple times globally) Integer localUses = localVarUsage.get(var); if (localUses == null || localUses > 1) { skippedMultipleUses++; return false; } - // Get the actual return type (could be Optional or T) Type type = bind.getFunction().getFunctionDefinition().getReturnType(); - - // Check if we can get a zero value for this type. For OptionalType, we use the inner type's zero value - Type innerType = type; - if (type instanceof OptionalType) { - innerType = ((OptionalType) type).inner(); - } + Type innerType = type instanceof OptionalType ? ((OptionalType) type).inner() : type; if (innerType instanceof RecordType) { skippedNoZeroValue++; @@ -196,16 +160,10 @@ private Condition createCoalescedCondition(Condition bind, Condition use, String LibraryFunction bindExpr = bind.getFunction(); LibraryFunction useExpr = use.getFunction(); - // Get the type and its zero value Type type = bindExpr.getFunctionDefinition().getReturnType(); - Type innerType = type; - if (type instanceof OptionalType) { - innerType = ((OptionalType) type).inner(); - } - + Type innerType = type instanceof OptionalType ? ((OptionalType) type).inner() : type; Literal zero = innerType.getZeroValue().get(); - // Create cache key based on canonical representations String bindCanonical = bindExpr.canonicalize().toString(); String zeroCanonical = zero.toString(); String useCanonical = useExpr.canonicalize().toString(); @@ -219,12 +177,10 @@ private Condition createCoalescedCondition(Condition bind, Condition use, String } Expression coalesced = Coalesce.ofExpressions(bindExpr, zero); - - // Replace the variable reference in the use expression Map replacements = new HashMap<>(); replacements.put(var, coalesced); - ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); - Expression replaced = rewriter.rewrite(useExpr); + + Expression replaced = TreeRewriter.forReplacements(replacements).rewrite(useExpr); LibraryFunction canonicalized = ((LibraryFunction) replaced).canonicalize(); Condition.Builder builder = Condition.builder().fn(canonicalized); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java index 127bfc83bad..da6eda2f02f 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java @@ -10,7 +10,6 @@ import java.util.HashMap; import java.util.HashSet; import java.util.IdentityHashMap; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -19,7 +18,6 @@ import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; @@ -39,20 +37,17 @@ */ final class SsaTransform { - // Stack of scopes, each mapping original variable names to their current SSA names private final Deque> scopeStack = new ArrayDeque<>(); private final Map rewrittenConditions = new IdentityHashMap<>(); private final Map rewrittenRules = new IdentityHashMap<>(); private final VariableAnalysis variableAnalysis; - private final ReferenceRewriter referenceRewriter; + private final TreeRewriter referenceRewriter; private SsaTransform(VariableAnalysis variableAnalysis) { - // Start with an empty global scope scopeStack.push(new HashMap<>()); this.variableAnalysis = variableAnalysis; - // Create a reference rewriter that uses our scope resolution - this.referenceRewriter = new ReferenceRewriter( + this.referenceRewriter = new TreeRewriter( ref -> { String originalName = ref.getName().toString(); String uniqueName = resolveReference(originalName); @@ -62,13 +57,10 @@ private SsaTransform(VariableAnalysis variableAnalysis) { } static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + ruleSet = VariableConsolidationTransform.transform(ruleSet); ruleSet = CoalesceTransform.transform(ruleSet); + SsaTransform ssaTransform = new SsaTransform(VariableAnalysis.analyze(ruleSet)); - // Use VariableAnalysis to get all the information we need - VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); - - // Rewrite with the pre-computed mappings - SsaTransform ssaTransform = new SsaTransform(analysis); List rewrittenRules = new ArrayList<>(ruleSet.getRules().size()); for (Rule original : ruleSet.getRules()) { rewrittenRules.add(ssaTransform.processRule(original)); @@ -88,12 +80,10 @@ private Rule processRule(Rule rule) { return rewrittenRule; } - // Enters a new scope, inheriting all variable mappings from the parent scope private void enterScope() { scopeStack.push(new HashMap<>(scopeStack.peek())); } - // Exits the current scope, reverting to the parent scope's variable mappings private void exitScope() { if (scopeStack.size() <= 1) { throw new IllegalStateException("Cannot exit global scope"); @@ -101,11 +91,9 @@ private void exitScope() { scopeStack.pop(); } - // Rewrites a condition's bindings and references to use SSA names private Condition rewriteCondition(Condition condition) { boolean hasBinding = condition.getResult().isPresent(); - // Check cache for non-binding conditions if (!hasBinding) { Condition cached = rewrittenConditions.get(condition); if (cached != null) { @@ -116,19 +104,21 @@ private Condition rewriteCondition(Condition condition) { LibraryFunction fn = condition.getFunction(); Set rewritableRefs = filterOutInputParameters(fn.getReferences()); - // Determine if this binding needs an SSA name String uniqueBindingName = null; boolean needsUniqueBinding = false; if (hasBinding) { String varName = condition.getResult().get().toString(); - Map expressionMap = variableAnalysis.getExpressionMappings().get(varName); - if (expressionMap != null) { - uniqueBindingName = expressionMap.get(fn.toString()); - needsUniqueBinding = uniqueBindingName != null && !uniqueBindingName.equals(varName); + + // Only need SSA rename if variable has multiple bindings + if (variableAnalysis.hasMultipleBindings(varName)) { + Map expressionMap = variableAnalysis.getExpressionMappings().get(varName); + if (expressionMap != null) { + uniqueBindingName = expressionMap.get(fn.toString()); + needsUniqueBinding = uniqueBindingName != null && !uniqueBindingName.equals(varName); + } } } - // Early return if no rewriting needed if (!needsRewriting(rewritableRefs) && !needsUniqueBinding) { if (!hasBinding) { rewrittenConditions.put(condition, condition); @@ -136,16 +126,12 @@ private Condition rewriteCondition(Condition condition) { return condition; } - // Rewrite the expression LibraryFunction rewrittenExpr = (LibraryFunction) referenceRewriter.rewrite(fn); boolean exprChanged = rewrittenExpr != fn; - // Build the rewritten condition Condition rewritten; if (hasBinding && uniqueBindingName != null) { - // Update scope with the SSA name scopeStack.peek().put(condition.getResult().get().toString(), uniqueBindingName); - if (needsUniqueBinding || exprChanged) { rewritten = condition.toBuilder().fn(rewrittenExpr).result(Identifier.of(uniqueBindingName)).build(); } else { @@ -157,7 +143,6 @@ private Condition rewriteCondition(Condition condition) { rewritten = condition; } - // Cache non-binding conditions if (!hasBinding) { rewrittenConditions.put(condition, rewritten); } @@ -169,12 +154,12 @@ private Set filterOutInputParameters(Set references) { if (references.isEmpty() || variableAnalysis.getInputParams().isEmpty()) { return references; } + Set filtered = new HashSet<>(references); filtered.removeAll(variableAnalysis.getInputParams()); return filtered; } - // Check if any references in scope need to be rewritten to SSA names private boolean needsRewriting(Set references) { if (references.isEmpty()) { return false; @@ -194,7 +179,6 @@ private boolean needsRewriting(Expression expression) { return needsRewriting(filterOutInputParameters(expression.getReferences())); } - // Rewrites a rule's conditions to use SSA names private Rule rewriteRule(Rule rule) { Rule cached = rewrittenRules.get(rule); if (cached != null) { @@ -234,26 +218,13 @@ private Rule rewriteEndpointRule( List rewrittenConditions, boolean conditionsChanged ) { - Endpoint endpoint = rule.getEndpoint(); + Endpoint rewrittenEndpoint = referenceRewriter.rewriteEndpoint(rule.getEndpoint()); - // Rewrite endpoint components to use SSA names - Expression rewrittenUrl = referenceRewriter.rewrite(endpoint.getUrl()); - Map> rewrittenHeaders = rewriteHeaders(endpoint.getHeaders()); - Map rewrittenProperties = rewriteProperties(endpoint.getProperties()); - - boolean endpointChanged = rewrittenUrl != endpoint.getUrl() - || !rewrittenHeaders.equals(endpoint.getHeaders()) - || !rewrittenProperties.equals(endpoint.getProperties()); - - if (conditionsChanged || endpointChanged) { + if (conditionsChanged || rewrittenEndpoint != rule.getEndpoint()) { return EndpointRule.builder() .description(rule.getDocumentation().orElse(null)) .conditions(rewrittenConditions) - .endpoint(Endpoint.builder() - .url(rewrittenUrl) - .headers(rewrittenHeaders) - .properties(rewrittenProperties) - .build()); + .endpoint(rewrittenEndpoint); } return rule; @@ -296,30 +267,6 @@ private Rule rewriteTreeRule(TreeRule rule, List rewrittenConditions, return rule; } - private Map> rewriteHeaders(Map> headers) { - Map> rewritten = new LinkedHashMap<>(); - for (Map.Entry> entry : headers.entrySet()) { - List rewrittenValues = new ArrayList<>(); - for (Expression expr : entry.getValue()) { - rewrittenValues.add(referenceRewriter.rewrite(expr)); - } - rewritten.put(entry.getKey(), rewrittenValues); - } - return rewritten; - } - - private Map rewriteProperties(Map properties) { - Map rewritten = new LinkedHashMap<>(); - for (Map.Entry entry : properties.entrySet()) { - Expression rewrittenExpr = referenceRewriter.rewrite(entry.getValue()); - if (!(rewrittenExpr instanceof Literal)) { - throw new IllegalStateException("Property value must be a literal"); - } - rewritten.put(entry.getKey(), (Literal) rewrittenExpr); - } - return rewritten; - } - private String resolveReference(String originalName) { // Input parameters are never rewritten return variableAnalysis.getInputParams().contains(originalName) diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriter.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeRewriter.java similarity index 53% rename from smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriter.java rename to smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeRewriter.java index 6dda29af584..7690a1cffd9 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriter.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/TreeRewriter.java @@ -8,9 +8,11 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Predicate; import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.rulesengine.language.Endpoint; import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; @@ -21,11 +23,15 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.RecordLiteral; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.TupleLiteral; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; /** * Utility for rewriting references within expression trees. */ -final class ReferenceRewriter { +final class TreeRewriter { + // A no-op rewriter that returns expressions unchanged. + static final TreeRewriter IDENTITY = new TreeRewriter(ref -> ref, expr -> false); private final Function referenceTransformer; private final Predicate shouldRewrite; @@ -36,7 +42,7 @@ final class ReferenceRewriter { * @param referenceTransformer function to transform references * @param shouldRewrite predicate to determine if an expression needs rewriting */ - ReferenceRewriter( + TreeRewriter( Function referenceTransformer, Predicate shouldRewrite ) { @@ -44,13 +50,45 @@ final class ReferenceRewriter { this.shouldRewrite = shouldRewrite; } + /** + * Creates a simple rewriter that replaces specific references. + * + * @param replacements map of variable names to replacement expressions + * @return a reference rewriter that performs the replacements + */ + static TreeRewriter forReplacements(Map replacements) { + if (replacements.isEmpty()) { + return IDENTITY; + } + return new TreeRewriter( + ref -> replacements.getOrDefault(ref.getName().toString(), ref), + expr -> expr.getReferences().stream().anyMatch(replacements::containsKey)); + } + + static List transformNestedRules( + TreeRule tree, + String parentPath, + BiFunction transformer + ) { + List result = new ArrayList<>(); + for (int i = 0; i < tree.getRules().size(); i++) { + Rule transformed = transformer.apply( + tree.getRules().get(i), + parentPath + "/tree/rule[" + i + "]"); + if (transformed != null) { + result.add(transformed); + } + } + return result; + } + /** * Rewrites references within an expression tree. * * @param expression the expression to rewrite * @return the rewritten expression, or the original if no changes needed */ - public Expression rewrite(Expression expression) { + Expression rewrite(Expression expression) { if (!shouldRewrite.test(expression)) { return expression; } @@ -70,6 +108,108 @@ public Expression rewrite(Expression expression) { return expression; } + Map> rewriteHeaders(Map> headers) { + if (headers.isEmpty()) { + return headers; + } + + Map> rewritten = null; + boolean changed = false; + + for (Map.Entry> entry : headers.entrySet()) { + List originalValues = entry.getValue(); + List rewrittenValues = null; + + for (int i = 0; i < originalValues.size(); i++) { + Expression original = originalValues.get(i); + Expression rewrittenExpr = rewrite(original); + + if (rewrittenExpr != original) { + if (rewrittenValues == null) { + rewrittenValues = new ArrayList<>(originalValues.subList(0, i)); + } + rewrittenValues.add(rewrittenExpr); + changed = true; + } else if (rewrittenValues != null) { + rewrittenValues.add(original); + } + } + + if (changed && rewritten == null) { + rewritten = new LinkedHashMap<>(); + // Copy all previous entries + for (Map.Entry> prev : headers.entrySet()) { + if (prev.getKey().equals(entry.getKey())) { + break; + } + rewritten.put(prev.getKey(), prev.getValue()); + } + } + + if (rewritten != null) { + rewritten.put(entry.getKey(), + rewrittenValues != null ? rewrittenValues : originalValues); + } + } + + return changed ? rewritten : headers; + } + + Map rewriteProperties(Map properties) { + if (properties.isEmpty()) { + return properties; + } + + Map rewritten = null; + boolean changed = false; + + for (Map.Entry entry : properties.entrySet()) { + Expression rewrittenExpr = rewrite(entry.getValue()); + + if (rewrittenExpr != entry.getValue()) { + if (!(rewrittenExpr instanceof Literal)) { + throw new IllegalStateException("Property value must be a literal"); + } + + if (rewritten == null) { + rewritten = new LinkedHashMap<>(); + // Copy all previous entries + for (Map.Entry prev : properties.entrySet()) { + if (prev.getKey().equals(entry.getKey())) { + break; + } + rewritten.put(prev.getKey(), prev.getValue()); + } + } + + rewritten.put(entry.getKey(), (Literal) rewrittenExpr); + changed = true; + } else if (rewritten != null) { + rewritten.put(entry.getKey(), entry.getValue()); + } + } + + return changed ? rewritten : properties; + } + + Endpoint rewriteEndpoint(Endpoint endpoint) { + Expression rewrittenUrl = rewrite(endpoint.getUrl()); + Map> rewrittenHeaders = rewriteHeaders(endpoint.getHeaders()); + Map rewrittenProperties = rewriteProperties(endpoint.getProperties()); + + // Only create new endpoint if something changed + if (rewrittenUrl != endpoint.getUrl() + || rewrittenHeaders != endpoint.getHeaders() + || rewrittenProperties != endpoint.getProperties()) { + return Endpoint.builder() + .url(rewrittenUrl) + .headers(rewrittenHeaders) + .properties(rewrittenProperties) + .build(); + } + return endpoint; + } + private Expression rewriteStringLiteral(StringLiteral str) { Template template = str.value(); if (template.isStatic()) { @@ -149,16 +289,4 @@ private Expression rewriteLibraryFunction(LibraryFunction fn) { .build(); return fn.getFunctionDefinition().createFunction(node); } - - /** - * Creates a simple rewriter that replaces specific references. - * - * @param replacements map of variable names to replacement expressions - * @return a reference rewriter that performs the replacements - */ - public static ReferenceRewriter forReplacements(Map replacements) { - return new ReferenceRewriter( - ref -> replacements.getOrDefault(ref.getName().toString(), ref), - expr -> expr.getReferences().stream().anyMatch(replacements::containsKey)); - } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java index f870024c60a..33d995239df 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java @@ -36,80 +36,62 @@ final class VariableAnalysis { private final Map> bindings; private final Map referenceCounts; private final Map> expressionMappings; + private final Map> expressionToVars; private VariableAnalysis( Set inputParams, Map> bindings, - Map referenceCounts + Map referenceCounts, + Map> expressionToVars ) { this.inputParams = inputParams; this.bindings = bindings; this.referenceCounts = referenceCounts; + this.expressionToVars = expressionToVars; this.expressionMappings = createExpressionMappings(bindings); } static VariableAnalysis analyze(EndpointRuleSet ruleSet) { Set inputParameters = extractInputParameters(ruleSet); - Map> variableBindings = new HashMap<>(); - Map referenceCounts = new HashMap<>(); - Visitor visitor = new Visitor(variableBindings, referenceCounts); + AnalysisVisitor visitor = new AnalysisVisitor(inputParameters); for (Rule rule : ruleSet.getRules()) { visitor.visitRule(rule); } - return new VariableAnalysis(inputParameters, variableBindings, referenceCounts); + return new VariableAnalysis( + inputParameters, + visitor.bindings, + visitor.referenceCounts, + visitor.expressionToVars); } Set getInputParams() { return inputParams; } - /** - * Gets the mapping from variable name to expression to SSA name. - * This is used to determine when variables need SSA renaming. - */ Map> getExpressionMappings() { return expressionMappings; } - /** - * Gets the reference count for a variable. - * - * @param variableName the variable name - * @return the number of times the variable is referenced, or 0 if not found - */ int getReferenceCount(String variableName) { return referenceCounts.getOrDefault(variableName, 0); } - /** - * Checks if a variable is referenced exactly once. - * - * @param variableName the variable name - * @return true if the variable is referenced exactly once - */ boolean isReferencedOnce(String variableName) { return getReferenceCount(variableName) == 1; } - /** - * Checks if a variable has only one binding expression. - * - * @param variableName the variable name - * @return true if the variable has exactly one binding - */ boolean hasSingleBinding(String variableName) { Set expressions = bindings.get(variableName); return expressions != null && expressions.size() == 1; } - /** - * Checks if a variable is safe to inline (single binding, single reference). - * - * @param variableName the variable name - * @return true if the variable can be safely inlined - */ + boolean hasMultipleBindings(String variableName) { + Set expressions = bindings.get(variableName); + return expressions != null && expressions.size() > 1; + } + boolean isSafeToInline(String variableName) { return hasSingleBinding(variableName) && isReferencedOnce(variableName); } @@ -122,10 +104,6 @@ private static Set extractInputParameters(EndpointRuleSet ruleSet) { return inputParameters; } - /** - * Creates a mapping from variable name to expression to SSA name. - * Variables assigned multiple times get unique SSA names (x, x_1, x_2, etc). - */ private static Map> createExpressionMappings( Map> bindings ) { @@ -145,13 +123,11 @@ private static Map createMappingForVariable( Map mapping = new HashMap<>(); if (expressions.size() == 1) { - // Only one expression for this variable, so no SSA renaming needed String expression = expressions.iterator().next(); mapping.put(expression, varName); } else { - // Multiple expressions, so create unique SSA names List sortedExpressions = new ArrayList<>(expressions); - sortedExpressions.sort(String::compareTo); // Ensure deterministic ordering + sortedExpressions.sort(String::compareTo); for (int i = 0; i < sortedExpressions.size(); i++) { String expression = sortedExpressions.get(i); @@ -163,23 +139,30 @@ private static Map createMappingForVariable( return mapping; } - // Visitor that collects variable bindings and reference counts. - private static class Visitor { - private final Map> variableBindings; - private final Map referenceCounts; + private static class AnalysisVisitor { + final Map> bindings = new HashMap<>(); + final Map referenceCounts = new HashMap<>(); + final Map> expressionToVars = new HashMap<>(); - Visitor(Map> variableBindings, Map referenceCounts) { - this.variableBindings = variableBindings; - this.referenceCounts = referenceCounts; + private final Set inputParams; + + AnalysisVisitor(Set inputParams) { + this.inputParams = inputParams; } void visitRule(Rule rule) { for (Condition condition : rule.getConditions()) { if (condition.getResult().isPresent()) { String varName = condition.getResult().get().toString(); - String expression = condition.getFunction().toString(); - variableBindings.computeIfAbsent(varName, k -> new HashSet<>()) + LibraryFunction fn = condition.getFunction(); + String expression = fn.toString(); + String canonical = fn.canonicalize().toString(); + + bindings.computeIfAbsent(varName, k -> new HashSet<>()) .add(expression); + + expressionToVars.computeIfAbsent(canonical, k -> new ArrayList<>()) + .add(varName); } countReferences(condition.getFunction()); @@ -194,17 +177,16 @@ void visitRule(Rule rule) { EndpointRule endpointRule = (EndpointRule) rule; Endpoint endpoint = endpointRule.getEndpoint(); countReferences(endpoint.getUrl()); - for (List headerValues : endpoint.getHeaders().values()) { - for (Expression expr : headerValues) { - countReferences(expr); - } - } - for (Literal literal : endpoint.getProperties().values()) { - countReferences(literal); - } + endpoint.getHeaders() + .values() + .stream() + .flatMap(List::stream) + .forEach(this::countReferences); + endpoint.getProperties() + .values() + .forEach(this::countReferences); } else if (rule instanceof ErrorRule) { - ErrorRule errorRule = (ErrorRule) rule; - countReferences(errorRule.getError()); + countReferences(((ErrorRule) rule).getError()); } } @@ -212,6 +194,7 @@ private void countReferences(Expression expression) { if (expression instanceof Reference) { Reference ref = (Reference) expression; String name = ref.getName().toString(); + // Count all references, including input parameters referenceCounts.merge(name, 1, Integer::sum); } else if (expression instanceof StringLiteral) { StringLiteral str = (StringLiteral) expression; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java new file mode 100644 index 00000000000..b0846e786b3 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java @@ -0,0 +1,285 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; + +/** + * Consolidates variable names for identical expressions and eliminates redundant bindings. + * + *

    This transform identifies conditions that compute the same expression but assign + * the result to different variable names, and either consolidates them to use the same + * name or eliminates redundant bindings when the same expression is already bound in + * an ancestor scope. + */ +final class VariableConsolidationTransform { + private static final Logger LOGGER = Logger.getLogger(VariableConsolidationTransform.class.getName()); + + // Global map of canonical expressions to their first variable name seen + private final Map globalExpressionToVar = new HashMap<>(); + + // Maps old variable names to new canonical names for rewriting references + private final Map variableRenameMap = new HashMap<>(); + + // Tracks conditions to eliminate (by their path in the tree) + private final Set conditionsToEliminate = new HashSet<>(); + + // Tracks all variables defined at each scope level to check for conflicts + private final Map> scopeDefinedVars = new HashMap<>(); + + private int consolidatedCount = 0; + private int eliminatedCount = 0; + private int skippedDueToShadowing = 0; + + public static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + VariableConsolidationTransform transform = new VariableConsolidationTransform(); + return transform.consolidate(ruleSet); + } + + private EndpointRuleSet consolidate(EndpointRuleSet ruleSet) { + LOGGER.info("Starting variable consolidation transform"); + + for (int i = 0; i < ruleSet.getRules().size(); i++) { + collectDefinitions(ruleSet.getRules().get(i), "rule[" + i + "]"); + } + + for (int i = 0; i < ruleSet.getRules().size(); i++) { + discoverBindingsInRule(ruleSet.getRules().get(i), "rule[" + i + "]", new HashMap<>(), new HashSet<>()); + } + + List transformedRules = new ArrayList<>(); + for (int i = 0; i < ruleSet.getRules().size(); i++) { + transformedRules.add(transformRule(ruleSet.getRules().get(i), "rule[" + i + "]")); + } + + LOGGER.info(String.format("Variable consolidation: %d consolidated, %d eliminated, %d skipped due to shadowing", + consolidatedCount, + eliminatedCount, + skippedDueToShadowing)); + + return EndpointRuleSet.builder() + .parameters(ruleSet.getParameters()) + .rules(transformedRules) + .version(ruleSet.getVersion()) + .build(); + } + + private void collectDefinitions(Rule rule, String path) { + Set definedVars = new HashSet<>(); + + // Collect all variables defined at this scope level + for (Condition condition : rule.getConditions()) { + if (condition.getResult().isPresent()) { + definedVars.add(condition.getResult().get().toString()); + } + } + + scopeDefinedVars.put(path, definedVars); + + if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + for (int i = 0; i < treeRule.getRules().size(); i++) { + collectDefinitions(treeRule.getRules().get(i), path + "/tree/rule[" + i + "]"); + } + } + } + + private void discoverBindingsInRule( + Rule rule, + String path, + Map parentBindings, + Set ancestorVars + ) { + // Track bindings at current scope (inherits parent bindings) + Map currentBindings = new HashMap<>(parentBindings); + // Track all variables visible from ancestors (for shadowing check) + Set visibleAncestorVars = new HashSet<>(ancestorVars); + + for (int i = 0; i < rule.getConditions().size(); i++) { + Condition condition = rule.getConditions().get(i); + String condPath = path + "/cond[" + i + "]"; + + if (condition.getResult().isPresent()) { + String varName = condition.getResult().get().toString(); + LibraryFunction fn = condition.getFunction(); + String canonical = fn.canonicalize().toString(); + + // Check if this expression is already bound in parent scope + String parentVar = parentBindings.get(canonical); + if (parentVar != null) { + // Found duplicate in parent, eliminate this binding + variableRenameMap.put(varName, parentVar); + conditionsToEliminate.add(condPath); + eliminatedCount++; + LOGGER.info(String.format("Eliminating redundant binding at %s: '%s' -> '%s' for: %s", + condPath, + varName, + parentVar, + canonical)); + } else { + // Not bound in parent, add to current scope + currentBindings.put(canonical, varName); + visibleAncestorVars.add(varName); + + // Check for global consolidation opportunity + String globalVar = globalExpressionToVar.get(canonical); + if (globalVar != null && !globalVar.equals(varName)) { + // Same expression elsewhere with different name + // Check if consolidation would cause shadowing + if (!wouldCauseShadowing(globalVar, path, ancestorVars)) { + variableRenameMap.put(varName, globalVar); + consolidatedCount++; + LOGGER.info(String.format("Consolidating '%s' -> '%s' for: %s", + varName, + globalVar, + canonical)); + } else { + skippedDueToShadowing++; + LOGGER.fine(String.format("Cannot consolidate '%s' -> '%s' (would shadow) for: %s", + varName, + globalVar, + canonical)); + } + } else if (globalVar == null) { + // First time seeing this expression globally + globalExpressionToVar.put(canonical, varName); + } + } + } + } + + if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + for (int i = 0; i < treeRule.getRules().size(); i++) { + discoverBindingsInRule( + treeRule.getRules().get(i), + path + "/tree/rule[" + i + "]", + currentBindings, + visibleAncestorVars); + } + } + } + + private boolean wouldCauseShadowing(String varName, String currentPath, Set ancestorVars) { + // Check if using this variable name would shadow an ancestor variable + if (ancestorVars.contains(varName)) { + return true; + } + + // Check if any child scope already defines this variable + // (which would be shadowed if we use it here) + for (Map.Entry> entry : scopeDefinedVars.entrySet()) { + String scopePath = entry.getKey(); + Set scopeVars = entry.getValue(); + // Check if this scope is a descendant of current path + if (scopePath.startsWith(currentPath + "/") && scopeVars.contains(varName)) { + return true; + } + } + + return false; + } + + private Rule transformRule(Rule rule, String path) { + List transformedConditions = new ArrayList<>(); + + for (int i = 0; i < rule.getConditions().size(); i++) { + String condPath = path + "/cond[" + i + "]"; + + if (conditionsToEliminate.contains(condPath)) { + // Skip this condition entirely since it's redundant + continue; + } + + Condition condition = rule.getConditions().get(i); + transformedConditions.add(transformCondition(condition)); + } + + if (rule instanceof TreeRule) { + TreeRule treeRule = (TreeRule) rule; + return TreeRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(transformedConditions) + .treeRule(TreeRewriter.transformNestedRules(treeRule, path, this::transformRule)); + + } else if (rule instanceof EndpointRule) { + EndpointRule endpointRule = (EndpointRule) rule; + TreeRewriter rewriter = createRewriter(); + + return EndpointRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(transformedConditions) + .endpoint(rewriter.rewriteEndpoint(endpointRule.getEndpoint())); + + } else if (rule instanceof ErrorRule) { + ErrorRule errorRule = (ErrorRule) rule; + TreeRewriter rewriter = createRewriter(); + + return ErrorRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(transformedConditions) + .error(rewriter.rewrite(errorRule.getError())); + } + + return rule.withConditions(transformedConditions); + } + + private Condition transformCondition(Condition condition) { + // Rewrite any references in the function + TreeRewriter rewriter = createRewriter(); + LibraryFunction fn = condition.getFunction(); + LibraryFunction rewrittenFn = (LibraryFunction) rewriter.rewrite(fn); + + // If this condition assigns to a variable that should be renamed, + // use the canonical name instead + if (condition.getResult().isPresent()) { + String varName = condition.getResult().get().toString(); + String canonicalName = variableRenameMap.get(varName); + + if (canonicalName != null) { + // This variable is being consolidated, use the canonical name + return Condition.builder() + .fn(rewrittenFn) + .result(Identifier.of(canonicalName)) + .build(); + } + } + + // No consolidation needed, but may still need reference rewriting + if (rewrittenFn != fn) { + return condition.toBuilder().fn(rewrittenFn).build(); + } + + return condition; + } + + private TreeRewriter createRewriter() { + if (variableRenameMap.isEmpty()) { + return TreeRewriter.IDENTITY; + } + + Map replacements = new HashMap<>(); + for (Map.Entry entry : variableRenameMap.entrySet()) { + replacements.put(entry.getKey(), Expression.getReference(Identifier.of(entry.getValue()))); + } + + return TreeRewriter.forReplacements(replacements); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/TestHelpers.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/TestHelpers.java index 093fa7e49a8..0b5834ea7d4 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/TestHelpers.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/TestHelpers.java @@ -7,12 +7,16 @@ import software.amazon.smithy.rulesengine.language.Endpoint; import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; -import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsValidHostLabel; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.ParseUrl; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Substring; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.UriEncode; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; @@ -30,17 +34,81 @@ public static LibraryFunction stringEquals(String paramName, String value) { StringLiteral.of(value)); } + public static LibraryFunction stringEquals(Expression expr1, Expression expr2) { + return StringEquals.ofExpressions(expr1, expr2); + } + public static LibraryFunction booleanEquals(String paramName, boolean value) { return BooleanEquals.ofExpressions( Expression.getReference(Identifier.of(paramName)), Literal.booleanLiteral(value)); } - public static LibraryFunction parseUrl(String urlTemplate) { - return ParseUrl.ofExpressions(Literal.stringLiteral(Template.fromString(urlTemplate))); + public static LibraryFunction booleanEquals(Expression expr, boolean value) { + return BooleanEquals.ofExpressions(expr, Literal.booleanLiteral(value)); + } + + public static LibraryFunction parseUrl(String paramName) { + return ParseUrl.ofExpressions(Expression.getReference(Identifier.of(paramName))); + } + + public static LibraryFunction getAttr(Expression expr, String path) { + return GetAttr.ofExpressions(expr, Literal.of(path)); + } + + public static LibraryFunction getAttr(String paramName, String path) { + return GetAttr.ofExpressions( + Expression.getReference(Identifier.of(paramName)), + Literal.of(path)); + } + + public static LibraryFunction substring(String paramName, int start, int stop, boolean reverse) { + return Substring.ofExpressions( + Expression.getReference(Identifier.of(paramName)), + Literal.of(start), + Literal.of(stop), + Literal.of(reverse)); + } + + public static LibraryFunction substring(Expression expr, int start, int stop, boolean reverse) { + return Substring.ofExpressions( + expr, + Literal.of(start), + Literal.of(stop), + Literal.of(reverse)); + } + + public static LibraryFunction not(Expression expr) { + return Not.ofExpressions(expr); + } + + public static LibraryFunction not(LibraryFunction fn) { + return Not.ofExpressions(fn); + } + + public static LibraryFunction isValidHostLabel(String paramName, boolean allowDots) { + return IsValidHostLabel.ofExpressions( + Expression.getReference(Identifier.of(paramName)), + Literal.of(allowDots)); + } + + public static LibraryFunction isValidHostLabel(Expression expr, boolean allowDots) { + return IsValidHostLabel.ofExpressions(expr, Literal.of(allowDots)); + } + + public static LibraryFunction uriEncode(String paramName) { + return UriEncode.ofExpressions(Expression.getReference(Identifier.of(paramName))); + } + + public static LibraryFunction uriEncode(Expression expr) { + return UriEncode.ofExpressions(expr); } public static Endpoint endpoint(String url) { return Endpoint.builder().url(Expression.of(url)).build(); } + + public static Endpoint endpoint(Expression url) { + return Endpoint.builder().url(url).build(); + } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/CfgGuidedOrderingTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrderingTest.java similarity index 96% rename from smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/CfgGuidedOrderingTest.java rename to smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrderingTest.java index fa7374c1357..bddb4a406ea 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/CfgGuidedOrderingTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrderingTest.java @@ -28,7 +28,7 @@ import software.amazon.smithy.rulesengine.logic.cfg.Cfg; import software.amazon.smithy.utils.ListUtils; -class CfgGuidedOrderingTest { +class InitialOrderingTest { @Test void testSimpleOrdering() { @@ -45,7 +45,7 @@ void testSimpleOrdering() { .build(); Cfg cfg = Cfg.from(ruleSet); - CfgGuidedOrdering ordering = new CfgGuidedOrdering(cfg); + InitialOrdering ordering = new InitialOrdering(cfg); List ordered = ordering.orderConditions(cfg.getConditions()); @@ -89,7 +89,7 @@ void testDependencyOrdering() { // If coalesce merged them, we'll have 1 condition. Otherwise 2. assertTrue(conditions.length >= 1 && conditions.length <= 2); - CfgGuidedOrdering ordering = new CfgGuidedOrdering(cfg); + InitialOrdering ordering = new InitialOrdering(cfg); List ordered = ordering.orderConditions(conditions); assertEquals(conditions.length, ordered.size()); @@ -153,7 +153,7 @@ void testGateConditionPriority() { .build(); Cfg cfg = Cfg.from(ruleSet); - CfgGuidedOrdering ordering = new CfgGuidedOrdering(cfg); + InitialOrdering ordering = new InitialOrdering(cfg); List ordered = ordering.orderConditions(cfg.getConditions()); @@ -198,7 +198,7 @@ void testNestedTreeOrdering() { .build(); Cfg cfg = Cfg.from(ruleSet); - CfgGuidedOrdering ordering = new CfgGuidedOrdering(cfg); + InitialOrdering ordering = new InitialOrdering(cfg); List ordered = ordering.orderConditions(cfg.getConditions()); @@ -240,7 +240,7 @@ void testMultipleIndependentConditions() { .build(); Cfg cfg = Cfg.from(ruleSet); - CfgGuidedOrdering ordering = new CfgGuidedOrdering(cfg); + InitialOrdering ordering = new InitialOrdering(cfg); List ordered = ordering.orderConditions(cfg.getConditions()); @@ -264,7 +264,7 @@ void testEmptyConditions() { .build(); Cfg cfg = Cfg.from(ruleSet); - CfgGuidedOrdering ordering = new CfgGuidedOrdering(cfg); + InitialOrdering ordering = new InitialOrdering(cfg); List ordered = ordering.orderConditions(cfg.getConditions()); @@ -311,7 +311,7 @@ void testIsSetGatePriority() { .build(); Cfg cfg = Cfg.from(ruleSet); - CfgGuidedOrdering ordering = new CfgGuidedOrdering(cfg); + InitialOrdering ordering = new InitialOrdering(cfg); List ordered = ordering.orderConditions(cfg.getConditions()); diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java index 225feabd34e..35d23c85a0b 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java @@ -66,7 +66,7 @@ void buildCreatesValidCfg() { assertNotNull(cfg); assertSame(root, cfg.getRoot()); - assertEquals(ruleSet, cfg.getRuleSet()); + assertEquals(ruleSet.getParameters(), cfg.getParameters()); } @Test diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java index 888d833f9d8..2aad5ab9484 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgTest.java @@ -44,7 +44,7 @@ void gettersReturnConstructorValues() { Cfg cfg = new Cfg(ruleSet, root); - assertSame(ruleSet, cfg.getRuleSet()); + assertSame(ruleSet.getParameters(), cfg.getParameters()); assertSame(root, cfg.getRoot()); } @@ -62,7 +62,6 @@ void fromCreatesSimpleCfg() { assertNotNull(cfg); assertNotNull(cfg.getRoot()); - assertEquals(ruleSet, cfg.getRuleSet()); // Root should be a result node for a simple endpoint rule assertInstanceOf(ResultNode.class, cfg.getRoot()); @@ -156,7 +155,7 @@ void iteratorVisitsAllNodes() { @Test void iteratorHandlesEmptyCfg() { CfgNode root = ResultNode.terminal(); - Cfg cfg = new Cfg(null, root); + Cfg cfg = new Cfg((EndpointRuleSet) null, root); List nodes = new ArrayList<>(); for (CfgNode node : cfg) { @@ -170,7 +169,7 @@ void iteratorHandlesEmptyCfg() { @Test void iteratorThrowsNoSuchElementException() { CfgNode root = ResultNode.terminal(); - Cfg cfg = new Cfg(null, root); + Cfg cfg = new Cfg((EndpointRuleSet) null, root); Iterator iterator = cfg.iterator(); assertTrue(iterator.hasNext()); diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransformTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransformTest.java new file mode 100644 index 00000000000..d2fbac52245 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CoalesceTransformTest.java @@ -0,0 +1,575 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.logic.cfg; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.StringEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.utils.ListUtils; + +class CoalesceTransformTest { + + @Test + void testActualCoalescing() { + // Test with substring which returns a string (has zero value "") + // This should actually coalesce + Condition checkInput = Condition.builder() + .fn(TestHelpers.isSet("Input")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.substring("Input", 0, 5, false)) + .result(Identifier.of("prefix")) + .build(); + + Condition use = Condition.builder() + .fn(StringEquals.ofExpressions( + Expression.getReference(Identifier.of("prefix")), + Literal.of("https"))) + .result(Identifier.of("isHttps")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkInput, bind, use) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Input") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Should coalesce because substring returns string which has zero value + Rule transformedRule = transformed.getRules().get(0); + assertEquals(2, transformedRule.getConditions().size()); // isSet + coalesced + + Condition coalesced = transformedRule.getConditions().get(1); + assertTrue(coalesced.getResult().isPresent()); + assertEquals("isHttps", coalesced.getResult().get().toString()); + } + + @Test + void testSimpleBindThenUsePattern() { + // parseUrl returns a record type which doesn't have a zero value + // So it won't be coalesced + Condition checkEndpoint = Condition.builder() + .fn(TestHelpers.isSet("Endpoint")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint")) + .result(Identifier.of("url")) + .build(); + + Condition use = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "scheme")) + .result(Identifier.of("scheme")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkEndpoint, bind, use) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Endpoint") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + Rule transformedRule = transformed.getRules().get(0); + List conditions = transformedRule.getConditions(); + + // Should not coalesce because parseUrl returns a record without zero value + assertEquals(3, conditions.size()); + } + + @Test + void testDoesNotCoalesceWhenVariableUsedMultipleTimes() { + // Variable is used multiple times - should not coalesce + Condition checkEndpoint = Condition.builder() + .fn(TestHelpers.isSet("Endpoint")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint")) + .result(Identifier.of("url")) + .build(); + + Condition use1 = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "scheme")) + .result(Identifier.of("scheme")) + .build(); + + Condition use2 = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "authority")) + .result(Identifier.of("authority")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkEndpoint, bind, use1, use2) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Endpoint") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Should not coalesce because 'url' is used twice + Rule transformedRule = transformed.getRules().get(0); + assertEquals(4, transformedRule.getConditions().size()); + } + + @Test + void testDoesNotCoalesceIsSetFunction() { + // isSet functions should not be coalesced + Condition bind = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition use = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasRegion")), + Literal.of(true))) + .result(Identifier.of("regionIsSet")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(bind, use) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Should coalesce because BooleanType has a zero value (false) + // The actual behavior is that it DOES coalesce boolean operations + Rule transformedRule = transformed.getRules().get(0); + assertEquals(1, transformedRule.getConditions().size()); + } + + @Test + void testMultipleCoalescesInSameRule() { + // parseUrl returns a record type which doesn't have a zero value + // So these won't be coalesced + Condition checkEndpoint1 = Condition.builder() + .fn(TestHelpers.isSet("Endpoint1")) + .build(); + + Condition checkEndpoint2 = Condition.builder() + .fn(TestHelpers.isSet("Endpoint2")) + .build(); + + Condition bind1 = Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint1")) + .result(Identifier.of("url1")) + .build(); + + Condition use1 = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url1")), + "scheme")) + .result(Identifier.of("scheme1")) + .build(); + + Condition bind2 = Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint2")) + .result(Identifier.of("url2")) + .build(); + + Condition use2 = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url2")), + "scheme")) + .result(Identifier.of("scheme2")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkEndpoint1, checkEndpoint2, bind1, use1, bind2, use2) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Endpoint1") + .type(ParameterType.STRING) + .build()) + .addParameter(Parameter.builder() + .name("Endpoint2") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Won't coalesce because parseUrl returns record without zero value + Rule transformedRule = transformed.getRules().get(0); + assertEquals(6, transformedRule.getConditions().size()); + } + + @Test + void testCoalesceWithStringFunctions() { + // Test coalescing with string manipulation functions + Condition checkInput = Condition.builder() + .fn(TestHelpers.isSet("Input")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.substring("Input", 0, 5, false)) + .result(Identifier.of("prefix")) + .build(); + + Condition use = Condition.builder() + .fn(StringEquals.ofExpressions( + Expression.getReference(Identifier.of("prefix")), + Literal.of("https"))) + .result(Identifier.of("isHttps")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkInput, bind, use) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Input") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Should coalesce string functions that have zero values + Rule transformedRule = transformed.getRules().get(0); + assertEquals(2, transformedRule.getConditions().size()); // isSet + coalesced + + Condition coalesced = transformedRule.getConditions().get(1); + assertTrue(coalesced.getResult().isPresent()); + assertEquals("isHttps", coalesced.getResult().get().toString()); + } + + @Test + void testDoesNotCoalesceWhenNotImmediatelyFollowing() { + // Bind and use are not immediately following each other + Condition checkEndpoint = Condition.builder() + .fn(TestHelpers.isSet("Endpoint")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint")) + .result(Identifier.of("url")) + .build(); + + Condition intermediate = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition use = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "scheme")) + .result(Identifier.of("scheme")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkEndpoint, bind, intermediate, use) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Endpoint") + .type(ParameterType.STRING) + .build()) + .addParameter(Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Should NOT coalesce because bind and use are not adjacent + Rule transformedRule = transformed.getRules().get(0); + assertEquals(4, transformedRule.getConditions().size()); + } + + @Test + void testCoalesceCaching() { + // parseUrl returns a record type which doesn't have a zero value + // So these won't be coalesced + Condition check1 = Condition.builder() + .fn(TestHelpers.isSet("Endpoint")) + .build(); + + Rule rule1 = EndpointRule.builder() + .conditions( + check1, + Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint")) + .result(Identifier.of("url")) + .build(), + Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "scheme")) + .result(Identifier.of("scheme")) + .build()) + .endpoint(TestHelpers.endpoint("https://example1.com")); + + Condition check2 = Condition.builder() + .fn(TestHelpers.isSet("Endpoint")) + .build(); + + Rule rule2 = EndpointRule.builder() + .conditions( + check2, + Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint")) + .result(Identifier.of("url")) + .build(), + Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "scheme")) + .result(Identifier.of("scheme")) + .build()) + .endpoint(TestHelpers.endpoint("https://example2.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Endpoint") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Won't coalesce because parseUrl returns record without zero value + assertEquals(3, transformed.getRules().get(0).getConditions().size()); + assertEquals(3, transformed.getRules().get(1).getConditions().size()); + } + + @Test + void testCoalesceInErrorRule() { + // parseUrl returns a record type which doesn't have a zero value + // So it won't be coalesced + Condition checkEndpoint = Condition.builder() + .fn(TestHelpers.isSet("Endpoint")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint")) + .result(Identifier.of("url")) + .build(); + + Condition use = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "scheme")) + .result(Identifier.of("scheme")) + .build(); + + Condition check = Condition.builder() + .fn(Not.ofExpressions( + StringEquals.ofExpressions( + Expression.getReference(Identifier.of("scheme")), + Literal.of("https")))) + .build(); + + Rule rule = ErrorRule.builder() + .conditions(checkEndpoint, bind, use, check) + .error(Literal.of("Endpoint must use HTTPS")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Endpoint") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Won't coalesce because parseUrl returns record without zero value + Rule transformedRule = transformed.getRules().get(0); + assertEquals(4, transformedRule.getConditions().size()); + } + + @Test + void testCoalesceWithBooleanType() { + // Test coalescing with boolean-returning functions + Condition checkInput = Condition.builder() + .fn(TestHelpers.isSet("Input")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.isValidHostLabel("Input", false)) + .result(Identifier.of("isValid")) + .build(); + + Condition use = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("isValid")), + Literal.of(true))) + .result(Identifier.of("validLabel")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(checkInput, bind, use) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Input") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Should coalesce boolean functions (they have zero value of false) + Rule transformedRule = transformed.getRules().get(0); + assertEquals(2, transformedRule.getConditions().size()); // isSet + coalesced + } + + @Test + void testDoesNotCoalesceWhenVariableUsedElsewhere() { + // Variable is used in a different branch + Condition checkEndpoint = Condition.builder() + .fn(TestHelpers.isSet("Endpoint")) + .build(); + + Condition bind = Condition.builder() + .fn(TestHelpers.parseUrl("Endpoint")) + .result(Identifier.of("url")) + .build(); + + Condition use = Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "scheme")) + .result(Identifier.of("scheme")) + .build(); + + Rule innerRule = EndpointRule.builder() + .conditions(Condition.builder() + .fn(TestHelpers.getAttr( + Expression.getReference(Identifier.of("url")), + "authority")) + .result(Identifier.of("authority")) + .build()) + .endpoint(TestHelpers.endpoint("https://inner.example.com")); + + Rule treeRule = TreeRule.builder() + .conditions(checkEndpoint, bind, use) + .treeRule(innerRule); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("Endpoint") + .type(ParameterType.STRING) + .build()) + .build(); + + EndpointRuleSet original = EndpointRuleSet.builder() + .parameters(params) + .addRule(treeRule) + .build(); + + EndpointRuleSet transformed = CoalesceTransform.transform(original); + + // Should NOT coalesce because 'url' is used in the inner rule + TreeRule transformedTree = (TreeRule) transformed.getRules().get(0); + assertEquals(3, transformedTree.getConditions().size()); // isSet + bind + use (not coalesced) + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java index 8edac1db7d7..29b9c87f03e 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/ReferenceRewriterTest.java @@ -34,7 +34,7 @@ void testSimpleReferenceReplacement() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("y"))); - ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); // Test rewriting a simple reference Reference original = Expression.getReference(Identifier.of("x")); @@ -49,7 +49,7 @@ void testNoRewriteNeeded() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("y"))); - ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); // Reference to "z" should not be rewritten Reference original = Expression.getReference(Identifier.of("z")); @@ -67,7 +67,7 @@ void testRewriteInStringLiteral() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("newVar"))); - ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); Expression rewritten = rewriter.rewrite(original); assertInstanceOf(StringLiteral.class, rewritten); @@ -84,7 +84,7 @@ void testRewriteInTupleLiteral() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("replaced"))); - ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); Expression rewritten = rewriter.rewrite(original); assertInstanceOf(TupleLiteral.class, rewritten); @@ -105,7 +105,7 @@ void testRewriteInRecordLiteral() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("newX"))); - ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); Expression rewritten = rewriter.rewrite(original); assertInstanceOf(RecordLiteral.class, rewritten); @@ -124,7 +124,7 @@ void testRewriteInLibraryFunction() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("replacedVar"))); - ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); Expression rewritten = rewriter.rewrite(original); assertTrue(rewritten.toString().contains("replacedVar")); @@ -142,7 +142,7 @@ void testMultipleReplacements() { replacements.put("a", Expression.getReference(Identifier.of("x"))); replacements.put("b", Expression.getReference(Identifier.of("y"))); - ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); Expression rewritten = rewriter.rewrite(original); assertTrue(rewritten.toString().contains("x")); @@ -158,7 +158,7 @@ void testNestedRewriting() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("newVar"))); - ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); Expression rewritten = rewriter.rewrite(original); assertTrue(rewritten.toString().contains("newVar")); @@ -173,7 +173,7 @@ void testStaticStringNotRewritten() { Map replacements = new HashMap<>(); replacements.put("x", Expression.getReference(Identifier.of("y"))); - ReferenceRewriter rewriter = ReferenceRewriter.forReplacements(replacements); + TreeRewriter rewriter = TreeRewriter.forReplacements(replacements); Expression rewritten = rewriter.rewrite(original); assertEquals(original, rewritten); diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysisTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysisTest.java index 204483ef615..ddb1505f7dc 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysisTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysisTest.java @@ -123,6 +123,7 @@ void testMultipleBindings() { VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); assertFalse(analysis.hasSingleBinding("x")); + assertTrue(analysis.hasMultipleBindings("x")); // Not safe to inline when multiple bindings exist assertFalse(analysis.isSafeToInline("x")); @@ -324,4 +325,145 @@ void testNoVariables() { // But Region is an input parameter assertTrue(analysis.getInputParams().contains("Region")); } + + @Test + void testSameExpressionDifferentVariableNames() { + // Same expression bound to different variable names in different rules + Rule rule1 = EndpointRule.builder() + .conditions(Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build()) + .endpoint(TestHelpers.endpoint("https://example1.com")); + + Rule rule2 = EndpointRule.builder() + .conditions(Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("regionExists")) + .build()) + .endpoint(TestHelpers.endpoint("https://example2.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + // Each variable has a single binding (same expression though) + assertTrue(analysis.hasSingleBinding("hasRegion")); + assertTrue(analysis.hasSingleBinding("regionExists")); + + // Neither is referenced after binding + assertEquals(0, analysis.getReferenceCount("hasRegion")); + assertEquals(0, analysis.getReferenceCount("regionExists")); + } + + @Test + void testDeeplyNestedTreeRules() { + // Multiple levels of tree rule nesting + Condition level3Define = Condition.builder() + .fn(TestHelpers.isSet("Bucket")) + .result(Identifier.of("hasBucket")) + .build(); + + Rule level3Rule = EndpointRule.builder() + .conditions(level3Define) + .endpoint(TestHelpers.endpoint("https://level3.com")); + + Condition level2Define = Condition.builder() + .fn(TestHelpers.isSet("Key")) + .result(Identifier.of("hasKey")) + .build(); + + Rule level2Rule = TreeRule.builder() + .conditions(level2Define) + .treeRule(level3Rule); + + Condition level1Define = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("hasRegion")) + .build(); + + Condition level1Use = Condition.builder() + .fn(BooleanEquals.ofExpressions( + Expression.getReference(Identifier.of("hasRegion")), + Literal.of(true))) + .build(); + + Rule level1Rule = TreeRule.builder() + .conditions(level1Define, level1Use) + .treeRule(level2Rule); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Key").type(ParameterType.STRING).build()) + .addParameter(Parameter.builder().name("Bucket").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(level1Rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertTrue(analysis.hasSingleBinding("hasRegion")); + assertTrue(analysis.hasSingleBinding("hasKey")); + assertTrue(analysis.hasSingleBinding("hasBucket")); + + assertEquals(1, analysis.getReferenceCount("hasRegion")); + assertEquals(0, analysis.getReferenceCount("hasKey")); + assertEquals(0, analysis.getReferenceCount("hasBucket")); + } + + @Test + void testUnreferencedVariable() { + // Variable that's defined but never used + Condition defineUnused = Condition.builder() + .fn(TestHelpers.isSet("Region")) + .result(Identifier.of("unused")) + .build(); + + Rule rule = EndpointRule.builder() + .conditions(defineUnused) + .endpoint(TestHelpers.endpoint("https://example.com")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .addRule(rule) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertTrue(analysis.hasSingleBinding("unused")); + assertEquals(0, analysis.getReferenceCount("unused")); + assertFalse(analysis.isSafeToInline("unused")); // Not safe because not referenced + } + + @Test + void testEmptyRuleSet() { + // Empty ruleset with just parameters + Parameters params = Parameters.builder() + .addParameter(Parameter.builder().name("Region").type(ParameterType.STRING).build()) + .build(); + + EndpointRuleSet ruleSet = EndpointRuleSet.builder() + .parameters(params) + .build(); + + VariableAnalysis analysis = VariableAnalysis.analyze(ruleSet); + + assertEquals(1, analysis.getInputParams().size()); + assertTrue(analysis.getInputParams().contains("Region")); + assertEquals(0, analysis.getReferenceCount("Region")); + } } From 5cc33c55ae4c1c0a4bddfd63b612863101326c18 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Thu, 14 Aug 2025 20:35:15 -0500 Subject: [PATCH 15/23] Add result coalescing phi nodes When we detect that a result or error are the same except for the variables used in template placeholders at the same position, we now automatically insert phi nodes using coalesce functions to reduce the amount of duplicate results. For example, this removed the result duplication from S3Express SSA variable versioning entirely, down from 158 results to 121. --- .../rulesengine/logic/cfg/CfgBuilder.java | 616 ++++++++++++++++-- .../rulesengine/logic/cfg/CfgBuilderTest.java | 243 +++++++ 2 files changed, 818 insertions(+), 41 deletions(-) diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java index bd7f874e555..503cbf64e9f 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java @@ -4,35 +4,56 @@ */ package software.amazon.smithy.rulesengine.logic.cfg; +import java.util.ArrayList; import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; +import java.util.logging.Logger; +import software.amazon.smithy.rulesengine.language.Endpoint; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.RecordLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.TupleLiteral; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; import software.amazon.smithy.rulesengine.logic.ConditionReference; /** - * Builder for constructing Control Flow Graphs with node deduplication. + * Builder for constructing Control Flow Graphs with node deduplication and result convergence. * - *

    The builder performs simple hash-consing during top-down construction, deduplicating nodes when the same node - * would be created multiple times in the same context. + *

    This builder performs hash-consing during construction and creates convergence nodes + * for structurally similar results to minimize BDD size. */ public final class CfgBuilder { + private static final Logger LOGGER = Logger.getLogger(CfgBuilder.class.getName()); + + // Configuration constants + private static final int MAX_DIVERGENT_PATHS_FOR_CONVERGENCE = 5; + private static final int MIN_RESULTS_FOR_CONVERGENCE = 2; + final EndpointRuleSet ruleSet; - // Simple hash-consing for nodes created in the same context + // Node deduplication private final Map nodeCache = new HashMap<>(); // Condition and result canonicalization @@ -40,9 +61,17 @@ public final class CfgBuilder { private final Map resultCache = new HashMap<>(); private final Map resultNodeCache = new HashMap<>(); + // Result convergence support + private final Map resultToConvergenceNode = new HashMap<>(); + private int convergenceNodesCreated = 0; + private int phiVariableCounter = 0; + public CfgBuilder(EndpointRuleSet ruleSet) { - // Disambiguate conditions and references so variable names are globally unique. + // Apply SSA transform to ensure globally unique variable names this.ruleSet = SsaTransform.transform(ruleSet); + + // Analyze results and create convergence nodes + analyzeAndCreateConvergenceNodes(); } /** @@ -83,12 +112,23 @@ public CfgNode createCondition(ConditionReference condRef, CfgNode trueBranch, C /** * Creates a result node representing a terminal rule evaluation. * + *

    If this result is part of a convergence group, returns the shared convergence node instead. + * * @param rule the result rule (endpoint or error) - * @return a result node (cached if identical rule already seen) + * @return a result node or convergence node */ public CfgNode createResult(Rule rule) { - Rule canonical = rule.withConditions(Collections.emptyList()); - Rule interned = resultCache.computeIfAbsent(canonical, k -> k); + // Intern the result + Rule interned = intern(rule); + + // Check if this result has a convergence node + CfgNode convergenceNode = resultToConvergenceNode.get(interned); + if (convergenceNode != null) { + LOGGER.fine("Using convergence node for result: " + interned); + return convergenceNode; + } + + // Regular result node return resultNodeCache.computeIfAbsent(interned, ResultNode::new); } @@ -96,13 +136,11 @@ public CfgNode createResult(Rule rule) { * Creates a canonical condition reference, handling negation and deduplication. */ public ConditionReference createConditionReference(Condition condition) { - // Check cache first ConditionReference cached = conditionToReference.get(condition); if (cached != null) { return cached; } - // Check if it's a negation boolean negated = false; Condition canonical = condition; @@ -110,20 +148,16 @@ public ConditionReference createConditionReference(Condition condition) { negated = true; canonical = unwrapNegation(condition); - // Check if we already have the non-negated version ConditionReference existing = conditionToReference.get(canonical); if (existing != null) { - // Reuse the existing Condition, just negate the reference ConditionReference negatedReference = existing.negate(); conditionToReference.put(condition, negatedReference); return negatedReference; } } - // Canonicalize for commutative operations canonical = canonical.canonicalize(); - // Canonicalize boolean equals Condition beforeBooleanCanon = canonical; canonical = canonicalizeBooleanEquals(canonical); @@ -131,13 +165,9 @@ public ConditionReference createConditionReference(Condition condition) { negated = !negated; } - // Create the reference (possibly negated) ConditionReference reference = new ConditionReference(canonical, negated); - - // Cache the reference under the original key conditionToReference.put(condition, reference); - // Also cache under the canonical form if different if (!negated && !condition.equals(canonical)) { conditionToReference.put(canonical, reference); } @@ -145,25 +175,259 @@ public ConditionReference createConditionReference(Condition condition) { return reference; } + private void analyzeAndCreateConvergenceNodes() { + LOGGER.info("Analyzing results for convergence opportunities"); + List allResults = new ArrayList<>(); + for (Rule rule : ruleSet.getRules()) { + collectResultsFromRule(rule, allResults); + } + LOGGER.info("Found " + allResults.size() + " total results"); + + Map groups = groupResultsByStructure(allResults); + createConvergenceNodesForGroups(groups); + + LOGGER.info(String.format("Created %d convergence nodes for %d result groups", + convergenceNodesCreated, + groups.size())); + } + + private void collectResultsFromRule(Rule rule, List results) { + if (rule instanceof EndpointRule || rule instanceof ErrorRule) { + results.add(intern(rule)); + } else if (rule instanceof TreeRule) { + for (Rule nestedRule : ((TreeRule) rule).getRules()) { + collectResultsFromRule(nestedRule, results); + } + } + } + + private Rule intern(Rule rule) { + return resultCache.computeIfAbsent(canonicalizeResult(rule), k -> k); + } + + private Map groupResultsByStructure(List results) { + Map groups = new HashMap<>(); + for (Rule result : results) { + ResultSignature sig = new ResultSignature(result); + groups.computeIfAbsent(sig, k -> new ResultGroup()).add(result); + } + return groups; + } + + private void createConvergenceNodesForGroups(Map groups) { + for (ResultGroup group : groups.values()) { + if (shouldGroupResults(group)) { + createConvergenceNodeForGroup(group); + } + } + } + + private boolean shouldGroupResults(ResultGroup group) { + if (group.results.size() < MIN_RESULTS_FOR_CONVERGENCE) { + return false; + } + + int divergentCount = group.getDivergentPathCount(); + if (divergentCount == 0) { + return false; + } + + if (divergentCount > MAX_DIVERGENT_PATHS_FOR_CONVERGENCE) { + LOGGER.fine(String.format("Skipping convergence for group with %d divergent paths (perf)", + divergentCount)); + return false; + } + + return true; + } + + private void createConvergenceNodeForGroup(ResultGroup group) { + Map> divergentPaths = group.getDivergentPaths(); + Rule canonical = group.results.get(0); // already interned + + Map phiVariableMap = createPhiVariablesByPath(divergentPaths); + Rule rewrittenResult = rewriteResultWithPhiVariables(canonical, divergentPaths, phiVariableMap); + + CfgNode convergenceNode = buildConvergenceNode(rewrittenResult, divergentPaths, phiVariableMap); + + for (Rule result : group.results) { + resultToConvergenceNode.put(result, convergenceNode); // keys are the interned instances + } + + convergenceNodesCreated++; + } + + private Map createPhiVariablesByPath(Map> divergentPaths) { + List paths = new ArrayList<>(divergentPaths.keySet()); + Collections.sort(paths); + + Map phiVariableMap = new LinkedHashMap<>(); + for (LocationPath path : paths) { + phiVariableMap.put(path, "phi_result_" + (phiVariableCounter++)); + } + return phiVariableMap; // already ordered + } + + private CfgNode buildConvergenceNode( + Rule result, + Map> divergentPaths, + Map phiVariableMap + ) { + CfgNode resultNode = new ResultNode(result); + + // Apply phi nodes in the order already established by phiVariableMap + for (Map.Entry entry : phiVariableMap.entrySet()) { + Set versions = divergentPaths.get(entry.getKey()); + String phiVar = entry.getValue(); + + Condition coalesceCondition = createCoalesceCondition(phiVar, versions); + ConditionReference condRef = new ConditionReference(coalesceCondition, false); + + // Use terminal (no match) as false branch to prevent BDD optimization + // The coalesce always succeeds, so we never take the false branch + resultNode = new ConditionNode(condRef, resultNode, ResultNode.terminal()); + + // Log with deterministic ordering + LOGGER.fine(() -> { + List sortedVersions = new ArrayList<>(versions); + Collections.sort(sortedVersions); + return String.format("Created convergence: %s = coalesce(%s) for path %s", + phiVar, + String.join(",", sortedVersions), + entry.getKey()); + }); + } + + return resultNode; + } + + private Rule rewriteResultWithPhiVariables( + Rule result, + Map> divergentPaths, + Map phiVariableMap + ) { + // Build replacements for URL or error expression only (not headers/properties) + Map urlReplacements = buildPhiReplacements(divergentPaths, phiVariableMap); + + if (urlReplacements.isEmpty()) { + return result; + } + + TreeRewriter rewriter = TreeRewriter.forReplacements(urlReplacements); + + if (result instanceof EndpointRule) { + return rewriteEndpointRule((EndpointRule) result, rewriter); + } else if (result instanceof ErrorRule) { + return rewriteErrorRule((ErrorRule) result, rewriter); + } + + return result; + } + + private Map buildPhiReplacements( + Map> divergentPaths, + Map phiVariableMap + ) { + Map replacements = new HashMap<>(); + for (Map.Entry> entry : divergentPaths.entrySet()) { + LocationPath path = entry.getKey(); + String phiVar = phiVariableMap.get(path); + Expression phiRef = Expression.getReference(Identifier.of(phiVar)); + + // Map all versions at this path to the phi variable + for (String version : entry.getValue()) { + replacements.put(version, phiRef); + } + } + return replacements; + } + + private Rule rewriteEndpointRule(EndpointRule rule, TreeRewriter rewriter) { + Endpoint endpoint = rule.getEndpoint(); + Expression rewrittenUrl = rewriter.rewrite(endpoint.getUrl()); + + if (rewrittenUrl != endpoint.getUrl()) { + Endpoint rewrittenEndpoint = Endpoint.builder() + .url(rewrittenUrl) + .headers(endpoint.getHeaders()) + .properties(endpoint.getProperties()) + .build(); + + return EndpointRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(Collections.emptyList()) + .endpoint(rewrittenEndpoint); + } + return rule; + } + + private Rule rewriteErrorRule(ErrorRule rule, TreeRewriter rewriter) { + Expression rewrittenError = rewriter.rewrite(rule.getError()); + + if (rewrittenError != rule.getError()) { + return ErrorRule.builder() + .description(rule.getDocumentation().orElse(null)) + .conditions(Collections.emptyList()) + .error(rewrittenError); + } + return rule; + } + + private Condition createCoalesceCondition(String resultVar, Set versions) { + // De-duplicate and sort for determinism + List inputs = new ArrayList<>(new LinkedHashSet<>(versions)); + Collections.sort(inputs); + + if (inputs.isEmpty()) { + throw new IllegalArgumentException("Cannot create coalesce with no versions"); + } + + Expression coalesced = buildCoalesceExpressionFromPrepared(inputs); + + return Condition.builder() + .fn((LibraryFunction) coalesced) + .result(Identifier.of(resultVar)) + .build(); + } + + private Expression buildCoalesceExpressionFromPrepared(List inputs) { + Expression coalesced = Expression.getReference(Identifier.of(inputs.get(0))); + + if (inputs.size() == 1) { + // Hack: coalesce(x, x) for single element + return Coalesce.ofExpressions(coalesced, coalesced); + } + + // Build coalesce chain for multiple versions + for (int i = 1; i < inputs.size(); i++) { + Expression ref = Expression.getReference(Identifier.of(inputs.get(i))); + coalesced = Coalesce.ofExpressions(coalesced, ref); + } + return coalesced; + } + + private Rule canonicalizeResult(Rule rule) { + return rule == null ? null : rule.withConditions(Collections.emptyList()); + } + private Condition canonicalizeBooleanEquals(Condition condition) { if (!(condition.getFunction() instanceof BooleanEquals)) { return condition; } List args = condition.getFunction().getArguments(); + if (args.size() != 2 || !(args.get(0) instanceof Reference) || !(args.get(1) instanceof Literal)) { + return condition; + } - // After commutative canonicalization, if there's a reference, it's in position 0 - if (args.get(0) instanceof Reference && args.get(1) instanceof Literal) { - Reference ref = (Reference) args.get(0); - Boolean literalValue = ((Literal) args.get(1)).asBooleanLiteral().orElse(null); - - if (literalValue != null && !literalValue && ruleSet != null) { - String varName = ref.getName().toString(); - Optional param = ruleSet.getParameters().get(Identifier.of(varName)); - if (param.isPresent() && param.get().getDefault().isPresent()) { - // Convert booleanEquals(var, false) to booleanEquals(var, true) - return condition.toBuilder().fn(BooleanEquals.ofExpressions(ref, true)).build(); - } + Reference ref = (Reference) args.get(0); + Boolean literalValue = ((Literal) args.get(1)).asBooleanLiteral().orElse(null); + + if (literalValue != null && !literalValue && ruleSet != null) { + String varName = ref.getName().toString(); + Optional param = ruleSet.getParameters().get(Identifier.of(varName)); + if (param.isPresent() && param.get().getDefault().isPresent()) { + return condition.toBuilder().fn(BooleanEquals.ofExpressions(ref, true)).build(); } } @@ -171,13 +435,9 @@ private Condition canonicalizeBooleanEquals(Condition condition) { } private static boolean isNegationWrapper(Condition condition) { - if (!(condition.getFunction() instanceof Not)) { - return false; - } else if (condition.getResult().isPresent()) { - return false; - } else { - return condition.getFunction().getArguments().get(0) instanceof LibraryFunction; - } + return condition.getFunction() instanceof Not + && !condition.getResult().isPresent() + && condition.getFunction().getArguments().get(0) instanceof LibraryFunction; } private static Condition unwrapNegation(Condition negatedCondition) { @@ -186,7 +446,282 @@ private static Condition unwrapNegation(Condition negatedCondition) { .build(); } - // Signature for node deduplication during construction. + /** + * Path represents a structural location within an expression tree. + */ + private static final class LocationPath implements Comparable { + private final String key; + private final int hash; + + LocationPath(List parts) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < parts.size(); i++) { + if (i > 0) { + sb.append("/"); + } + sb.append(parts.get(i).toString()); + } + this.key = sb.toString(); + this.hash = key.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (o instanceof LocationPath) { + return ((LocationPath) o).key.equals(this.key); + } + return false; + } + + @Override + public int hashCode() { + return hash; + } + + @Override + public String toString() { + return key; + } + + @Override + public int compareTo(LocationPath other) { + return this.key.compareTo(other.key); + } + } + + /** + * Signature for result grouping based on structural similarity. + */ + private static class ResultSignature { + private final String type; + private final Object urlStructure; + private final Object headersStructure; + private final Object propertiesStructure; + private final int hashCode; + + ResultSignature(Rule result) { + this.type = result instanceof EndpointRule ? "endpoint" : "error"; + + if (result instanceof EndpointRule) { + Endpoint ep = ((EndpointRule) result).getEndpoint(); + this.urlStructure = buildExpressionStructure(ep.getUrl()); + this.headersStructure = buildHeadersStructure(ep.getHeaders()); + this.propertiesStructure = buildPropertiesStructure(ep.getProperties()); + } else if (result instanceof ErrorRule) { + this.urlStructure = buildExpressionStructure(((ErrorRule) result).getError()); + this.headersStructure = null; + this.propertiesStructure = null; + } else { + this.urlStructure = null; + this.headersStructure = null; + this.propertiesStructure = null; + } + + this.hashCode = Objects.hash(type, urlStructure, headersStructure, propertiesStructure); + } + + private Object buildExpressionStructure(Expression expr) { + if (expr instanceof Reference) { + return "VAR"; + } else if (expr instanceof StringLiteral) { + return buildTemplateStructure(((StringLiteral) expr).value()); + } else if (expr instanceof LibraryFunction) { + return buildFunctionStructure((LibraryFunction) expr); + } else if (expr instanceof TupleLiteral) { + return buildTupleStructure((TupleLiteral) expr); + } else if (expr instanceof RecordLiteral) { + return buildRecordStructure((RecordLiteral) expr); + } else if (expr instanceof Literal) { + return expr.toString(); + } + return expr.getClass().getSimpleName(); + } + + private Object buildFunctionStructure(LibraryFunction fn) { + Map fnStructure = new LinkedHashMap<>(); + fnStructure.put("fn", fn.getName()); + List args = new ArrayList<>(); + for (Expression arg : fn.getArguments()) { + args.add(buildExpressionStructure(arg)); + } + fnStructure.put("args", args); + return fnStructure; + } + + private Object buildTupleStructure(TupleLiteral tuple) { + List structure = new ArrayList<>(); + for (Literal member : tuple.members()) { + structure.add(buildExpressionStructure(member)); + } + return structure; + } + + private Object buildRecordStructure(RecordLiteral record) { + Map recordStructure = new LinkedHashMap<>(); + for (Map.Entry entry : record.members().entrySet()) { + recordStructure.put(entry.getKey().toString(), buildExpressionStructure(entry.getValue())); + } + return recordStructure; + } + + private Object buildTemplateStructure(Template template) { + if (template.isStatic()) { + return template.expectLiteral(); + } + + List parts = new ArrayList<>(); + for (Template.Part part : template.getParts()) { + if (part instanceof Template.Literal) { + parts.add(((Template.Literal) part).getValue()); + } else if (part instanceof Template.Dynamic) { + parts.add(buildExpressionStructure(((Template.Dynamic) part).toExpression())); + } + } + return parts; + } + + private Object buildHeadersStructure(Map> headers) { + if (headers.isEmpty()) { + return Collections.emptyMap(); + } + + // Sort header keys for deterministic ordering + List sortedKeys = new ArrayList<>(headers.keySet()); + Collections.sort(sortedKeys); + + Map structure = new LinkedHashMap<>(); + for (String key : sortedKeys) { + List values = new ArrayList<>(); + for (Expression expr : headers.get(key)) { + values.add(buildExpressionStructure(expr)); + } + structure.put(key, values); + } + return structure; + } + + private Object buildPropertiesStructure(Map properties) { + if (properties.isEmpty()) { + return Collections.emptyMap(); + } + + // Sort property keys by their string representation for deterministic ordering + List sortedIds = new ArrayList<>(properties.keySet()); + sortedIds.sort(Comparator.comparing(Identifier::toString)); + + Map structure = new LinkedHashMap<>(); + for (Identifier id : sortedIds) { + structure.put(id.toString(), buildExpressionStructure(properties.get(id))); + } + return structure; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ResultSignature)) { + return false; + } + ResultSignature that = (ResultSignature) o; + return type.equals(that.type) + && Objects.equals(urlStructure, that.urlStructure) + && Objects.equals(headersStructure, that.headersStructure) + && Objects.equals(propertiesStructure, that.propertiesStructure); + } + + @Override + public int hashCode() { + return hashCode; + } + } + + /** + * Group of structurally similar results. + */ + private static class ResultGroup { + private final List results = new ArrayList<>(); + private Map> divergentPaths = null; + + void add(Rule result) { + results.add(result); + divergentPaths = null; // Invalidate cache + } + + Map> getDivergentPaths() { + if (divergentPaths == null) { + divergentPaths = computeDivergentByPath(); + } + return divergentPaths; + } + + int getDivergentPathCount() { + return getDivergentPaths().size(); + } + + private Map> computeDivergentByPath() { + Map> byPath = new LinkedHashMap<>(); + + // Collect references by path for URL only (not headers/properties) + for (Rule result : results) { + if (result instanceof EndpointRule) { + Endpoint ep = ((EndpointRule) result).getEndpoint(); + collectRefsByPath(ep.getUrl(), new ArrayList<>(), byPath); + } else if (result instanceof ErrorRule) { + Expression error = ((ErrorRule) result).getError(); + collectRefsByPath(error, new ArrayList<>(), byPath); + } + } + + // Remove paths with only one variable name (no divergence) + byPath.entrySet().removeIf(e -> e.getValue().size() <= 1); + + return byPath; + } + + private void collectRefsByPath(Expression expr, List path, Map> out) { + if (expr instanceof StringLiteral) { + collectTemplateRefs((StringLiteral) expr, path, out); + } else if (expr instanceof Reference) { + LocationPath p = new LocationPath(path); + out.computeIfAbsent(p, k -> new LinkedHashSet<>()).add(((Reference) expr).getName().toString()); + } else if (expr instanceof LibraryFunction) { + collectFunctionRefs((LibraryFunction) expr, path, out); + } else { + throw new UnsupportedOperationException("Unexpected URL or error type: " + expr); + } + } + + private void collectTemplateRefs(StringLiteral str, List path, Map> out) { + Template template = str.value(); + int i = 0; + for (Template.Part part : template.getParts()) { + if (part instanceof Template.Dynamic) { + List newPath = new ArrayList<>(path); + newPath.add("T"); + newPath.add(i); + collectRefsByPath(((Template.Dynamic) part).toExpression(), newPath, out); + } + i++; + } + } + + private void collectFunctionRefs(LibraryFunction fn, List path, Map> out) { + int i = 0; + for (Expression arg : fn.getArguments()) { + List newPath = new ArrayList<>(path); + newPath.add("F"); + newPath.add(fn.getName()); + newPath.add(i++); + collectRefsByPath(arg, newPath, out); + } + } + } + + /** + * Signature for node deduplication during construction. + */ private static final class NodeSignature { private final ConditionReference condition; private final CfgNode trueBranch; @@ -197,7 +732,6 @@ private static final class NodeSignature { this.condition = condition; this.trueBranch = trueBranch; this.falseBranch = falseBranch; - // Use identity hash for branches. this.hashCode = Objects.hash( condition, System.identityHashCode(trueBranch), @@ -208,11 +742,11 @@ private static final class NodeSignature { public boolean equals(Object o) { if (this == o) { return true; - } else if (!(o instanceof NodeSignature)) { + } + if (!(o instanceof NodeSignature)) { return false; } NodeSignature that = (NodeSignature) o; - // Reference equality for children return Objects.equals(condition, that.condition) && trueBranch == that.trueBranch && falseBranch == that.falseBranch; diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java index 35d23c85a0b..f6cdb96a1c2 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java @@ -13,8 +13,14 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.Endpoint; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.evaluation.value.Value; import software.amazon.smithy.rulesengine.language.syntax.Identifier; @@ -32,6 +38,7 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.logic.ConditionReference; import software.amazon.smithy.rulesengine.logic.TestHelpers; +import software.amazon.smithy.utils.ListUtils; class CfgBuilderTest { @@ -277,4 +284,240 @@ void createConditionReferenceIgnoresNegationWithVariableBinding() { assertInstanceOf(Not.class, ref.getCondition().getFunction()); assertEquals(negatedWithBinding, ref.getCondition()); } + + @Test + void createResultPreservesHeadersAndPropertiesInSignature() { + // Create endpoints with same URL but different headers + Map> headers1 = new HashMap<>(); + headers1.put("x-custom", Collections.singletonList(Expression.of("value1"))); + + Map> headers2 = new HashMap<>(); + headers2.put("x-custom", Collections.singletonList(Expression.of("value2"))); + + Rule rule1 = EndpointRule.builder() + .endpoint(Endpoint.builder() + .url(Expression.of("https://example.com")) + .headers(headers1) + .build()); + Rule rule2 = EndpointRule.builder() + .endpoint(Endpoint.builder() + .url(Expression.of("https://example.com")) + .headers(headers2) + .build()); + + EndpointRuleSet ruleSetWithHeaders = EndpointRuleSet.builder() + .parameters(ruleSet.getParameters()) + .rules(ListUtils.of(rule1, rule2)) + .build(); + CfgBuilder convergenceBuilder = new CfgBuilder(ruleSetWithHeaders); + + CfgNode node1 = convergenceBuilder.createResult(rule1); + CfgNode node2 = convergenceBuilder.createResult(rule2); + + // Different headers mean different signatures - no convergence + assertNotSame(node1, node2); + } + + @Test + void createResultWithStructurallyIdenticalEndpointsCreatesConvergenceNode() { + // Create two rules with structurally identical endpoints but different variable names + Rule rule1 = EndpointRule.builder() + .endpoint(TestHelpers.endpoint("https://{region1}.example.com")); + Rule rule2 = EndpointRule.builder() + .endpoint(TestHelpers.endpoint("https://{region2}.example.com")); + + // Create parameters for the variables used in the endpoints + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("region1") + .type(ParameterType.STRING) + .defaultValue(Value.stringValue("a")) + .required(true) + .build()) + .addParameter(Parameter.builder() + .name("region2") + .type(ParameterType.STRING) + .defaultValue(Value.stringValue("a")) + .required(true) + .build()) + .build(); + + EndpointRuleSet ruleSetWithEndpoints = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + CfgBuilder convergenceBuilder = new CfgBuilder(ruleSetWithEndpoints); + + CfgNode node1 = convergenceBuilder.createResult(rule1); + CfgNode node2 = convergenceBuilder.createResult(rule2); + + // Both should return the same convergence node + assertSame(node1, node2); + assertInstanceOf(ConditionNode.class, node1); + } + + @Test + void createResultDistinguishesEndpointsWithDifferentStructure() { + // Create rules with different endpoint structures + Rule rule1 = EndpointRule.builder() + .endpoint(TestHelpers.endpoint("https://{region}.example.com")); + Rule rule2 = EndpointRule.builder() + .endpoint(TestHelpers.endpoint("https://example.com/{region}")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("region") + .type(ParameterType.STRING) + .defaultValue(Value.stringValue("a")) + .required(true) + .build()) + .build(); + + EndpointRuleSet ruleSetWithEndpoints = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + CfgBuilder convergenceBuilder = new CfgBuilder(ruleSetWithEndpoints); + + CfgNode node1 = convergenceBuilder.createResult(rule1); + CfgNode node2 = convergenceBuilder.createResult(rule2); + + // Different structures should not converge + assertNotSame(node1, node2); + } + + @Test + void createResultWithIdenticalErrorsCreatesConvergenceNode() { + // Create structurally identical error rules with different variable references + Rule error1 = ErrorRule.builder() + .error(Expression.of("Region {r1} is not supported")); + Rule error2 = ErrorRule.builder() + .error(Expression.of("Region {r2} is not supported")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("r1") + .type(ParameterType.STRING) + .defaultValue(Value.stringValue("a")) + .required(true) + .build()) + .addParameter(Parameter.builder() + .name("r2") + .type(ParameterType.STRING) + .defaultValue(Value.stringValue("a")) + .required(true) + .build()) + .build(); + + EndpointRuleSet ruleSetWithErrors = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(error1, error2)) + .build(); + CfgBuilder convergenceBuilder = new CfgBuilder(ruleSetWithErrors); + + CfgNode node1 = convergenceBuilder.createResult(error1); + CfgNode node2 = convergenceBuilder.createResult(error2); + + // Should converge to the same node + assertSame(node1, node2); + } + + @Test + void createResultHandlesComplexTemplateConvergence() { + // Create endpoints with complex templates that are structurally identical + Rule rule1 = EndpointRule.builder() + .endpoint(TestHelpers.endpoint("https://{svc}.{reg}.amazonaws.com/{path}")); + Rule rule2 = EndpointRule.builder() + .endpoint(TestHelpers.endpoint("https://{service}.{region}.amazonaws.com/{resource}")); + + Parameters params = Parameters.builder() + .addParameter(Parameter.builder() + .name("svc") + .type(ParameterType.STRING) + .defaultValue(Value.stringValue("a")) + .required(true) + .build()) + .addParameter(Parameter.builder() + .name("reg") + .type(ParameterType.STRING) + .defaultValue(Value.stringValue("a")) + .required(true) + .build()) + .addParameter(Parameter.builder() + .name("path") + .type(ParameterType.STRING) + .defaultValue(Value.stringValue("a")) + .required(true) + .build()) + .addParameter(Parameter.builder() + .name("service") + .type(ParameterType.STRING) + .defaultValue(Value.stringValue("a")) + .required(true) + .build()) + .addParameter(Parameter.builder() + .name("region") + .type(ParameterType.STRING) + .defaultValue(Value.stringValue("a")) + .required(true) + .build()) + .addParameter(Parameter.builder() + .name("resource") + .type(ParameterType.STRING) + .defaultValue(Value.stringValue("a")) + .required(true) + .build()) + .build(); + + EndpointRuleSet ruleSetWithTemplates = EndpointRuleSet.builder() + .parameters(params) + .rules(ListUtils.of(rule1, rule2)) + .build(); + CfgBuilder convergenceBuilder = new CfgBuilder(ruleSetWithTemplates); + + CfgNode node1 = convergenceBuilder.createResult(rule1); + CfgNode node2 = convergenceBuilder.createResult(rule2); + + // Structurally identical templates should converge + assertSame(node1, node2); + } + + @Test + void createResultDoesNotConvergeWithTooManyDivergentPaths() { + // Create many endpoints with different variable names at multiple positions + // This should exceed MAX_DIVERGENT_PATHS_FOR_CONVERGENCE (5) + List rules = new ArrayList<>(); + Parameters.Builder paramsBuilder = Parameters.builder(); + + for (int i = 0; i < 7; i++) { + rules.add(EndpointRule.builder() + .endpoint(TestHelpers.endpoint(String.format("https://{var%d}.{reg%d}.example.com", i, i)))); + paramsBuilder.addParameter(Parameter.builder() + .name("var" + i) + .type(ParameterType.STRING) + .defaultValue(Value.stringValue("a")) + .required(true) + .build()); + paramsBuilder.addParameter(Parameter.builder() + .name("reg" + i) + .type(ParameterType.STRING) + .defaultValue(Value.stringValue("a")) + .required(true) + .build()); + } + + EndpointRuleSet ruleSetWithMany = EndpointRuleSet.builder() + .parameters(paramsBuilder.build()) + .rules(rules) + .build(); + CfgBuilder convergenceBuilder = new CfgBuilder(ruleSetWithMany); + + // With too many divergent paths, convergence should be skipped + CfgNode firstNode = convergenceBuilder.createResult(rules.get(0)); + CfgNode lastNode = convergenceBuilder.createResult(rules.get(rules.size() - 1)); + + // They should still be cached as the same due to interning, + // but won't have phi node convergence due to performance limits + assertSame(firstNode, lastNode); + } } From 67a02ecf00d63db41cde8c21c5695efc6289fc5f Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Thu, 14 Aug 2025 20:40:37 -0500 Subject: [PATCH 16/23] Rename bdd trait to endpointBdd --- .../RuleSetAwsBuiltInValidator.java | 6 ++--- .../language/evaluation/RuleEvaluator.java | 4 ++-- .../language/evaluation/TestEvaluator.java | 4 ++-- .../{BddTrait.java => EndpointBddTrait.java} | 22 ++++++++--------- .../rulesengine/logic/bdd/NodeReversal.java | 4 ++-- .../logic/bdd/SiftingOptimization.java | 8 +++---- .../validators/BddTraitValidator.java | 14 +++++------ .../EndpointTestsTraitValidator.java | 4 ++-- .../RuleSetAuthSchemesValidator.java | 6 ++--- .../validators/RuleSetBuiltInValidator.java | 6 ++--- .../RuleSetParamMissingDocsValidator.java | 6 ++--- .../validators/RuleSetParameterValidator.java | 4 ++-- .../validators/RuleSetTestCaseValidator.java | 6 ++--- .../validators/RuleSetUriValidator.java | 6 ++--- ...re.amazon.smithy.model.traits.TraitService | 2 +- .../META-INF/smithy/smithy.rules.smithy | 4 ++-- .../rulesengine/logic/bdd/BddTraitTest.java | 6 ++--- .../logic/bdd/NodeReversalTest.java | 4 ++-- .../logic/bdd/SiftingOptimizationTest.java | 24 +++++++++---------- .../errorfiles/bdd/bdd-invalid-base64.errors | 2 +- .../errorfiles/bdd/bdd-invalid-base64.smithy | 4 ++-- .../bdd/bdd-invalid-node-data.errors | 2 +- .../bdd/bdd-invalid-node-data.smithy | 4 ++-- .../bdd/bdd-invalid-root-reference.errors | 2 +- .../bdd/bdd-invalid-root-reference.smithy | 4 ++-- .../traits/errorfiles/bdd/bdd-valid.errors | 2 +- .../traits/errorfiles/bdd/bdd-valid.smithy | 4 ++-- .../endpoint-tests-without-ruleset.errors | 2 +- 28 files changed, 83 insertions(+), 83 deletions(-) rename smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/{BddTrait.java => EndpointBddTrait.java} (94%) diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java index d59dae89833..51f2b80ac84 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/validators/RuleSetAwsBuiltInValidator.java @@ -14,7 +14,7 @@ import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.aws.language.functions.AwsBuiltIns; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; -import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.utils.SetUtils; @@ -38,8 +38,8 @@ public List validate(Model model) { validateRuleSetAwsBuiltIns(events, s, trait.getEndpointRuleSet().getParameters()); } - for (ServiceShape s : model.getServiceShapesWithTrait(BddTrait.class)) { - validateRuleSetAwsBuiltIns(events, s, s.expectTrait(BddTrait.class).getParameters()); + for (ServiceShape s : model.getServiceShapesWithTrait(EndpointBddTrait.class)) { + validateRuleSetAwsBuiltIns(events, s, s.expectTrait(EndpointBddTrait.class).getParameters()); } return events; diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java index 08dfcffda68..30372140c1f 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java @@ -27,7 +27,7 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.RuleValueVisitor; import software.amazon.smithy.rulesengine.logic.RuleBasedConditionEvaluator; import software.amazon.smithy.rulesengine.logic.bdd.Bdd; -import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.EndpointBddTrait; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -57,7 +57,7 @@ public static Value evaluate(EndpointRuleSet ruleset, Map par * @param args The rule-set parameter identifiers and values to evaluate the BDD against. * @return The resulting value from the final matched rule. */ - public static Value evaluate(BddTrait trait, Map args) { + public static Value evaluate(EndpointBddTrait trait, Map args) { return evaluate(trait.getBdd(), trait.getParameters(), trait.getConditions(), trait.getResults(), args); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java index 50cefa59276..4bbc696144d 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/TestEvaluator.java @@ -13,7 +13,7 @@ import software.amazon.smithy.rulesengine.language.evaluation.value.EndpointValue; import software.amazon.smithy.rulesengine.language.evaluation.value.Value; import software.amazon.smithy.rulesengine.language.syntax.Identifier; -import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestExpectation; import software.amazon.smithy.rulesengine.traits.ExpectedEndpoint; @@ -45,7 +45,7 @@ public static void evaluate(EndpointRuleSet ruleset, EndpointTestCase testCase) * @param bdd The BDD trait to be tested. * @param testCase The test case. */ - public static void evaluate(BddTrait bdd, EndpointTestCase testCase) { + public static void evaluate(EndpointBddTrait bdd, EndpointTestCase testCase) { Value result = RuleEvaluator.evaluate(bdd, createParams(testCase)); processResult(result, testCase); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/EndpointBddTrait.java similarity index 94% rename from smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java rename to smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/EndpointBddTrait.java index 09315363fbd..85d1f0936c5 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/BddTrait.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/EndpointBddTrait.java @@ -34,8 +34,8 @@ /** * Trait containing a precompiled BDD with full context for endpoint resolution. */ -public final class BddTrait extends AbstractTrait implements ToSmithyBuilder { - public static final ShapeId ID = ShapeId.from("smithy.rules#bdd"); +public final class EndpointBddTrait extends AbstractTrait implements ToSmithyBuilder { + public static final ShapeId ID = ShapeId.from("smithy.rules#endpointBdd"); private static final Set ALLOWED_PROPERTIES = SetUtils.of( "parameters", @@ -50,7 +50,7 @@ public final class BddTrait extends AbstractTrait implements ToSmithyBuilder results; private final Bdd bdd; - private BddTrait(Builder builder) { + private EndpointBddTrait(Builder builder) { super(ID, builder.getSourceLocation()); this.parameters = SmithyBuilder.requiredState("parameters", builder.parameters); this.conditions = SmithyBuilder.requiredState("conditions", builder.conditions); @@ -64,7 +64,7 @@ private BddTrait(Builder builder) { * @param cfg the control flow graph to compile * @return the BddTrait containing the compiled BDD and all context */ - public static BddTrait from(Cfg cfg) { + public static EndpointBddTrait from(Cfg cfg) { BddCompiler compiler = new BddCompiler(cfg, new BddBuilder()); Bdd bdd = compiler.compile(); @@ -122,7 +122,7 @@ public Bdd getBdd() { * @param transformer Transformer used to modify the trait. * @return the updated trait. */ - public BddTrait transform(Function transformer) { + public EndpointBddTrait transform(Function transformer) { return transformer.apply(this); } @@ -164,7 +164,7 @@ protected Node createNode() { * @param node the node to parse * @return the BddTrait */ - public static BddTrait fromNode(Node node) { + public static EndpointBddTrait fromNode(Node node) { ObjectNode obj = node.expectObjectNode(); obj.warnIfAdditionalProperties(ALLOWED_PROPERTIES); Parameters params = Parameters.fromNode(obj.expectObjectMember("parameters")); @@ -181,7 +181,7 @@ public static BddTrait fromNode(Node node) { Bdd bdd = decodeBdd(nodesBase64, nodeCount, rootRef, conditions.size(), results.size()); - BddTrait trait = builder() + EndpointBddTrait trait = builder() .sourceLocation(node) .parameters(params) .conditions(conditions) @@ -251,7 +251,7 @@ public Builder toBuilder() { /** * Builder for BddTrait. */ - public static final class Builder extends AbstractTraitBuilder { + public static final class Builder extends AbstractTraitBuilder { private Parameters parameters; private List conditions; private List results; @@ -304,8 +304,8 @@ public Builder bdd(Bdd bdd) { } @Override - public BddTrait build() { - return new BddTrait(this); + public EndpointBddTrait build() { + return new EndpointBddTrait(this); } } @@ -316,7 +316,7 @@ public Provider() { @Override public Trait createTrait(ShapeId target, Node value) { - BddTrait trait = BddTrait.fromNode(value); + EndpointBddTrait trait = EndpointBddTrait.fromNode(value); trait.setNodeCache(value); return trait; } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java index ecd58249762..8200d173995 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversal.java @@ -13,12 +13,12 @@ *

    This transformation reverses the node array (except the terminal at index 0) * and updates all references throughout the BDD to maintain correctness. */ -public final class NodeReversal implements Function { +public final class NodeReversal implements Function { private static final Logger LOGGER = Logger.getLogger(NodeReversal.class.getName()); @Override - public BddTrait apply(BddTrait trait) { + public EndpointBddTrait apply(EndpointBddTrait trait) { Bdd reversedBdd = reverse(trait.getBdd()); // Only rebuild the trait if the BDD actually changed return reversedBdd == trait.getBdd() ? trait : trait.toBuilder().bdd(reversedBdd).build(); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java index 7fc06152a1a..525c15e6ac1 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -30,7 +30,7 @@ *

  • Granular: Fine-tuned optimization for maximum reduction
  • * */ -public final class SiftingOptimization implements Function { +public final class SiftingOptimization implements Function { private static final Logger LOGGER = Logger.getLogger(SiftingOptimization.class.getName()); // When to use a parallel stream @@ -96,7 +96,7 @@ public static Builder builder() { } @Override - public BddTrait apply(BddTrait trait) { + public EndpointBddTrait apply(EndpointBddTrait trait) { try { return doApply(trait); } finally { @@ -104,7 +104,7 @@ public BddTrait apply(BddTrait trait) { } } - private BddTrait doApply(BddTrait trait) { + private EndpointBddTrait doApply(EndpointBddTrait trait) { LOGGER.info("Starting BDD sifting optimization"); long startTime = System.currentTimeMillis(); OptimizationState state = initializeOptimization(trait); @@ -135,7 +135,7 @@ private BddTrait doApply(BddTrait trait) { return trait.toBuilder().conditions(state.orderView).results(state.results).bdd(state.bestBdd).build(); } - private OptimizationState initializeOptimization(BddTrait trait) { + private OptimizationState initializeOptimization(EndpointBddTrait trait) { // Use the trait's existing ordering as the starting point List initialOrder = new ArrayList<>(trait.getConditions()); Condition[] order = initialOrder.toArray(new Condition[0]); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/BddTraitValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/BddTraitValidator.java index d6c688b51d9..40c525f7553 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/BddTraitValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/BddTraitValidator.java @@ -12,24 +12,24 @@ import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.logic.bdd.Bdd; -import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.EndpointBddTrait; public final class BddTraitValidator extends AbstractValidator { @Override public List validate(Model model) { - if (!model.isTraitApplied(BddTrait.class)) { + if (!model.isTraitApplied(EndpointBddTrait.class)) { return Collections.emptyList(); } List events = new ArrayList<>(); - for (ServiceShape service : model.getServiceShapesWithTrait(BddTrait.class)) { - validateService(events, service, service.expectTrait(BddTrait.class)); + for (ServiceShape service : model.getServiceShapesWithTrait(EndpointBddTrait.class)) { + validateService(events, service, service.expectTrait(EndpointBddTrait.class)); } return events; } - private void validateService(List events, ServiceShape service, BddTrait trait) { + private void validateService(List events, ServiceShape service, EndpointBddTrait trait) { Bdd bdd = trait.getBdd(); // Validate root reference @@ -87,11 +87,11 @@ private void validateService(List events, ServiceShape service, private void validateReference( List events, ServiceShape service, - BddTrait trait, + EndpointBddTrait trait, String context, int ref, Bdd bdd, - BddTrait bddTrait + EndpointBddTrait bddTrait ) { if (ref == 0) { events.add(error(service, trait, String.format("%s has invalid reference: 0", context))); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/EndpointTestsTraitValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/EndpointTestsTraitValidator.java index 99b6f9a983a..2cb62ecd2cc 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/EndpointTestsTraitValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/EndpointTestsTraitValidator.java @@ -20,7 +20,7 @@ import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; -import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestOperationInput; @@ -52,7 +52,7 @@ public List validate(Model model) { operationNameMap); }); - serviceShape.getTrait(BddTrait.class).ifPresent(trait -> { + serviceShape.getTrait(EndpointBddTrait.class).ifPresent(trait -> { validateEndpointRuleSet(events, model, serviceShape, trait.getParameters(), operationNameMap); }); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java index 5a109005413..cb6c600174a 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java @@ -21,7 +21,7 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; -import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; /** @@ -35,7 +35,7 @@ public List validate(Model model) { List events = new ArrayList<>(); for (ServiceShape serviceShape : model.getServiceShapes()) { visitRuleset(events, serviceShape, serviceShape.getTrait(EndpointRuleSetTrait.class).orElse(null)); - visitBdd(events, serviceShape, serviceShape.getTrait(BddTrait.class).orElse(null)); + visitBdd(events, serviceShape, serviceShape.getTrait(EndpointBddTrait.class).orElse(null)); } return events; } @@ -48,7 +48,7 @@ private void visitRuleset(List events, ServiceShape serviceShap } } - private void visitBdd(List events, ServiceShape serviceShape, BddTrait trait) { + private void visitBdd(List events, ServiceShape serviceShape, EndpointBddTrait trait) { if (trait != null) { for (Rule result : trait.getResults()) { if (result instanceof EndpointRule) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java index b0e75424ede..3b552dddfac 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetBuiltInValidator.java @@ -15,7 +15,7 @@ import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; -import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestOperationInput; @@ -29,8 +29,8 @@ public final class RuleSetBuiltInValidator extends AbstractValidator { public List validate(Model model) { List events = new ArrayList<>(); - for (ServiceShape s : model.getServiceShapesWithTrait(BddTrait.class)) { - validateParams(events, s, s.expectTrait(BddTrait.class).getParameters()); + for (ServiceShape s : model.getServiceShapesWithTrait(EndpointBddTrait.class)) { + validateParams(events, s, s.expectTrait(EndpointBddTrait.class).getParameters()); } for (ServiceShape s : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java index 333b0893d03..e0115a00ee9 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParamMissingDocsValidator.java @@ -11,7 +11,7 @@ import software.amazon.smithy.model.validation.AbstractValidator; import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; -import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; /** @@ -23,7 +23,7 @@ public List validate(Model model) { List events = new ArrayList<>(); for (ServiceShape serviceShape : model.getServiceShapes()) { visitRuleset(events, serviceShape, serviceShape.getTrait(EndpointRuleSetTrait.class).orElse(null)); - visitBdd(events, serviceShape, serviceShape.getTrait(BddTrait.class).orElse(null)); + visitBdd(events, serviceShape, serviceShape.getTrait(EndpointBddTrait.class).orElse(null)); } return events; } @@ -34,7 +34,7 @@ private void visitRuleset(List events, ServiceShape serviceShap } } - private void visitBdd(List events, ServiceShape serviceShape, BddTrait trait) { + private void visitBdd(List events, ServiceShape serviceShape, EndpointBddTrait trait) { if (trait != null) { visitParams(events, serviceShape, trait.getParameters()); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java index 0e6268fd0fe..6484e43e090 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetParameterValidator.java @@ -25,7 +25,7 @@ import software.amazon.smithy.rulesengine.language.evaluation.value.Value; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; -import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.ClientContextParamDefinition; import software.amazon.smithy.rulesengine.traits.ClientContextParamsTrait; import software.amazon.smithy.rulesengine.traits.ContextParamTrait; @@ -49,7 +49,7 @@ public List validate(Model model) { for (ServiceShape service : model.getServiceShapes()) { EndpointRuleSetTrait epTrait = service.getTrait(EndpointRuleSetTrait.class).orElse(null); - BddTrait bddTrait = service.getTrait(BddTrait.class).orElse(null); + EndpointBddTrait bddTrait = service.getTrait(EndpointBddTrait.class).orElse(null); if (epTrait != null) { validate(model, topDownIndex, service, errors, epTrait, epTrait.getEndpointRuleSet().getParameters()); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java index ae62a0322c4..5ce90fb4735 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetTestCaseValidator.java @@ -12,7 +12,7 @@ import software.amazon.smithy.model.validation.ValidationEvent; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.evaluation.TestEvaluator; -import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.rulesengine.traits.EndpointTestCase; import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait; @@ -28,7 +28,7 @@ public List validate(Model model) { EndpointTestsTrait testsTrait = serviceShape.expectTrait(EndpointTestsTrait.class); if (serviceShape.hasTrait(EndpointRuleSetTrait.class)) { validate(serviceShape, testsTrait, events); - } else if (serviceShape.hasTrait(BddTrait.class)) { + } else if (serviceShape.hasTrait(EndpointBddTrait.class)) { validateBdd(serviceShape, testsTrait, events); } } @@ -49,7 +49,7 @@ private void validate(ServiceShape serviceShape, EndpointTestsTrait testsTrait, } private void validateBdd(ServiceShape serviceShape, EndpointTestsTrait testsTrait, List events) { - BddTrait trait = serviceShape.expectTrait(BddTrait.class); + EndpointBddTrait trait = serviceShape.expectTrait(EndpointBddTrait.class); for (EndpointTestCase endpointTestCase : testsTrait.getTestCases()) { try { TestEvaluator.evaluate(trait, endpointTestCase); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java index b5da5ff896e..c2dd43941da 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetUriValidator.java @@ -17,7 +17,7 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; -import software.amazon.smithy.rulesengine.logic.bdd.BddTrait; +import software.amazon.smithy.rulesengine.logic.bdd.EndpointBddTrait; import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -31,7 +31,7 @@ public List validate(Model model) { List events = new ArrayList<>(); for (ServiceShape serviceShape : model.getServiceShapes()) { visitRuleset(events, serviceShape, serviceShape.getTrait(EndpointRuleSetTrait.class).orElse(null)); - visitBdd(events, serviceShape, serviceShape.getTrait(BddTrait.class).orElse(null)); + visitBdd(events, serviceShape, serviceShape.getTrait(EndpointBddTrait.class).orElse(null)); } return events; } @@ -44,7 +44,7 @@ private void visitRuleset(List events, ServiceShape serviceShap } } - private void visitBdd(List events, ServiceShape serviceShape, BddTrait trait) { + private void visitBdd(List events, ServiceShape serviceShape, EndpointBddTrait trait) { if (trait != null) { for (Rule result : trait.getResults()) { if (result instanceof EndpointRule) { diff --git a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService index 354eade0bdf..f6ad2ccde3f 100644 --- a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService +++ b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.traits.TraitService @@ -4,4 +4,4 @@ software.amazon.smithy.rulesengine.traits.StaticContextParamsTrait$Provider software.amazon.smithy.rulesengine.traits.OperationContextParamsTrait$Provider software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait$Provider software.amazon.smithy.rulesengine.traits.EndpointTestsTrait$Provider -software.amazon.smithy.rulesengine.logic.bdd.BddTrait$Provider +software.amazon.smithy.rulesengine.logic.bdd.EndpointBddTrait$Provider diff --git a/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy b/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy index 69fd058c0c1..e40049e5525 100644 --- a/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy +++ b/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy @@ -10,7 +10,7 @@ document endpointRuleSet /// Defines an endpoint rule-set using a binary decision diagram (BDD). @unstable @trait(selector: "service") -structure bdd { +structure endpointBdd { /// A map of zero or more endpoint parameter names to their parameter configuration. @required parameters: Parameters @@ -206,7 +206,7 @@ map EndpointObjectHeaders { /// Defines endpoint test-cases for validating a client's endpoint rule-set. @unstable -@trait(selector: "service :is([trait|smithy.rules#endpointRuleSet], [trait|smithy.rules#bdd])") +@trait(selector: "service :is([trait|smithy.rules#endpointRuleSet], [trait|smithy.rules#endpointBdd])") structure endpointTests { /// The endpoint tests schema version. @required diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java index fe004d1155d..371f08643a5 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/BddTraitTest.java @@ -34,7 +34,7 @@ void testBddTraitSerialization() { Bdd bdd = createSimpleBdd(); - BddTrait original = BddTrait.builder() + EndpointBddTrait original = EndpointBddTrait.builder() .parameters(params) .conditions(ListUtils.of(cond)) .results(results) @@ -56,7 +56,7 @@ void testBddTraitSerialization() { assertEquals(1, serializedResultCount); // Deserialize from Node - BddTrait restored = BddTrait.fromNode(node); + EndpointBddTrait restored = EndpointBddTrait.fromNode(node); assertEquals(original.getParameters(), restored.getParameters()); assertEquals(original.getConditions().size(), restored.getConditions().size()); @@ -87,7 +87,7 @@ void testEmptyBddTrait() { int[] nodes = new int[] {-1, 1, -1}; Bdd bdd = new Bdd(-1, 0, 1, 1, nodes); - BddTrait trait = BddTrait.builder() + EndpointBddTrait trait = EndpointBddTrait.builder() .parameters(params) .conditions(ListUtils.of()) .results(ListUtils.of(NoMatchRule.INSTANCE)) diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java index b9d05b1f15e..8eda72256ff 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/NodeReversalTest.java @@ -188,14 +188,14 @@ void testBddTraitReversalReturnsOriginalForSmallBdd() { consumer.accept(0, 1, -1); // node 1: simple condition }); - BddTrait originalTrait = BddTrait.builder() + EndpointBddTrait originalTrait = EndpointBddTrait.builder() .parameters(Parameters.builder().build()) .conditions(new ArrayList<>()) .results(Collections.singletonList(NoMatchRule.INSTANCE)) .bdd(bdd) .build(); - BddTrait reversedTrait = reversal.apply(originalTrait); + EndpointBddTrait reversedTrait = reversal.apply(originalTrait); // Should return the exact same trait object for small BDDs assertSame(originalTrait, reversedTrait); diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java index 005985861a8..9edbe82187d 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimizationTest.java @@ -45,10 +45,10 @@ void testBasicOptimization() { .build(); Cfg cfg = Cfg.from(ruleSet); - BddTrait originalTrait = BddTrait.from(cfg); + EndpointBddTrait originalTrait = EndpointBddTrait.from(cfg); SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); - BddTrait optimizedTrait = optimizer.apply(originalTrait); + EndpointBddTrait optimizedTrait = optimizer.apply(originalTrait); // Basic checks assertEquals(originalTrait.getConditions().size(), optimizedTrait.getConditions().size()); @@ -88,10 +88,10 @@ void testDependenciesPreserved() { .build(); Cfg cfg = Cfg.from(ruleSet); - BddTrait originalTrait = BddTrait.from(cfg); + EndpointBddTrait originalTrait = EndpointBddTrait.from(cfg); SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); - BddTrait optimizedTrait = optimizer.apply(originalTrait); + EndpointBddTrait optimizedTrait = optimizer.apply(originalTrait); // Verify the optimizer preserved the number of conditions assertEquals(originalTrait.getConditions().size(), optimizedTrait.getConditions().size()); @@ -118,10 +118,10 @@ void testSingleCondition() { EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).addRule(rule).build(); Cfg cfg = Cfg.from(ruleSet); - BddTrait originalTrait = BddTrait.from(cfg); + EndpointBddTrait originalTrait = EndpointBddTrait.from(cfg); SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); - BddTrait optimizedTrait = optimizer.apply(originalTrait); + EndpointBddTrait optimizedTrait = optimizer.apply(originalTrait); // Should be unchanged or very similar assertEquals(originalTrait.getBdd().getNodeCount(), optimizedTrait.getBdd().getNodeCount()); @@ -136,10 +136,10 @@ void testEmptyRuleSet() { EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).build(); Cfg cfg = Cfg.from(ruleSet); - BddTrait originalTrait = BddTrait.from(cfg); + EndpointBddTrait originalTrait = EndpointBddTrait.from(cfg); SiftingOptimization optimizer = SiftingOptimization.builder().cfg(cfg).build(); - BddTrait optimizedTrait = optimizer.apply(originalTrait); + EndpointBddTrait optimizedTrait = optimizer.apply(originalTrait); assertEquals(0, optimizedTrait.getBdd().getConditionCount()); assertEquals(originalTrait.getBdd().getResultCount(), optimizedTrait.getBdd().getResultCount()); @@ -183,13 +183,13 @@ void testLargeReduction() { .build(); Cfg cfg = Cfg.from(ruleSet); - BddTrait originalTrait = BddTrait.from(cfg); + EndpointBddTrait originalTrait = EndpointBddTrait.from(cfg); SiftingOptimization optimizer = SiftingOptimization.builder() .cfg(cfg) .granularEffort(100_000, 10) // Allow more aggressive optimization .build(); - BddTrait optimizedTrait = optimizer.apply(originalTrait); + EndpointBddTrait optimizedTrait = optimizer.apply(originalTrait); // Should maintain correctness assertEquals(originalTrait.getConditions().size(), optimizedTrait.getConditions().size()); @@ -214,14 +214,14 @@ void testNoImprovementReturnsOriginal() { EndpointRuleSet ruleSet = EndpointRuleSet.builder().parameters(params).addRule(rule).build(); Cfg cfg = Cfg.from(ruleSet); - BddTrait originalTrait = BddTrait.from(cfg); + EndpointBddTrait originalTrait = EndpointBddTrait.from(cfg); SiftingOptimization optimizer = SiftingOptimization.builder() .cfg(cfg) .coarseEffort(1, 1) // Minimal effort to likely find no improvement .build(); - BddTrait optimizedTrait = optimizer.apply(originalTrait); + EndpointBddTrait optimizedTrait = optimizer.apply(originalTrait); // For simple cases with minimal optimization effort, should return the same trait object if (optimizedTrait.getBdd().getNodeCount() == originalTrait.getBdd().getNodeCount()) { diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.errors index 5ead7e2607a..aa03a54143f 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.errors +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.errors @@ -1 +1 @@ -[ERROR] smithy.example#ValidBddService: Error creating trait `smithy.rules#bdd`: Input byte array has wrong 4-byte ending unit | Model +[ERROR] smithy.example#ValidBddService: Error creating trait `smithy.rules#endpointBdd`: Input byte array has wrong 4-byte ending unit | Model diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.smithy index e99bbd51f0f..dc2a66b9cbd 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.smithy +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.smithy @@ -2,9 +2,9 @@ $version: "2.0" namespace smithy.example -use smithy.rules#bdd +use smithy.rules#endpointBdd -@bdd({ +@endpointBdd({ parameters: { Region: { type: "string" diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.errors index 13b553277b7..7ef5d590fd7 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.errors +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.errors @@ -1 +1 @@ -[ERROR] smithy.example#ValidBddService: Error creating trait `smithy.rules#bdd`: Expected 36 bytes for 3 nodes, but got 2 | Model +[ERROR] smithy.example#ValidBddService: Error creating trait `smithy.rules#endpointBdd`: Expected 36 bytes for 3 nodes, but got 2 | Model diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.smithy index c89c0d4967e..0b558f3f93c 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.smithy +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.smithy @@ -2,9 +2,9 @@ $version: "2.0" namespace smithy.example -use smithy.rules#bdd +use smithy.rules#endpointBdd -@bdd({ +@endpointBdd({ parameters: { Region: { type: "string" diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.errors index 2cdf5489406..8703a938001 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.errors +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.errors @@ -1 +1 @@ -[ERROR] smithy.example#InvalidRootRefService: Error creating trait `smithy.rules#bdd`: Root reference cannot be complemented: -5 | Model +[ERROR] smithy.example#InvalidRootRefService: Error creating trait `smithy.rules#endpointBdd`: Root reference cannot be complemented: -5 | Model diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy index 5479643975c..4474897a55d 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy @@ -2,9 +2,9 @@ $version: "2.0" namespace smithy.example -use smithy.rules#bdd +use smithy.rules#endpointBdd -@bdd({ +@endpointBdd({ parameters: {} conditions: [] results: [] diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.errors index b72ae3c7545..19ee0420061 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.errors +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.errors @@ -1,2 +1,2 @@ -[WARNING] smithy.example#ValidBddService: This shape applies a trait that is unstable: smithy.rules#bdd | UnstableTrait.smithy.rules#bdd +[WARNING] smithy.example#ValidBddService: This shape applies a trait that is unstable: smithy.rules#endpointBdd | UnstableTrait.smithy.rules#endpointBdd [WARNING] smithy.example#ValidBddService: This shape applies a trait that is unstable: smithy.rules#clientContextParams | UnstableTrait.smithy.rules#clientContextParams diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy index 43f221e8014..032117bbc36 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy @@ -2,14 +2,14 @@ $version: "2.0" namespace smithy.example -use smithy.rules#bdd use smithy.rules#clientContextParams +use smithy.rules#endpointBdd @clientContextParams( Region: {type: "string", documentation: "docs"} UseFips: {type: "boolean", documentation: "docs"} ) -@bdd({ +@endpointBdd({ "parameters": { "Region": { "required": true, diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/endpoint-tests-without-ruleset.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/endpoint-tests-without-ruleset.errors index b0779649d6f..788659c192c 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/endpoint-tests-without-ruleset.errors +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/endpoint-tests-without-ruleset.errors @@ -1 +1 @@ -[ERROR] smithy.example#InvalidService: Trait `smithy.rules#endpointTests` cannot be applied to `smithy.example#InvalidService`. This trait may only be applied to shapes that match the following selector: service :is([trait|smithy.rules#endpointRuleSet], [trait|smithy.rules#bdd]) | TraitTarget +[ERROR] smithy.example#InvalidService: Trait `smithy.rules#endpointTests` cannot be applied to `smithy.example#InvalidService`. This trait may only be applied to shapes that match the following selector: service :is([trait|smithy.rules#endpointRuleSet], [trait|smithy.rules#endpointBdd]) | TraitTarget From fc6d936537433b2c7e6e9fa38c3135eaaf4a897f Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Thu, 14 Aug 2025 21:05:47 -0500 Subject: [PATCH 17/23] Add version to bdd trait and syntax elements We can now track what version an endpointBdd trait uses and attach minimum version requirements to all syntax elements of the rules engine, include functions. Now the coalesce function is available since version 1.1. Next, I'll add validation to ensure the version requirements of every syntax element of an endpointRuleSet or endpointBdd meet the version of the trait. --- .../language/syntax/SyntaxElement.java | 9 +++ .../expressions/functions/Coalesce.java | 7 ++ .../logic/bdd/EndpointBddTrait.java | 42 ++++++++++++ .../smithy/rulesengine/logic/cfg/Cfg.java | 24 +++++-- .../rulesengine/logic/cfg/CfgBuilder.java | 14 ++++ .../META-INF/smithy/smithy.rules.smithy | 24 ++----- .../errorfiles/bdd/bdd-invalid-base64.smithy | 1 + .../bdd/bdd-invalid-node-data.smithy | 1 + .../bdd/bdd-invalid-root-reference.smithy | 1 + .../traits/errorfiles/bdd/bdd-valid.smithy | 1 + .../errorfiles/bdd/illegal-version.errors | 1 + .../errorfiles/bdd/illegal-version.smithy | 64 +++++++++++++++++++ 12 files changed, 168 insertions(+), 21 deletions(-) create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/illegal-version.errors create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/illegal-version.smithy diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/SyntaxElement.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/SyntaxElement.java index de0a0f623eb..2a75580944f 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/SyntaxElement.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/SyntaxElement.java @@ -22,6 +22,15 @@ */ @SmithyInternalApi public abstract class SyntaxElement implements ToCondition, ToExpression { + /** + * Get the rules engine version that this syntax element is available since. + * + * @return the version this is available since. + */ + public String availableSince() { + return "1.0"; + } + /** * Returns a BooleanEquals expression comparing this expression to the provided boolean value. * diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java index 3e8061dcb44..0042ce6f95b 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java @@ -37,6 +37,8 @@ * *

    Supports chaining: * {@code coalesce(opt1, coalesce(opt2, coalesce(opt3, default)))} + * + *

    Available since: rules engine 1.1. */ @SmithyUnstableApi public final class Coalesce extends LibraryFunction { @@ -67,6 +69,11 @@ public static Coalesce ofExpressions(ToExpression arg1, ToExpression arg2) { return DEFINITION.createFunction(FunctionNode.ofExpressions(ID, arg1, arg2)); } + @Override + public String availableSince() { + return "1.1"; + } + @Override public R accept(ExpressionVisitor visitor) { List args = getArguments(); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/EndpointBddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/EndpointBddTrait.java index 85d1f0936c5..3b114785566 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/EndpointBddTrait.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/EndpointBddTrait.java @@ -8,6 +8,7 @@ import java.io.DataOutputStream; import java.io.IOException; import java.io.UncheckedIOException; +import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.ArrayList; @@ -37,7 +38,9 @@ public final class EndpointBddTrait extends AbstractTrait implements ToSmithyBuilder { public static final ShapeId ID = ShapeId.from("smithy.rules#endpointBdd"); + private static final BigDecimal MIN_VERSION = new BigDecimal("1.1"); private static final Set ALLOWED_PROPERTIES = SetUtils.of( + "version", "parameters", "conditions", "results", @@ -45,6 +48,7 @@ public final class EndpointBddTrait extends AbstractTrait implements ToSmithyBui "nodes", "nodeCount"); + private final String version; private final Parameters parameters; private final List conditions; private final List results; @@ -52,10 +56,16 @@ public final class EndpointBddTrait extends AbstractTrait implements ToSmithyBui private EndpointBddTrait(Builder builder) { super(ID, builder.getSourceLocation()); + this.version = SmithyBuilder.requiredState("version", builder.version); this.parameters = SmithyBuilder.requiredState("parameters", builder.parameters); this.conditions = SmithyBuilder.requiredState("conditions", builder.conditions); this.results = SmithyBuilder.requiredState("results", builder.results); this.bdd = SmithyBuilder.requiredState("bdd", builder.bdd); + + BigDecimal v = new BigDecimal(version); + if (v.compareTo(MIN_VERSION) < 0) { + throw new IllegalArgumentException("Rules engine version for endpointBdd trait must be >= " + MIN_VERSION); + } } /** @@ -72,7 +82,14 @@ public static EndpointBddTrait from(Cfg cfg) { throw new IllegalStateException("Mismatch between BDD var count and orderedConditions size"); } + // Automatically convert 1.0 versions of the decision tree to 1.1 for the minimum version of the BDD trait. + String version = cfg.getVersion(); + if (version.equals("1.0")) { + version = "1.1"; + } + return builder() + .version(version) .parameters(cfg.getParameters()) .conditions(compiler.getOrderedConditions()) .results(compiler.getIndexedResults()) @@ -116,6 +133,15 @@ public Bdd getBdd() { return bdd; } + /** + * Get the endpoint ruleset version. + * + * @return the rules engine version + */ + public String getVersion() { + return version; + } + /** * Transform this BDD using the given function and return the updated BddTrait. * @@ -129,6 +155,7 @@ public EndpointBddTrait transform(Function t @Override protected Node createNode() { ObjectNode.Builder builder = ObjectNode.builder(); + builder.withMember("version", version); builder.withMember("parameters", parameters.toNode()); ArrayNode.Builder conditionBuilder = ArrayNode.builder(); @@ -167,6 +194,7 @@ protected Node createNode() { public static EndpointBddTrait fromNode(Node node) { ObjectNode obj = node.expectObjectNode(); obj.warnIfAdditionalProperties(ALLOWED_PROPERTIES); + String version = obj.expectStringMember("version").getValue(); Parameters params = Parameters.fromNode(obj.expectObjectMember("parameters")); List conditions = obj.expectArrayMember("conditions").getElementsAs(Condition::fromNode); @@ -182,6 +210,7 @@ public static EndpointBddTrait fromNode(Node node) { Bdd bdd = decodeBdd(nodesBase64, nodeCount, rootRef, conditions.size(), results.size()); EndpointBddTrait trait = builder() + .version(version) .sourceLocation(node) .parameters(params) .conditions(conditions) @@ -241,6 +270,7 @@ public static Builder builder() { @Override public Builder toBuilder() { return builder() + .version(version) .sourceLocation(getSourceLocation()) .parameters(parameters) .conditions(conditions) @@ -252,6 +282,7 @@ public Builder toBuilder() { * Builder for BddTrait. */ public static final class Builder extends AbstractTraitBuilder { + private String version = "1.1"; private Parameters parameters; private List conditions; private List results; @@ -259,6 +290,17 @@ public static final class Builder extends AbstractTraitBuilder { // Lazily computed condition data private Condition[] conditions; private Map conditionToIndex; + private final String version; Cfg(EndpointRuleSet ruleSet, CfgNode root) { - this(ruleSet == null ? Parameters.builder().build() : ruleSet.getParameters(), root); + this( + ruleSet == null ? Parameters.builder().build() : ruleSet.getParameters(), + root, + ruleSet == null ? "1.1" : ruleSet.getVersion()); } - Cfg(Parameters parameters, CfgNode root) { + Cfg(Parameters parameters, CfgNode root, String version) { this.root = SmithyBuilder.requiredState("root", root); + this.version = version; this.parameters = parameters; } @@ -69,6 +75,15 @@ public static Cfg from(EndpointRuleSet ruleSet) { return builder.build(root); } + /** + * Get the endpoint ruleset version of the CFG. + * + * @return endpoint ruleset version. + */ + public String getVersion() { + return version; + } + /** * Gets all unique conditions in the CFG, in the order they were discovered. * @@ -141,13 +156,14 @@ public boolean equals(Object object) { } else if (object == null || getClass() != object.getClass()) { return false; } else { - return root.equals(((Cfg) object).root); + Cfg o = (Cfg) object; + return root.equals(o.root) && version.equals(o.version); } } @Override public int hashCode() { - return root.hashCode(); + return Objects.hash(root, version); } /** diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java index 503cbf64e9f..f28c5dbcacb 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java @@ -66,9 +66,12 @@ public final class CfgBuilder { private int convergenceNodesCreated = 0; private int phiVariableCounter = 0; + private String version = "1.1"; + public CfgBuilder(EndpointRuleSet ruleSet) { // Apply SSA transform to ensure globally unique variable names this.ruleSet = SsaTransform.transform(ruleSet); + this.version = ruleSet.getVersion(); // Analyze results and create convergence nodes analyzeAndCreateConvergenceNodes(); @@ -84,6 +87,17 @@ public Cfg build(CfgNode root) { return new Cfg(ruleSet, Objects.requireNonNull(root)); } + /** + * Set the version of the endpoint rules engine (e.g., 1.1). + * + * @param version Version to set. + * @return the builder; + */ + public CfgBuilder version(String version) { + this.version = Objects.requireNonNull(version); + return this; + } + /** * Creates a condition node, reusing existing nodes when possible. * diff --git a/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy b/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy index e40049e5525..09dc195fe4f 100644 --- a/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy +++ b/smithy-rules-engine/src/main/resources/META-INF/smithy/smithy.rules.smithy @@ -11,6 +11,10 @@ document endpointRuleSet @unstable @trait(selector: "service") structure endpointBdd { + /// The rules engine version. Must be set to 1.1 or higher. + @required + version: String + /// A map of zero or more endpoint parameter names to their parameter configuration. @required parameters: Parameters @@ -35,15 +39,12 @@ structure endpointBdd { /// Base64-encoded array of BDD nodes representing the decision graph structure. /// + /// All integers are encoded in big-endian. + /// /// The first node (index 0) is always the terminal node `[-1, 1, -1]` and is included in the nodeCount. /// User-defined nodes start at index 1. /// - /// Zig-zag encoding transforms signed integers to unsigned: - /// - 0 -> 0, -1 → 1, 1 → 2, -2 → 3, 2 → 4, etc. - /// - Formula: `(n << 1) ^ (n >> 31)` - /// - This ensures small negative numbers use few bytes - /// - /// Each node consists of three varint-encoded integers written sequentially: + /// Each node is written one after the other and consists of three integers written sequentially: /// 1. variable index /// 2. high reference (when condition is true) /// 3. low reference (when condition is false) @@ -69,17 +70,6 @@ structure endpointBdd { /// boolean function and its complement; instead of creating separate nodes for `condition AND other` and /// `NOT(condition AND other)`, we can reuse the same nodes with complement edges. Complement edges cannot be /// used on result terminals. - /// - /// Example (before encoding): - /// ``` - /// nodes = [ - /// [ -1, 1, -1], // 0: terminal node - /// [ 0, 3, 2], // 1: if condition[0] then node 3, else node 2 - /// [ 1, 2000001, -1], // 2: if condition[1] then result[1], else FALSE - /// ] - /// ``` - /// - /// After zig-zag + varint + base64: `"AQEBAAYEBAGBwOgPAQ=="` @required nodes: String } diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.smithy index dc2a66b9cbd..7ce9f68fabd 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.smithy +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-base64.smithy @@ -5,6 +5,7 @@ namespace smithy.example use smithy.rules#endpointBdd @endpointBdd({ + version: "1.1" parameters: { Region: { type: "string" diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.smithy index 0b558f3f93c..b3e27e3b824 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.smithy +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-node-data.smithy @@ -5,6 +5,7 @@ namespace smithy.example use smithy.rules#endpointBdd @endpointBdd({ + version: "1.1" parameters: { Region: { type: "string" diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy index 4474897a55d..2976c0be659 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-invalid-root-reference.smithy @@ -5,6 +5,7 @@ namespace smithy.example use smithy.rules#endpointBdd @endpointBdd({ + version: "1.1" parameters: {} conditions: [] results: [] diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy index 032117bbc36..86d695be02c 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/bdd-valid.smithy @@ -10,6 +10,7 @@ use smithy.rules#endpointBdd UseFips: {type: "boolean", documentation: "docs"} ) @endpointBdd({ + version: "1.1" "parameters": { "Region": { "required": true, diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/illegal-version.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/illegal-version.errors new file mode 100644 index 00000000000..e18a498effe --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/illegal-version.errors @@ -0,0 +1 @@ +[ERROR] smithy.example#ValidBddService: Error creating trait `smithy.rules#endpointBdd`: Rules engine version for endpointBdd trait must be >= 1.1 | Model diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/illegal-version.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/illegal-version.smithy new file mode 100644 index 00000000000..f37ef99b450 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/traits/errorfiles/bdd/illegal-version.smithy @@ -0,0 +1,64 @@ +$version: "2.0" + +namespace smithy.example + +use smithy.rules#clientContextParams +use smithy.rules#endpointBdd + +@clientContextParams( + Region: {type: "string", documentation: "docs"} + UseFips: {type: "boolean", documentation: "docs"} +) +@endpointBdd({ + version: "1.0" + "parameters": { + "Region": { + "required": true, + "documentation": "The AWS region", + "type": "string" + }, + "UseFips": { + "required": true, + "default": false, + "documentation": "Use FIPS endpoints", + "type": "boolean" + } + }, + "conditions": [ + { + "fn": "booleanEquals", + "argv": [ + { + "ref": "UseFips" + }, + true + ] + } + ], + "results": [ + { + "conditions": [], + "endpoint": { + "url": "https://service-fips.{Region}.amazonaws.com", + "properties": {}, + "headers": {} + }, + "type": "endpoint" + }, + { + "conditions": [], + "endpoint": { + "url": "https://service.{Region}.amazonaws.com", + "properties": {}, + "headers": {} + }, + "type": "endpoint" + } + ], + "root": 2, + "nodeCount": 2, + "nodes": "/////wAAAAH/////AAAAAAX14QEF9eEC" +}) +service ValidBddService { + version: "2022-01-01" +} From 5bbd38bf8aee48f828d167c8e709b09769f6aca6 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Fri, 15 Aug 2025 15:27:10 -0500 Subject: [PATCH 18/23] Simplify and document coalesce --- .../rules-engine/standard-library.rst | 48 +++++++++++ .../expressions/functions/Coalesce.java | 80 ++++--------------- 2 files changed, 65 insertions(+), 63 deletions(-) diff --git a/docs/source-2.0/additional-specs/rules-engine/standard-library.rst b/docs/source-2.0/additional-specs/rules-engine/standard-library.rst index f628d689286..1b0d3fa7da3 100644 --- a/docs/source-2.0/additional-specs/rules-engine/standard-library.rst +++ b/docs/source-2.0/additional-specs/rules-engine/standard-library.rst @@ -38,6 +38,54 @@ parameter is equal to the value ``false``: } +.. _rules-engine-standard-library-coalesce: + +``coalesce`` function +===================== + +Summary + Evaluates the first argument and returns the result if it is present, otherwise evaluates and returns the result + of the second argument. +Argument types + * value1: ``T`` or ``option`` + * value2: ``T`` or ``option`` +Return type + * ``coalesce(T, T)`` → ``T`` + * ``coalesce(option, T)`` → ``T`` + * ``coalesce(T, option)`` → ``T`` + * ``coalesce(option, option)`` → ``option`` +Since + 1.1 + +The ``coalesce`` function provides null-safe chaining by returning the result of the first argument if it returns a +value, otherwise returns the result of the second argument. This is particularly useful for providing default values +for optional parameters, chaining multiple optional values together, and related optimizations. + +The following example demonstrates chaining multiple ``coalesce`` calls to try several optional values +in sequence: + +.. code-block:: json + + { + "fn": "coalesce", + "argv": [ + {"ref": "customEndpoint"}, + { + "fn": "coalesce", + "argv": [ + {"ref": "regionalEndpoint"}, + {"ref": "defaultEndpoint"} + ] + } + ] + } + +.. important:: + Both arguments must be of the same type after unwrapping any optionals (types are known at compile time and do not + need to be validated at runtime). Note that the first result is returned even if it's ``false`` (coalesce is + looking for a *non-empty* value). + + .. _rules-engine-standard-library-getAttr: ``getAttr`` function diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java index 0042ce6f95b..af68b21207a 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java @@ -7,7 +7,6 @@ import java.util.Arrays; import java.util.List; import software.amazon.smithy.rulesengine.language.evaluation.Scope; -import software.amazon.smithy.rulesengine.language.evaluation.type.AnyType; import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.evaluation.value.Value; @@ -17,24 +16,17 @@ import software.amazon.smithy.utils.SmithyUnstableApi; /** - * A coalesce function that returns the first non-empty value, with type-safe fallback handling. + * A coalesce function that returns the first non-empty value. * At runtime, returns the left value unless it's EmptyValue, in which case returns the right value. * *

    Type checking rules: *

      *
    • {@code coalesce(T, T) => T} (same types)
    • - *
    • {@code coalesce(T, AnyType) => T} (AnyType adapts to concrete type)
    • - *
    • {@code coalesce(AnyType, T) => T} (AnyType adapts to concrete type)
    • - *
    • {@code coalesce(T, S) => S} (if T.isA(S), i.e., S is more general)
    • - *
    • {@code coalesce(T, S) => T} (if S.isA(T), i.e., T is more general)
    • - *
    • {@code coalesce(Optional, S) => common_type(T, S)} (unwraps optional)
    • - *
    • {@code coalesce(T, Optional) => common_type(T, S)} (unwraps optional)
    • - *
    • {@code coalesce(Optional, Optional) => Optional}
    • + *
    • {@code coalesce(Optional, T) => T} (unwraps optional)
    • + *
    • {@code coalesce(T, Optional) => T} (unwraps optional)
    • + *
    • {@code coalesce(Optional, Optional) => Optional}
    • *
    * - *

    Special handling for AnyType: Since AnyType can masquerade as any type, when coalescing - * with a concrete type, the concrete type is used as the result type. - * *

    Supports chaining: * {@code coalesce(opt1, coalesce(opt2, coalesce(opt3, default)))} * @@ -80,29 +72,6 @@ public R accept(ExpressionVisitor visitor) { return visitor.visitCoalesce(args.get(0), args.get(1)); } - // Type checking rules for coalesce: - // - // This function returns the first non-empty value with type-safe fallback handling. - // The type resolution follows these rules: - // - // 1. If both types are identical, use that type - // 2. Special handling for AnyType: Since AnyType.isA() always returns true (it can masquerade as any type), we - // need to handle it specially. When coalescing AnyType with a concrete type, we use the concrete type as the - // result, since AnyType can adapt to it at runtime. - // 3. For other types, we use the isA relationship to find the more general type: - // - If left.isA(right), then right is more general, use right - // - If right.isA(left), then left is more general, use left - // 4. If no type relationship exists, throw a type mismatch error - // - // The result is wrapped in Optional only if BOTH inputs are Optional, since coalesce(optional, required) - // guarantees a non-empty result. - // - // Examples: - // - coalesce(String, String) => String - // - coalesce(Optional, String) => String - // - coalesce(Optional, Optional) => Optional - // - coalesce(String, AnyType) => String (AnyType adapts) - // - coalesce(SubType, SuperType) => SuperType (more general) @Override public Type typeCheck(Scope scope) { List args = getArguments(); @@ -113,38 +82,23 @@ public Type typeCheck(Scope scope) { Type leftType = args.get(0).typeCheck(scope); Type rightType = args.get(1).typeCheck(scope); - - // Find the least upper bound (most specific common type) - Type resultType = lubForCoalesce(leftType, rightType); - - // Only return Optional if both sides can be empty - if (leftType instanceof OptionalType && rightType instanceof OptionalType) { - return Type.optionalType(resultType); + Type leftInner = getInnerType(leftType); + Type rightInner = getInnerType(rightType); + + // Both must be the same type (after unwrapping optionals) + if (!leftInner.equals(rightInner)) { + throw new IllegalArgumentException(String.format( + "Type mismatch in coalesce: %s and %s must be the same type", + leftType, + rightType)); } - return resultType; - } - - // Finds the least upper bound (LUB) for coalesce type checking. - // The LUB is the most specific type that both input types can be assigned to. - // Special handling for AnyType: it adapts to concrete types rather than dominating them. - private static Type lubForCoalesce(Type a, Type b) { - Type ai = getInnerType(a); - Type bi = getInnerType(b); - - if (ai.equals(bi)) { - return ai; - } else if (ai instanceof AnyType) { - return bi; // AnyType adapts to concrete type - } else if (bi instanceof AnyType) { - return ai; // AnyType adapts to concrete type - } else if (ai.isA(bi)) { - return bi; // bi is more general - } else if (bi.isA(ai)) { - return ai; // ai is more general + // Only return Optional if both sides are optional + if (leftType instanceof OptionalType && rightType instanceof OptionalType) { + return Type.optionalType(leftInner); } - throw new IllegalArgumentException("Type mismatch in coalesce: " + a + " and " + b + " have no common type"); + return leftInner; } private static Type getInnerType(Type t) { From 7c5ee43df69f02370d234189bcb0e2ba93ac0d6c Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Sat, 16 Aug 2025 23:56:48 -0500 Subject: [PATCH 19/23] Add version validation to rules engine This now ensures that syntax elements, functions, etc all have an availableSince version that does not exceed the version decalred on a rules engine trait. --- .../rulesengine/language/EndpointRuleSet.java | 11 ++ .../rulesengine/language/RulesVersion.java | 143 ++++++++++++++++++ .../language/syntax/SyntaxElement.java | 5 +- .../expressions/functions/Coalesce.java | 5 +- .../logic/bdd/EndpointBddTrait.java | 25 ++- .../smithy/rulesengine/logic/cfg/Cfg.java | 9 +- .../RulesEngineVersionValidator.java | 113 ++++++++++++++ ...e.amazon.smithy.model.validation.Validator | 1 + .../errorfiles/invalid/bad-version-use.errors | 3 + .../errorfiles/invalid/bad-version-use.smithy | 44 ++++++ 10 files changed, 338 insertions(+), 21 deletions(-) create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/RulesVersion.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RulesEngineVersionValidator.java create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.errors create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.smithy diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/EndpointRuleSet.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/EndpointRuleSet.java index 9f25f9ce426..57c1b423d22 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/EndpointRuleSet.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/EndpointRuleSet.java @@ -58,6 +58,7 @@ private static final class LazyEndpointComponentFactoryHolder { private final List rules; private final SourceLocation sourceLocation; private final String version; + private final RulesVersion rulesVersion; private EndpointRuleSet(Builder builder) { super(); @@ -65,6 +66,7 @@ private EndpointRuleSet(Builder builder) { rules = builder.rules.copy(); sourceLocation = SmithyBuilder.requiredState("source", builder.getSourceLocation()); version = SmithyBuilder.requiredState(VERSION, builder.version); + rulesVersion = RulesVersion.of(version); } /** @@ -130,6 +132,15 @@ public String getVersion() { return version; } + /** + * Get the parsed rules engine version. + * + * @return parsed version. + */ + public RulesVersion getRulesVersion() { + return rulesVersion; + } + public Type typeCheck() { return typeCheck(new Scope<>()); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/RulesVersion.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/RulesVersion.java new file mode 100644 index 00000000000..d0302067dce --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/RulesVersion.java @@ -0,0 +1,143 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language; + +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import software.amazon.smithy.utils.SmithyUnstableApi; +import software.amazon.smithy.utils.StringUtils; + +/** + * Represents the rules engine version with major and minor components. + */ +@SmithyUnstableApi +public final class RulesVersion implements Comparable { + + private static final ConcurrentHashMap CACHE = new ConcurrentHashMap<>(); + + public static final RulesVersion V1_0 = of("1.0"); + public static final RulesVersion V1_1 = of("1.1"); + + private final int major; + private final int minor; + private final String stringValue; + private final int hashCode; + + private RulesVersion(int major, int minor) { + if (major < 0 || minor < 0) { + throw new IllegalArgumentException("Version components must be non-negative"); + } + + this.major = major; + this.minor = minor; + this.stringValue = major + "." + minor; + this.hashCode = Objects.hash(major, minor); + } + + /** + * Creates a RulesVersion from a string representation. + * + * @param version the version string (e.g., "1.0", "1.2") + * @return the RulesVersion instance + * @throws IllegalArgumentException if the version string is invalid + */ + public static RulesVersion of(String version) { + return CACHE.computeIfAbsent(version, RulesVersion::parse); + } + + /** + * Creates a RulesVersion from components. + * + * @param major the major version + * @param minor the minor version + * @return the RulesVersion instance + */ + public static RulesVersion of(int major, int minor) { + String key = major + "." + minor; + return CACHE.computeIfAbsent(key, k -> new RulesVersion(major, minor)); + } + + private static RulesVersion parse(String version) { + if (StringUtils.isEmpty(version)) { + throw new IllegalArgumentException("Version string cannot be null or empty"); + } + + String[] parts = version.split("\\."); + if (parts.length < 2) { + throw new IllegalArgumentException("Invalid version: `" + version + "`. Expected format: major.minor"); + } + + try { + int major = Integer.parseInt(parts[0]); + int minor = Integer.parseInt(parts[1]); + return new RulesVersion(major, minor); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid version format: " + version, e); + } + } + + /** + * Gets the major version component. + * + * @return the major version + */ + public int getMajor() { + return major; + } + + /** + * Gets the minor version component. + * + * @return the minor version + */ + public int getMinor() { + return minor; + } + + /** + * Checks if this version is at least the specified version. + * + * @param other the version to compare against + * @return true if this version >= other + */ + public boolean isAtLeast(RulesVersion other) { + return compareTo(other) >= 0; + } + + @Override + public int compareTo(RulesVersion other) { + if (this == other) { + return 0; + } + + int result = Integer.compare(major, other.major); + if (result != 0) { + return result; + } else { + return Integer.compare(minor, other.minor); + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } else if (!(obj instanceof RulesVersion)) { + return false; + } + RulesVersion other = (RulesVersion) obj; + return major == other.major && minor == other.minor; + } + + @Override + public int hashCode() { + return hashCode; + } + + @Override + public String toString() { + return stringValue; + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/SyntaxElement.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/SyntaxElement.java index 2a75580944f..fc798ec8301 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/SyntaxElement.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/SyntaxElement.java @@ -4,6 +4,7 @@ */ package software.amazon.smithy.rulesengine.language.syntax; +import software.amazon.smithy.rulesengine.language.RulesVersion; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; @@ -27,8 +28,8 @@ public abstract class SyntaxElement implements ToCondition, ToExpression { * * @return the version this is available since. */ - public String availableSince() { - return "1.0"; + public RulesVersion availableSince() { + return RulesVersion.V1_0; } /** diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java index af68b21207a..7797fe2f4be 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java @@ -6,6 +6,7 @@ import java.util.Arrays; import java.util.List; +import software.amazon.smithy.rulesengine.language.RulesVersion; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; @@ -62,8 +63,8 @@ public static Coalesce ofExpressions(ToExpression arg1, ToExpression arg2) { } @Override - public String availableSince() { - return "1.1"; + public RulesVersion availableSince() { + return RulesVersion.V1_1; } @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/EndpointBddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/EndpointBddTrait.java index 3b114785566..20dc2616867 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/EndpointBddTrait.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/EndpointBddTrait.java @@ -8,7 +8,6 @@ import java.io.DataOutputStream; import java.io.IOException; import java.io.UncheckedIOException; -import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.ArrayList; @@ -23,6 +22,7 @@ import software.amazon.smithy.model.traits.AbstractTrait; import software.amazon.smithy.model.traits.AbstractTraitBuilder; import software.amazon.smithy.model.traits.Trait; +import software.amazon.smithy.rulesengine.language.RulesVersion; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; @@ -38,7 +38,7 @@ public final class EndpointBddTrait extends AbstractTrait implements ToSmithyBuilder { public static final ShapeId ID = ShapeId.from("smithy.rules#endpointBdd"); - private static final BigDecimal MIN_VERSION = new BigDecimal("1.1"); + private static final RulesVersion MIN_VERSION = RulesVersion.V1_1; private static final Set ALLOWED_PROPERTIES = SetUtils.of( "version", "parameters", @@ -48,7 +48,7 @@ public final class EndpointBddTrait extends AbstractTrait implements ToSmithyBui "nodes", "nodeCount"); - private final String version; + private final RulesVersion version; private final Parameters parameters; private final List conditions; private final List results; @@ -62,8 +62,7 @@ private EndpointBddTrait(Builder builder) { this.results = SmithyBuilder.requiredState("results", builder.results); this.bdd = SmithyBuilder.requiredState("bdd", builder.bdd); - BigDecimal v = new BigDecimal(version); - if (v.compareTo(MIN_VERSION) < 0) { + if (version.compareTo(MIN_VERSION) < 0) { throw new IllegalArgumentException("Rules engine version for endpointBdd trait must be >= " + MIN_VERSION); } } @@ -83,9 +82,9 @@ public static EndpointBddTrait from(Cfg cfg) { } // Automatically convert 1.0 versions of the decision tree to 1.1 for the minimum version of the BDD trait. - String version = cfg.getVersion(); - if (version.equals("1.0")) { - version = "1.1"; + RulesVersion version = cfg.getVersion(); + if (version.equals(RulesVersion.V1_0)) { + version = RulesVersion.V1_1; } return builder() @@ -138,7 +137,7 @@ public Bdd getBdd() { * * @return the rules engine version */ - public String getVersion() { + public RulesVersion getVersion() { return version; } @@ -155,7 +154,7 @@ public EndpointBddTrait transform(Function t @Override protected Node createNode() { ObjectNode.Builder builder = ObjectNode.builder(); - builder.withMember("version", version); + builder.withMember("version", version.toString()); builder.withMember("parameters", parameters.toNode()); ArrayNode.Builder conditionBuilder = ArrayNode.builder(); @@ -194,7 +193,7 @@ protected Node createNode() { public static EndpointBddTrait fromNode(Node node) { ObjectNode obj = node.expectObjectNode(); obj.warnIfAdditionalProperties(ALLOWED_PROPERTIES); - String version = obj.expectStringMember("version").getValue(); + RulesVersion version = RulesVersion.of(obj.expectStringMember("version").getValue()); Parameters params = Parameters.fromNode(obj.expectObjectMember("parameters")); List conditions = obj.expectArrayMember("conditions").getElementsAs(Condition::fromNode); @@ -282,7 +281,7 @@ public Builder toBuilder() { * Builder for BddTrait. */ public static final class Builder extends AbstractTraitBuilder { - private String version = "1.1"; + private RulesVersion version = RulesVersion.V1_1; private Parameters parameters; private List conditions; private List results; @@ -296,7 +295,7 @@ private Builder() {} * @param version Version to set (e.g., 1.1). * @return this builder */ - public Builder version(String version) { + public Builder version(RulesVersion version) { this.version = version; return this; } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java index 723bfb39363..3924242ef47 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/Cfg.java @@ -18,6 +18,7 @@ import java.util.Objects; import java.util.Set; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.RulesVersion; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; @@ -46,16 +47,16 @@ public final class Cfg implements Iterable { // Lazily computed condition data private Condition[] conditions; private Map conditionToIndex; - private final String version; + private final RulesVersion version; Cfg(EndpointRuleSet ruleSet, CfgNode root) { this( ruleSet == null ? Parameters.builder().build() : ruleSet.getParameters(), root, - ruleSet == null ? "1.1" : ruleSet.getVersion()); + ruleSet == null ? RulesVersion.V1_1 : ruleSet.getRulesVersion()); } - Cfg(Parameters parameters, CfgNode root, String version) { + Cfg(Parameters parameters, CfgNode root, RulesVersion version) { this.root = SmithyBuilder.requiredState("root", root); this.version = version; this.parameters = parameters; @@ -80,7 +81,7 @@ public static Cfg from(EndpointRuleSet ruleSet) { * * @return endpoint ruleset version. */ - public String getVersion() { + public RulesVersion getVersion() { return version; } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RulesEngineVersionValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RulesEngineVersionValidator.java new file mode 100644 index 00000000000..4d58c1e2417 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RulesEngineVersionValidator.java @@ -0,0 +1,113 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.validators; + +import java.util.ArrayList; +import java.util.List; +import software.amazon.smithy.model.FromSourceLocation; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.SourceLocation; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.validation.AbstractValidator; +import software.amazon.smithy.model.validation.ValidationEvent; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.syntax.SyntaxElement; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.rulesengine.logic.bdd.EndpointBddTrait; +import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; + +/** + * Validates that the rules engine version of a trait only uses compatible features. + */ +public final class RulesEngineVersionValidator extends AbstractValidator { + + @Override + public List validate(Model model) { + List events = new ArrayList<>(); + + for (ServiceShape s : model.getServiceShapesWithTrait(EndpointBddTrait.class)) { + validateBdd(events, s, s.expectTrait(EndpointBddTrait.class)); + } + + for (ServiceShape s : model.getServiceShapesWithTrait(EndpointRuleSetTrait.class)) { + validateTree(events, s, s.expectTrait(EndpointRuleSetTrait.class)); + } + + return events; + } + + private void validateBdd(List events, ServiceShape service, EndpointBddTrait trait) { + RulesVersion version = trait.getVersion(); + + for (Condition condition : trait.getConditions()) { + validateSyntaxElement(events, service, condition, version); + } + + for (Rule result : trait.getResults()) { + validateRule(events, service, result, version); + } + } + + private void validateTree(List events, ServiceShape service, EndpointRuleSetTrait trait) { + EndpointRuleSet rules = trait.getEndpointRuleSet(); + RulesVersion version = rules.getRulesVersion(); + for (Rule rule : rules.getRules()) { + validateRule(events, service, rule, version); + } + } + + private void validateRule(List events, ServiceShape service, Rule rule, RulesVersion version) { + for (Condition condition : rule.getConditions()) { + validateSyntaxElement(events, service, condition, version); + validateSyntaxElement(events, service, condition.getFunction(), version); + for (Expression arg : condition.getFunction().getArguments()) { + validateSyntaxElement(events, service, arg, version); + } + } + + if (rule instanceof TreeRule) { + for (Rule nestedRule : ((TreeRule) rule).getRules()) { + validateRule(events, service, nestedRule, version); + } + } else if (rule instanceof EndpointRule) { + EndpointRule endpointRule = (EndpointRule) rule; + validateSyntaxElement(events, service, endpointRule.getEndpoint().getUrl(), version); + for (List headerValues : endpointRule.getEndpoint().getHeaders().values()) { + for (Expression expr : headerValues) { + validateSyntaxElement(events, service, expr, version); + } + } + } else if (rule instanceof ErrorRule) { + validateSyntaxElement(events, service, ((ErrorRule) rule).getError(), version); + } + } + + private void validateSyntaxElement( + List events, + ServiceShape service, + SyntaxElement element, + RulesVersion declaredVersion + ) { + RulesVersion requiredVersion = element.availableSince(); + + if (!declaredVersion.isAtLeast(requiredVersion)) { + SourceLocation s = element instanceof FromSourceLocation + ? ((FromSourceLocation) element).getSourceLocation() + : element.toExpression().getSourceLocation(); + String msg = String.format( + "%s requires rules engine version >= %s, but ruleset declares version %s", + element.getClass().getSimpleName(), + requiredVersion, + declaredVersion); + events.add(error(service, s, msg)); + } + } +} diff --git a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator index d1243825d5b..18ef8dab55c 100644 --- a/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator +++ b/smithy-rules-engine/src/main/resources/META-INF/services/software.amazon.smithy.model.validation.Validator @@ -8,3 +8,4 @@ software.amazon.smithy.rulesengine.validators.RuleSetParamMissingDocsValidator software.amazon.smithy.rulesengine.validators.RuleSetParameterValidator software.amazon.smithy.rulesengine.validators.RuleSetTestCaseValidator software.amazon.smithy.rulesengine.validators.BddTraitValidator +software.amazon.smithy.rulesengine.validators.RulesEngineVersionValidator diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.errors new file mode 100644 index 00000000000..2acfc1d2cc7 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.errors @@ -0,0 +1,3 @@ +[WARNING] example#FizzBuzz: This shape applies a trait that is unstable: smithy.rules#clientContextParams | UnstableTrait +[WARNING] example#FizzBuzz: This shape applies a trait that is unstable: smithy.rules#endpointRuleSet | UnstableTrait +[ERROR] example#FizzBuzz: Coalesce requires rules engine version >= 1.1, but ruleset declares version 1.0 | RulesEngineVersion diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.smithy new file mode 100644 index 00000000000..b4f85945205 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.smithy @@ -0,0 +1,44 @@ +$version: "2.0" + +namespace example + +use smithy.rules#clientContextParams +use smithy.rules#endpointRuleSet + +@clientContextParams( + foo: {type: "string", documentation: "a client string parameter"} +) +@endpointRuleSet({ + "version": "1.0", + "parameters": { + "foo": { + "type": "String", + "documentation": "docs" + } + }, + "rules": [ + { + "conditions": [ + { + "fn": "coalesce", + "argv": [ + {"ref": "foo"}, + "" + ], + "assign": "hi" + } + ], + "documentation": "base rule", + "endpoint": { + "url": "https://{hi}.amazonaws.com", + "headers": {} + }, + "type": "endpoint" + } + ] +}) +service FizzBuzz { + operations: [GetResource] +} + +operation GetResource {} From 4a8aa352ec0a875ef419102a29393a21354b3d84 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Tue, 19 Aug 2025 15:52:38 -0500 Subject: [PATCH 20/23] Use NO_MATCH rule for false nodes instead --- .../smithy/rulesengine/logic/bdd/EndpointBddTrait.java | 2 ++ .../smithy/rulesengine/logic/cfg/CfgBuilder.java | 10 ++++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/EndpointBddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/EndpointBddTrait.java index 20dc2616867..a044a116b35 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/EndpointBddTrait.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/EndpointBddTrait.java @@ -172,6 +172,8 @@ protected Node createNode() { Rule result = results.get(i); if (result instanceof NoMatchRule) { throw new IllegalStateException("NoMatch rules can only appear at rule index 0. Found at index " + i); + } else if (result == null) { + throw new IllegalStateException("BDD result is null at index " + i); } resultBuilder.withValue(result); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java index f28c5dbcacb..f7583be6921 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java @@ -34,6 +34,7 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; import software.amazon.smithy.rulesengine.logic.ConditionReference; @@ -66,7 +67,7 @@ public final class CfgBuilder { private int convergenceNodesCreated = 0; private int phiVariableCounter = 0; - private String version = "1.1"; + private String version; public CfgBuilder(EndpointRuleSet ruleSet) { // Apply SSA transform to ensure globally unique variable names @@ -297,9 +298,10 @@ private CfgNode buildConvergenceNode( Condition coalesceCondition = createCoalesceCondition(phiVar, versions); ConditionReference condRef = new ConditionReference(coalesceCondition, false); - // Use terminal (no match) as false branch to prevent BDD optimization - // The coalesce always succeeds, so we never take the false branch - resultNode = new ConditionNode(condRef, resultNode, ResultNode.terminal()); + // Use NO_MATCH as false branch since coalesce should always succeed + // This ensures we get a valid result index instead of null that would result from FALSE. + CfgNode noMatch = createResult(NoMatchRule.INSTANCE); + resultNode = new ConditionNode(condRef, resultNode, noMatch); // Log with deterministic ordering LOGGER.fine(() -> { From 9986fc89a182b72795234e53e790cecc28202005 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Tue, 19 Aug 2025 17:25:52 -0500 Subject: [PATCH 21/23] Make coalesce variadic This commit adds support for variadic functions in the rules engine and makes coalesce variadic. This makes things like phi functions a shallow list of expressions rather than a massive list of nested binary expressions. This also makes it easier to optimize phi nodes without requiring peephole checks in compilers for optimization. --- .../rules-engine/standard-library.rst | 43 +++--- .../language/evaluation/RuleEvaluator.java | 12 +- .../syntax/expressions/ExpressionVisitor.java | 10 +- .../expressions/functions/Coalesce.java | 88 ++++++----- .../functions/FunctionDefinition.java | 13 ++ .../functions/LibraryFunction.java | 86 +++++++---- .../logic/bdd/InitialOrdering.java | 1 - .../rulesengine/logic/cfg/CfgBuilder.java | 33 +---- .../syntax/functions/CoalesceTest.java | 140 ++++++++++++++---- .../errorfiles/invalid/bad-version-use.errors | 2 - .../errorfiles/invalid/bad-version-use.smithy | 1 + .../valid/coalesce-three-args.errors | 0 .../valid/coalesce-three-args.smithy | 93 ++++++++++++ 13 files changed, 368 insertions(+), 154 deletions(-) create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-three-args.errors create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-three-args.smithy diff --git a/docs/source-2.0/additional-specs/rules-engine/standard-library.rst b/docs/source-2.0/additional-specs/rules-engine/standard-library.rst index 1b0d3fa7da3..fb27632e216 100644 --- a/docs/source-2.0/additional-specs/rules-engine/standard-library.rst +++ b/docs/source-2.0/additional-specs/rules-engine/standard-library.rst @@ -44,24 +44,28 @@ parameter is equal to the value ``false``: ===================== Summary - Evaluates the first argument and returns the result if it is present, otherwise evaluates and returns the result - of the second argument. + Evaluates arguments in order and returns the first non-empty result, otherwise returns the result of the last + argument. Argument types - * value1: ``T`` or ``option`` - * value2: ``T`` or ``option`` + * This function is variadic and requires two or more arguments, each of type ``T`` or ``option`` + * All arguments must have the same inner type ``T`` Return type - * ``coalesce(T, T)`` → ``T`` - * ``coalesce(option, T)`` → ``T`` - * ``coalesce(T, option)`` → ``T`` - * ``coalesce(option, option)`` → ``option`` + * ``coalesce(T, T, ...)`` → ``T`` + * ``coalesce(option, T, ...)`` → ``T`` (if any argument is non-optional) + * ``coalesce(T, option, ...)`` → ``T`` (if any argument is non-optional) + * ``coalesce(option, option, ...)`` → ``option`` (if all arguments are optional) Since 1.1 -The ``coalesce`` function provides null-safe chaining by returning the result of the first argument if it returns a -value, otherwise returns the result of the second argument. This is particularly useful for providing default values -for optional parameters, chaining multiple optional values together, and related optimizations. +The ``coalesce`` function provides null-safe chaining by evaluating arguments in order and returning the first +non-empty result. If all arguments evaluate to empty, it returns the result of the last argument. This is +particularly useful for providing default values for optional parameters, chaining multiple optional values +together, and related optimizations. -The following example demonstrates chaining multiple ``coalesce`` calls to try several optional values +The function accepts two or more arguments, all of which must have the same inner type after unwrapping any +optionals. The return type is ``option`` only if all arguments are ``option``; otherwise it returns ``T``. + +The following example demonstrates using ``coalesce`` with multiple arguments to try several optional values in sequence: .. code-block:: json @@ -70,20 +74,15 @@ in sequence: "fn": "coalesce", "argv": [ {"ref": "customEndpoint"}, - { - "fn": "coalesce", - "argv": [ - {"ref": "regionalEndpoint"}, - {"ref": "defaultEndpoint"} - ] - } + {"ref": "regionalEndpoint"}, + {"ref": "defaultEndpoint"} ] } .. important:: - Both arguments must be of the same type after unwrapping any optionals (types are known at compile time and do not - need to be validated at runtime). Note that the first result is returned even if it's ``false`` (coalesce is - looking for a *non-empty* value). + All arguments must have the same type after unwrapping any optionals (types are known at compile time and do not + need to be validated at runtime). Note that the first non-empty result is returned even if it's ``false`` + (coalesce is looking for a *non-empty* value, not a truthy value). .. _rules-engine-standard-library-getAttr: diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java index 30372140c1f..a15352396c8 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java @@ -184,9 +184,15 @@ public Value visitIsSet(Expression fn) { } @Override - public Value visitCoalesce(Expression left, Expression right) { - Value leftValue = left.accept(this); - return leftValue.isEmpty() ? right.accept(this) : leftValue; + public Value visitCoalesce(List expressions) { + Value result = Value.emptyValue(); + for (Expression exp : expressions) { + result = exp.accept(this); + if (!result.isEmpty()) { + return result; + } + } + return result; } @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java index 2086426a76e..1557b529b52 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java @@ -9,7 +9,6 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionDefinition; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; -import software.amazon.smithy.utils.ListUtils; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -54,12 +53,11 @@ public interface ExpressionVisitor { /** * Visits a coalesce function. * - * @param left the first value to check. - * @param right the second value to check. + * @param expressions The coalesce expressions to check. * @return the value from the visitor. */ - default R visitCoalesce(Expression left, Expression right) { - return visitLibraryFunction(Coalesce.getDefinition(), ListUtils.of(left, right)); + default R visitCoalesce(List expressions) { + return visitLibraryFunction(Coalesce.getDefinition(), expressions); } /** @@ -121,7 +119,7 @@ public R visitIsSet(Expression fn) { } @Override - public R visitCoalesce(Expression left, Expression right) { + public R visitCoalesce(List expressions) { return getDefault(); } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java index 7797fe2f4be..be91ab0d508 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java @@ -4,8 +4,9 @@ */ package software.amazon.smithy.rulesengine.language.syntax.expressions.functions; -import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Optional; import software.amazon.smithy.rulesengine.language.RulesVersion; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; @@ -18,19 +19,17 @@ /** * A coalesce function that returns the first non-empty value. - * At runtime, returns the left value unless it's EmptyValue, in which case returns the right value. + * + *

    This variadic function requires two or more arguments. At runtime, returns the first arguments that returns a + * non-EmptyValue, otherwise returns the result of the last argument. * *

    Type checking rules: *

      - *
    • {@code coalesce(T, T) => T} (same types)
    • - *
    • {@code coalesce(Optional, T) => T} (unwraps optional)
    • - *
    • {@code coalesce(T, Optional) => T} (unwraps optional)
    • - *
    • {@code coalesce(Optional, Optional) => Optional}
    • + *
    • {@code coalesce(T, T, T) => T} (same types)
    • + *
    • {@code coalesce(Optional, T, T) => T} (any non-optional makes result non-optional)
    • + *
    • {@code coalesce(Optional, Optional, Optional) => Optional} (all optional)
    • *
    * - *

    Supports chaining: - * {@code coalesce(opt1, coalesce(opt2, coalesce(opt3, default)))} - * *

    Available since: rules engine 1.1. */ @SmithyUnstableApi @@ -52,14 +51,23 @@ public static Definition getDefinition() { } /** - * Creates a {@link Coalesce} function from the given expressions. + * Creates a {@link Coalesce} function from variadic expressions. + * + * @param args the expressions to coalesce + * @return The resulting {@link Coalesce} function. + */ + public static Coalesce ofExpressions(ToExpression... args) { + return DEFINITION.createFunction(FunctionNode.ofExpressions(ID, args)); + } + + /** + * Creates a {@link Coalesce} function from a list of expressions. * - * @param arg1 the first expression, typically optional. - * @param arg2 the second expression, used as fallback. + * @param args the expressions to coalesce * @return The resulting {@link Coalesce} function. */ - public static Coalesce ofExpressions(ToExpression arg1, ToExpression arg2) { - return DEFINITION.createFunction(FunctionNode.ofExpressions(ID, arg1, arg2)); + public static Coalesce ofExpressions(List args) { + return ofExpressions(args.toArray(new ToExpression[0])); } @Override @@ -69,37 +77,38 @@ public RulesVersion availableSince() { @Override public R accept(ExpressionVisitor visitor) { - List args = getArguments(); - return visitor.visitCoalesce(args.get(0), args.get(1)); + return visitor.visitCoalesce(getArguments()); } @Override public Type typeCheck(Scope scope) { List args = getArguments(); - - if (args.size() != 2) { - throw new IllegalArgumentException("Coalesce requires exactly 2 arguments, got " + args.size()); - } - - Type leftType = args.get(0).typeCheck(scope); - Type rightType = args.get(1).typeCheck(scope); - Type leftInner = getInnerType(leftType); - Type rightInner = getInnerType(rightType); - - // Both must be the same type (after unwrapping optionals) - if (!leftInner.equals(rightInner)) { - throw new IllegalArgumentException(String.format( - "Type mismatch in coalesce: %s and %s must be the same type", - leftType, - rightType)); + if (args.size() < 2) { + throw new IllegalArgumentException("Coalesce requires at least 2 arguments, got " + args.size()); } - // Only return Optional if both sides are optional - if (leftType instanceof OptionalType && rightType instanceof OptionalType) { - return Type.optionalType(leftInner); + // Get the first argument's type as the baseline + Type firstType = args.get(0).typeCheck(scope); + Type baseInnerType = getInnerType(firstType); + boolean hasNonOptional = !(firstType instanceof OptionalType); + + // Check all other arguments match the base type + for (int i = 1; i < args.size(); i++) { + Type argType = args.get(i).typeCheck(scope); + Type innerType = getInnerType(argType); + + if (!innerType.equals(baseInnerType)) { + throw new IllegalArgumentException(String.format( + "Type mismatch in coalesce at argument %d: expected %s but got %s", + i + 1, + baseInnerType, + innerType)); + } + + hasNonOptional = hasNonOptional || !(argType instanceof OptionalType); } - return leftInner; + return hasNonOptional ? baseInnerType : Type.optionalType(baseInnerType); } private static Type getInnerType(Type t) { @@ -119,7 +128,12 @@ public String getId() { @Override public List getArguments() { - return Arrays.asList(Type.anyType(), Type.anyType()); + return Collections.emptyList(); + } + + @Override + public Optional getVariadicArguments() { + return Optional.of(Type.anyType()); } @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java index e8878c85ed6..c07cad7b7fe 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/FunctionDefinition.java @@ -5,6 +5,7 @@ package software.amazon.smithy.rulesengine.language.syntax.expressions.functions; import java.util.List; +import java.util.Optional; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.evaluation.value.Value; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -26,6 +27,18 @@ public interface FunctionDefinition { */ List getArguments(); + /** + * Gets the type of variadic arguments if this function accepts them. + * + *

    When present, the function accepts any number of additional arguments of this type after the fixed arguments + * from getArguments(). + * + * @return the variadic argument type, or empty if not variadic + */ + default Optional getVariadicArguments() { + return Optional.empty(); + } + /** * The return type of this function definition. * @return The function return type diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java index 3f32dfd6252..c5daeab89d1 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/LibraryFunction.java @@ -8,6 +8,7 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.Set; import software.amazon.smithy.model.SourceException; import software.amazon.smithy.model.SourceLocation; @@ -105,7 +106,7 @@ public SourceLocation getSourceLocation() { protected Type typeCheckLocal(Scope scope) { RuleError.context(String.format("while typechecking the invocation of %s", definition.getId()), this, () -> { try { - checkTypeSignature(definition.getArguments(), functionNode.getArguments(), scope); + checkTypeSignature(scope); } catch (InnerParseError e) { throw new RuntimeException(e.getMessage()); } @@ -113,37 +114,64 @@ protected Type typeCheckLocal(Scope scope) { return definition.getReturnType(); } - private void checkTypeSignature(List expectedArgs, List actualArguments, Scope scope) + private void checkTypeSignature(Scope scope) throws InnerParseError { + List expectedArgs = definition.getArguments(); + Optional variadicType = definition.getVariadicArguments(); + List actualArguments = functionNode.getArguments(); + + if (variadicType.isPresent()) { + // check we have at least the fixed arguments + if (actualArguments.size() < expectedArgs.size()) { + throw new InnerParseError(String.format("Expected at least %s arguments but found %s", + expectedArgs.size(), + actualArguments.size())); + } + // check fixed arguments + for (int i = 0; i < expectedArgs.size(); i++) { + checkArgument(i, expectedArgs.get(i), actualArguments.get(i), scope); + } + // check variadic arguments + Type varType = variadicType.get(); + for (int i = expectedArgs.size(); i < actualArguments.size(); i++) { + checkArgument(i, varType, actualArguments.get(i), scope); + } + } else { + // Non-variadic, so exact count required + if (expectedArgs.size() != actualArguments.size()) { + throw new InnerParseError(String.format("Expected %s arguments but found %s", + expectedArgs.size(), + actualArguments.size())); + } + // check all positional arguments + for (int i = 0; i < expectedArgs.size(); i++) { + checkArgument(i, expectedArgs.get(i), actualArguments.get(i), scope); + } + } + } + + private void checkArgument(int index, Type expected, Expression actual, Scope scope) throws InnerParseError { - if (expectedArgs.size() != actualArguments.size()) { - throw new InnerParseError( - String.format( - "Expected %s arguments but found %s", - expectedArgs.size(), - actualArguments)); + Type actualType = actual.typeCheck(scope); + if (expected.isA(actualType)) { + return; } - for (int i = 0; i < expectedArgs.size(); i++) { - Type expected = expectedArgs.get(i); - Type actual = actualArguments.get(i).typeCheck(scope); - if (!expected.isA(actual)) { - Type optAny = Type.optionalType(Type.anyType()); - String hint = ""; - if (actual.isA(optAny) && !expected.isA(optAny) - && actual.expectOptionalType().inner().equals(expected)) { - hint = String.format( - "hint: use `assign` in a condition or `isSet(%s)` to prove that this value is non-null", - actualArguments.get(i)); - hint = StringUtils.indent(hint, 2); - } - throw new InnerParseError( - String.format( - "Unexpected type in the %s argument: Expected %s but found %s%n%s", - ordinal(i + 1), - expected, - actual, - hint)); - } + + Type optAny = Type.optionalType(Type.anyType()); + String hint = ""; + if (actualType.isA(optAny) + && !expected.isA(optAny) + && actualType.expectOptionalType().inner().equals(expected)) { + hint = String.format( + "hint: use `assign` in a condition or `isSet(%s)` to prove that this value is non-null", + actual); + hint = StringUtils.indent(hint, 2); } + + throw new InnerParseError(String.format("Unexpected type in the %s argument: Expected %s but found %s%n%s", + ordinal(index + 1), + expected, + actualType, + hint)); } private static String ordinal(int arg) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrdering.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrdering.java index 39c8aabc277..7ed2dba09df 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrdering.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/InitialOrdering.java @@ -51,7 +51,6 @@ public List orderConditions(Condition[] conditions) { long elapsed = System.currentTimeMillis() - startTime; LOGGER.info(() -> String.format("Initial ordering: %d conditions in %dms", conditions.length, elapsed)); - result.forEach(System.out::println); return result; } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java index f7583be6921..fdbf9593a8e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java @@ -15,6 +15,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.TreeSet; import java.util.logging.Logger; import software.amazon.smithy.rulesengine.language.Endpoint; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; @@ -390,36 +391,18 @@ private Rule rewriteErrorRule(ErrorRule rule, TreeRewriter rewriter) { } private Condition createCoalesceCondition(String resultVar, Set versions) { - // De-duplicate and sort for determinism - List inputs = new ArrayList<>(new LinkedHashSet<>(versions)); - Collections.sort(inputs); - - if (inputs.isEmpty()) { + if (versions.isEmpty()) { throw new IllegalArgumentException("Cannot create coalesce with no versions"); } - Expression coalesced = buildCoalesceExpressionFromPrepared(inputs); - - return Condition.builder() - .fn((LibraryFunction) coalesced) - .result(Identifier.of(resultVar)) - .build(); - } - - private Expression buildCoalesceExpressionFromPrepared(List inputs) { - Expression coalesced = Expression.getReference(Identifier.of(inputs.get(0))); - - if (inputs.size() == 1) { - // Hack: coalesce(x, x) for single element - return Coalesce.ofExpressions(coalesced, coalesced); + // TreeSet for both deduplication and deterministic ordering + Set inputs = new TreeSet<>(versions); + List refs = new ArrayList<>(inputs.size()); + for (String input : inputs) { + refs.add(Expression.getReference(Identifier.of(input))); } - // Build coalesce chain for multiple versions - for (int i = 1; i < inputs.size(); i++) { - Expression ref = Expression.getReference(Identifier.of(inputs.get(i))); - coalesced = Coalesce.ofExpressions(coalesced, ref); - } - return coalesced; + return Condition.builder().fn(Coalesce.ofExpressions(refs)).result(Identifier.of(resultVar)).build(); } private Rule canonicalizeResult(Rule rule) { diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java index 66e80ce32f4..bf6ac4bb9da 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java @@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.util.Arrays; import org.junit.jupiter.api.Test; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; @@ -19,7 +20,7 @@ public class CoalesceTest { @Test - void testCoalesceWithSameTypes() { + void testCoalesceWithTwoSameTypes() { Expression left = Literal.of("default"); Expression right = Literal.of("fallback"); Coalesce coalesce = Coalesce.ofExpressions(left, right); @@ -30,6 +31,34 @@ void testCoalesceWithSameTypes() { assertEquals(Type.stringType(), resultType); } + @Test + void testCoalesceWithThreeSameTypes() { + Expression first = Literal.of("first"); + Expression second = Literal.of("second"); + Expression third = Literal.of("third"); + Coalesce coalesce = Coalesce.ofExpressions(first, second, third); + + Scope scope = new Scope<>(); + Type resultType = coalesce.typeCheck(scope); + + assertEquals(Type.stringType(), resultType); + } + + @Test + void testCoalesceVariadicWithList() { + Expression first = Literal.of(1); + Expression second = Literal.of(2); + Expression third = Literal.of(3); + Expression fourth = Literal.of(4); + + Coalesce coalesce = Coalesce.ofExpressions(Arrays.asList(first, second, third, fourth)); + + Scope scope = new Scope<>(); + Type resultType = coalesce.typeCheck(scope); + + assertEquals(Type.integerType(), resultType); + } + @Test void testCoalesceWithOptionalLeft() { Expression optionalVar = Expression.getReference(Identifier.of("maybeValue")); @@ -46,7 +75,7 @@ void testCoalesceWithOptionalLeft() { } @Test - void testCoalesceWithBothOptional() { + void testCoalesceWithAllOptional() { Expression var1 = Expression.getReference(Identifier.of("maybe1")); Expression var2 = Expression.getReference(Identifier.of("maybe2")); Coalesce coalesce = Coalesce.ofExpressions(var1, var2); @@ -62,19 +91,39 @@ void testCoalesceWithBothOptional() { } @Test - void testCoalesceWithCompatibleTypes() { - // Test with optional types that should resolve to non-optional - Expression optionalString = Expression.getReference(Identifier.of("optional")); - Expression requiredString = Expression.getReference(Identifier.of("required")); - Coalesce coalesce = Coalesce.ofExpressions(optionalString, requiredString); + void testCoalesceThreeWithAllOptional() { + Expression var1 = Expression.getReference(Identifier.of("maybe1")); + Expression var2 = Expression.getReference(Identifier.of("maybe2")); + Expression var3 = Expression.getReference(Identifier.of("maybe3")); + Coalesce coalesce = Coalesce.ofExpressions(var1, var2, var3); + + Scope scope = new Scope<>(); + scope.insert("maybe1", Type.optionalType(Type.integerType())); + scope.insert("maybe2", Type.optionalType(Type.integerType())); + scope.insert("maybe3", Type.optionalType(Type.integerType())); + + Type resultType = coalesce.typeCheck(scope); + + // All optional means result is optional + assertEquals(Type.optionalType(Type.integerType()), resultType); + } + + @Test + void testCoalesceMixedOptionalAndNonOptional() { + Expression optional1 = Expression.getReference(Identifier.of("optional1")); + Expression required = Expression.getReference(Identifier.of("required")); + Expression optional2 = Expression.getReference(Identifier.of("optional2")); + + Coalesce coalesce = Coalesce.ofExpressions(optional1, required, optional2); Scope scope = new Scope<>(); - scope.insert("optional", Type.optionalType(Type.stringType())); + scope.insert("optional1", Type.optionalType(Type.stringType())); scope.insert("required", Type.stringType()); + scope.insert("optional2", Type.optionalType(Type.stringType())); Type resultType = coalesce.typeCheck(scope); - // When coalescing Optional with String, should return String + // Any non-optional in the chain makes result non-optional assertEquals(Type.stringType(), resultType); } @@ -86,51 +135,84 @@ void testCoalesceWithIncompatibleTypes() { Scope scope = new Scope<>(); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, - () -> coalesce.typeCheck(scope)); + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> coalesce.typeCheck(scope)); assertTrue(ex.getMessage().contains("Type mismatch in coalesce")); + assertTrue(ex.getMessage().contains("argument 2")); } @Test - void testCoalesceNonOptionalWithNonOptional() { - Expression int1 = Literal.of(42); - Expression int2 = Literal.of(100); - Coalesce coalesce = Coalesce.ofExpressions(int1, int2); + void testCoalesceWithIncompatibleTypesInMiddle() { + Expression int1 = Literal.of(1); + Expression int2 = Literal.of(2); + Expression string = Literal.of("oops"); + Expression int3 = Literal.of(3); + + Coalesce coalesce = Coalesce.ofExpressions(int1, int2, string, int3); Scope scope = new Scope<>(); - Type resultType = coalesce.typeCheck(scope); - // Two non-optionals of same type should return that type - assertEquals(Type.integerType(), resultType); + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> coalesce.typeCheck(scope)); + assertTrue(ex.getMessage().contains("Type mismatch in coalesce")); + assertTrue(ex.getMessage().contains("argument 3")); } @Test - void testCoalesceOptionalWithNonOptional() { - Expression optionalInt = Expression.getReference(Identifier.of("maybeInt")); - Expression defaultInt = Literal.of(0); - Coalesce coalesce = Coalesce.ofExpressions(optionalInt, defaultInt); + void testCoalesceWithLessThanTwoArguments() { + Expression single = Literal.of("only"); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> Coalesce.ofExpressions(single).typeCheck(new Scope<>())); + assertTrue(ex.getMessage().contains("at least 2 arguments")); + } + + @Test + void testCoalesceArrayTypes() { + Expression arr1 = Expression.getReference(Identifier.of("array1")); + Expression arr2 = Expression.getReference(Identifier.of("array2")); + Expression arr3 = Expression.getReference(Identifier.of("array3")); + Coalesce coalesce = Coalesce.ofExpressions(arr1, arr2, arr3); Scope scope = new Scope<>(); - scope.insert("maybeInt", Type.optionalType(Type.integerType())); + scope.insert("array1", Type.arrayType(Type.stringType())); + scope.insert("array2", Type.arrayType(Type.stringType())); + scope.insert("array3", Type.arrayType(Type.stringType())); Type resultType = coalesce.typeCheck(scope); - // Optional coalesced with Int should return Int - assertEquals(Type.integerType(), resultType); + assertEquals(Type.arrayType(Type.stringType()), resultType); } @Test - void testCoalesceArrayTypes() { + void testCoalesceOptionalArrayTypes() { Expression arr1 = Expression.getReference(Identifier.of("array1")); Expression arr2 = Expression.getReference(Identifier.of("array2")); Coalesce coalesce = Coalesce.ofExpressions(arr1, arr2); Scope scope = new Scope<>(); - scope.insert("array1", Type.arrayType(Type.stringType())); - scope.insert("array2", Type.arrayType(Type.stringType())); + scope.insert("array1", Type.optionalType(Type.arrayType(Type.integerType()))); + scope.insert("array2", Type.arrayType(Type.integerType())); Type resultType = coalesce.typeCheck(scope); - assertEquals(Type.arrayType(Type.stringType()), resultType); + // One non-optional makes result non-optional + assertEquals(Type.arrayType(Type.integerType()), resultType); + } + + @Test + void testCoalesceWithBooleanTypes() { + Expression bool1 = Expression.getReference(Identifier.of("bool1")); + Expression bool2 = Expression.getReference(Identifier.of("bool2")); + Expression bool3 = Expression.getReference(Identifier.of("bool3")); + + Coalesce coalesce = Coalesce.ofExpressions(bool1, bool2, bool3); + + Scope scope = new Scope<>(); + scope.insert("bool1", Type.optionalType(Type.booleanType())); + scope.insert("bool2", Type.optionalType(Type.booleanType())); + scope.insert("bool3", Type.booleanType()); + + Type resultType = coalesce.typeCheck(scope); + + assertEquals(Type.booleanType(), resultType); } } diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.errors index 2acfc1d2cc7..dc90845867d 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.errors +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.errors @@ -1,3 +1 @@ -[WARNING] example#FizzBuzz: This shape applies a trait that is unstable: smithy.rules#clientContextParams | UnstableTrait -[WARNING] example#FizzBuzz: This shape applies a trait that is unstable: smithy.rules#endpointRuleSet | UnstableTrait [ERROR] example#FizzBuzz: Coalesce requires rules engine version >= 1.1, but ruleset declares version 1.0 | RulesEngineVersion diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.smithy index b4f85945205..11ca6b72297 100644 --- a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.smithy +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/invalid/bad-version-use.smithy @@ -37,6 +37,7 @@ use smithy.rules#endpointRuleSet } ] }) +@suppress(["UnstableTrait.smithy"]) service FizzBuzz { operations: [GetResource] } diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-three-args.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-three-args.errors new file mode 100644 index 00000000000..e69de29bb2d diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-three-args.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-three-args.smithy new file mode 100644 index 00000000000..54d4b6802d8 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/coalesce-three-args.smithy @@ -0,0 +1,93 @@ +$version: "2.0" + +namespace example + +use smithy.rules#clientContextParams +use smithy.rules#endpointRuleSet +use smithy.rules#endpointTests + +@clientContextParams( + bar: {type: "string", documentation: "a client string parameter"} + baz: {type: "string", documentation: "another client string parameter"} +) +@endpointRuleSet({ + version: "1.1", + parameters: { + bar: { + type: "string", + documentation: "docs" + } + baz: { + type: "string", + documentation: "docs" + } + }, + rules: [ + { + "documentation": "Template baz into URI when bar is set" + "conditions": [ + { + "fn": "coalesce" + "argv": [{"ref": "bar"}, {"ref": "baz"}, "oops"] + "assign": "hi" + } + ] + "endpoint": { + "url": "https://example.com/{hi}" + } + "type": "endpoint" + } + ] +}) +@endpointTests({ + "version": "1.0", + "testCases": [ + { + "params": { + "bar": "bar", + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example.com/bar" + } + } + } + { + "params": { + "baz": "baz" + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example.com/baz" + } + } + } + { + "params": {} + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example.com/oops" + } + } + } + ] +}) +@suppress(["RuleSetParameter.TestCase.Unused"]) +@suppress(["UnstableTrait.smithy"]) +service FizzBuzz { + version: "2022-01-01", + operations: [GetThing] +} + +operation GetThing { + input := {} +} From 96d73dec5768cbdd89999239e378a3ce51bb7c40 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Thu, 21 Aug 2025 17:11:36 -0500 Subject: [PATCH 22/23] Remove result coalescing Trying to coalesce SSA nodes with a result phi node is causing BDD resolution issues. Removing for now. --- .../language/evaluation/RuleEvaluator.java | 20 +- .../rulesengine/logic/cfg/CfgBuilder.java | 555 +----------------- .../rulesengine/logic/cfg/SsaTransform.java | 7 +- .../logic/cfg/VariableAnalysis.java | 22 +- .../rulesengine/logic/cfg/CfgBuilderTest.java | 174 ------ .../logic/cfg/SsaTransformTest.java | 11 +- 6 files changed, 21 insertions(+), 768 deletions(-) diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java index a15352396c8..f946e1534ac 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java @@ -137,21 +137,6 @@ public Value evaluateRuleSet(EndpointRuleSet ruleset, Map par }); } - /** - * Configure the rule evaluator with the given parameters and parameter values for manual evaluation. - * - * @param parameters Parameters of the ruleset to evaluate. - * @param parameterArguments Parameter values to evaluate the ruleset against. - * @return the updated evaluator. - */ - public RuleEvaluator withParameters(Parameters parameters, Map parameterArguments) { - for (Parameter parameter : parameters) { - parameter.getDefault().ifPresent(value -> scope.insert(parameter.getName(), value)); - } - parameterArguments.forEach(scope::insert); - return this; - } - /** * Evaluates the given condition in the current scope. * @@ -185,14 +170,13 @@ public Value visitIsSet(Expression fn) { @Override public Value visitCoalesce(List expressions) { - Value result = Value.emptyValue(); for (Expression exp : expressions) { - result = exp.accept(this); + Value result = exp.accept(this); if (!result.isEmpty()) { return result; } } - return result; + return Value.emptyValue(); } @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java index fdbf9593a8e..1cadcce2525 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilder.java @@ -4,54 +4,32 @@ */ package software.amazon.smithy.rulesengine.logic.cfg; -import java.util.ArrayList; import java.util.Collections; -import java.util.Comparator; import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.Set; -import java.util.TreeSet; -import java.util.logging.Logger; -import software.amazon.smithy.rulesengine.language.Endpoint; import software.amazon.smithy.rulesengine.language.EndpointRuleSet; import software.amazon.smithy.rulesengine.language.syntax.Identifier; import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; -import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.BooleanEquals; -import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.LibraryFunction; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.RecordLiteral; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.StringLiteral; -import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.TupleLiteral; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; -import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule; -import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; -import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; import software.amazon.smithy.rulesengine.logic.ConditionReference; /** - * Builder for constructing Control Flow Graphs with node deduplication and result convergence. + * Builder for constructing Control Flow Graphs with node deduplication. * - *

    This builder performs hash-consing during construction and creates convergence nodes - * for structurally similar results to minimize BDD size. + *

    This builder performs hash-consing during construction to share identical + * subtrees and prevent exponential growth. */ public final class CfgBuilder { - private static final Logger LOGGER = Logger.getLogger(CfgBuilder.class.getName()); - - // Configuration constants - private static final int MAX_DIVERGENT_PATHS_FOR_CONVERGENCE = 5; - private static final int MIN_RESULTS_FOR_CONVERGENCE = 2; final EndpointRuleSet ruleSet; @@ -63,20 +41,9 @@ public final class CfgBuilder { private final Map resultCache = new HashMap<>(); private final Map resultNodeCache = new HashMap<>(); - // Result convergence support - private final Map resultToConvergenceNode = new HashMap<>(); - private int convergenceNodesCreated = 0; - private int phiVariableCounter = 0; - - private String version; - public CfgBuilder(EndpointRuleSet ruleSet) { // Apply SSA transform to ensure globally unique variable names this.ruleSet = SsaTransform.transform(ruleSet); - this.version = ruleSet.getVersion(); - - // Analyze results and create convergence nodes - analyzeAndCreateConvergenceNodes(); } /** @@ -89,24 +56,8 @@ public Cfg build(CfgNode root) { return new Cfg(ruleSet, Objects.requireNonNull(root)); } - /** - * Set the version of the endpoint rules engine (e.g., 1.1). - * - * @param version Version to set. - * @return the builder; - */ - public CfgBuilder version(String version) { - this.version = Objects.requireNonNull(version); - return this; - } - /** * Creates a condition node, reusing existing nodes when possible. - * - * @param condition the condition to evaluate - * @param trueBranch the node to evaluate when the condition is true - * @param falseBranch the node to evaluate when the condition is false - * @return a condition node (possibly cached) */ public CfgNode createCondition(Condition condition, CfgNode trueBranch, CfgNode falseBranch) { return createCondition(createConditionReference(condition), trueBranch, falseBranch); @@ -114,11 +65,6 @@ public CfgNode createCondition(Condition condition, CfgNode trueBranch, CfgNode /** * Creates a condition node, reusing existing nodes when possible. - * - * @param condRef the condition reference to evaluate - * @param trueBranch the node to evaluate when the condition is true - * @param falseBranch the node to evaluate when the condition is false - * @return a condition node (possibly cached) */ public CfgNode createCondition(ConditionReference condRef, CfgNode trueBranch, CfgNode falseBranch) { NodeSignature signature = new NodeSignature(condRef, trueBranch, falseBranch); @@ -127,23 +73,11 @@ public CfgNode createCondition(ConditionReference condRef, CfgNode trueBranch, C /** * Creates a result node representing a terminal rule evaluation. - * - *

    If this result is part of a convergence group, returns the shared convergence node instead. - * - * @param rule the result rule (endpoint or error) - * @return a result node or convergence node */ public CfgNode createResult(Rule rule) { // Intern the result Rule interned = intern(rule); - // Check if this result has a convergence node - CfgNode convergenceNode = resultToConvergenceNode.get(interned); - if (convergenceNode != null) { - LOGGER.fine("Using convergence node for result: " + interned); - return convergenceNode; - } - // Regular result node return resultNodeCache.computeIfAbsent(interned, ResultNode::new); } @@ -191,220 +125,10 @@ public ConditionReference createConditionReference(Condition condition) { return reference; } - private void analyzeAndCreateConvergenceNodes() { - LOGGER.info("Analyzing results for convergence opportunities"); - List allResults = new ArrayList<>(); - for (Rule rule : ruleSet.getRules()) { - collectResultsFromRule(rule, allResults); - } - LOGGER.info("Found " + allResults.size() + " total results"); - - Map groups = groupResultsByStructure(allResults); - createConvergenceNodesForGroups(groups); - - LOGGER.info(String.format("Created %d convergence nodes for %d result groups", - convergenceNodesCreated, - groups.size())); - } - - private void collectResultsFromRule(Rule rule, List results) { - if (rule instanceof EndpointRule || rule instanceof ErrorRule) { - results.add(intern(rule)); - } else if (rule instanceof TreeRule) { - for (Rule nestedRule : ((TreeRule) rule).getRules()) { - collectResultsFromRule(nestedRule, results); - } - } - } - private Rule intern(Rule rule) { return resultCache.computeIfAbsent(canonicalizeResult(rule), k -> k); } - private Map groupResultsByStructure(List results) { - Map groups = new HashMap<>(); - for (Rule result : results) { - ResultSignature sig = new ResultSignature(result); - groups.computeIfAbsent(sig, k -> new ResultGroup()).add(result); - } - return groups; - } - - private void createConvergenceNodesForGroups(Map groups) { - for (ResultGroup group : groups.values()) { - if (shouldGroupResults(group)) { - createConvergenceNodeForGroup(group); - } - } - } - - private boolean shouldGroupResults(ResultGroup group) { - if (group.results.size() < MIN_RESULTS_FOR_CONVERGENCE) { - return false; - } - - int divergentCount = group.getDivergentPathCount(); - if (divergentCount == 0) { - return false; - } - - if (divergentCount > MAX_DIVERGENT_PATHS_FOR_CONVERGENCE) { - LOGGER.fine(String.format("Skipping convergence for group with %d divergent paths (perf)", - divergentCount)); - return false; - } - - return true; - } - - private void createConvergenceNodeForGroup(ResultGroup group) { - Map> divergentPaths = group.getDivergentPaths(); - Rule canonical = group.results.get(0); // already interned - - Map phiVariableMap = createPhiVariablesByPath(divergentPaths); - Rule rewrittenResult = rewriteResultWithPhiVariables(canonical, divergentPaths, phiVariableMap); - - CfgNode convergenceNode = buildConvergenceNode(rewrittenResult, divergentPaths, phiVariableMap); - - for (Rule result : group.results) { - resultToConvergenceNode.put(result, convergenceNode); // keys are the interned instances - } - - convergenceNodesCreated++; - } - - private Map createPhiVariablesByPath(Map> divergentPaths) { - List paths = new ArrayList<>(divergentPaths.keySet()); - Collections.sort(paths); - - Map phiVariableMap = new LinkedHashMap<>(); - for (LocationPath path : paths) { - phiVariableMap.put(path, "phi_result_" + (phiVariableCounter++)); - } - return phiVariableMap; // already ordered - } - - private CfgNode buildConvergenceNode( - Rule result, - Map> divergentPaths, - Map phiVariableMap - ) { - CfgNode resultNode = new ResultNode(result); - - // Apply phi nodes in the order already established by phiVariableMap - for (Map.Entry entry : phiVariableMap.entrySet()) { - Set versions = divergentPaths.get(entry.getKey()); - String phiVar = entry.getValue(); - - Condition coalesceCondition = createCoalesceCondition(phiVar, versions); - ConditionReference condRef = new ConditionReference(coalesceCondition, false); - - // Use NO_MATCH as false branch since coalesce should always succeed - // This ensures we get a valid result index instead of null that would result from FALSE. - CfgNode noMatch = createResult(NoMatchRule.INSTANCE); - resultNode = new ConditionNode(condRef, resultNode, noMatch); - - // Log with deterministic ordering - LOGGER.fine(() -> { - List sortedVersions = new ArrayList<>(versions); - Collections.sort(sortedVersions); - return String.format("Created convergence: %s = coalesce(%s) for path %s", - phiVar, - String.join(",", sortedVersions), - entry.getKey()); - }); - } - - return resultNode; - } - - private Rule rewriteResultWithPhiVariables( - Rule result, - Map> divergentPaths, - Map phiVariableMap - ) { - // Build replacements for URL or error expression only (not headers/properties) - Map urlReplacements = buildPhiReplacements(divergentPaths, phiVariableMap); - - if (urlReplacements.isEmpty()) { - return result; - } - - TreeRewriter rewriter = TreeRewriter.forReplacements(urlReplacements); - - if (result instanceof EndpointRule) { - return rewriteEndpointRule((EndpointRule) result, rewriter); - } else if (result instanceof ErrorRule) { - return rewriteErrorRule((ErrorRule) result, rewriter); - } - - return result; - } - - private Map buildPhiReplacements( - Map> divergentPaths, - Map phiVariableMap - ) { - Map replacements = new HashMap<>(); - for (Map.Entry> entry : divergentPaths.entrySet()) { - LocationPath path = entry.getKey(); - String phiVar = phiVariableMap.get(path); - Expression phiRef = Expression.getReference(Identifier.of(phiVar)); - - // Map all versions at this path to the phi variable - for (String version : entry.getValue()) { - replacements.put(version, phiRef); - } - } - return replacements; - } - - private Rule rewriteEndpointRule(EndpointRule rule, TreeRewriter rewriter) { - Endpoint endpoint = rule.getEndpoint(); - Expression rewrittenUrl = rewriter.rewrite(endpoint.getUrl()); - - if (rewrittenUrl != endpoint.getUrl()) { - Endpoint rewrittenEndpoint = Endpoint.builder() - .url(rewrittenUrl) - .headers(endpoint.getHeaders()) - .properties(endpoint.getProperties()) - .build(); - - return EndpointRule.builder() - .description(rule.getDocumentation().orElse(null)) - .conditions(Collections.emptyList()) - .endpoint(rewrittenEndpoint); - } - return rule; - } - - private Rule rewriteErrorRule(ErrorRule rule, TreeRewriter rewriter) { - Expression rewrittenError = rewriter.rewrite(rule.getError()); - - if (rewrittenError != rule.getError()) { - return ErrorRule.builder() - .description(rule.getDocumentation().orElse(null)) - .conditions(Collections.emptyList()) - .error(rewrittenError); - } - return rule; - } - - private Condition createCoalesceCondition(String resultVar, Set versions) { - if (versions.isEmpty()) { - throw new IllegalArgumentException("Cannot create coalesce with no versions"); - } - - // TreeSet for both deduplication and deterministic ordering - Set inputs = new TreeSet<>(versions); - List refs = new ArrayList<>(inputs.size()); - for (String input : inputs) { - refs.add(Expression.getReference(Identifier.of(input))); - } - - return Condition.builder().fn(Coalesce.ofExpressions(refs)).result(Identifier.of(resultVar)).build(); - } - private Rule canonicalizeResult(Rule rule) { return rule == null ? null : rule.withConditions(Collections.emptyList()); } @@ -445,279 +169,6 @@ private static Condition unwrapNegation(Condition negatedCondition) { .build(); } - /** - * Path represents a structural location within an expression tree. - */ - private static final class LocationPath implements Comparable { - private final String key; - private final int hash; - - LocationPath(List parts) { - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < parts.size(); i++) { - if (i > 0) { - sb.append("/"); - } - sb.append(parts.get(i).toString()); - } - this.key = sb.toString(); - this.hash = key.hashCode(); - } - - @Override - public boolean equals(Object o) { - if (o instanceof LocationPath) { - return ((LocationPath) o).key.equals(this.key); - } - return false; - } - - @Override - public int hashCode() { - return hash; - } - - @Override - public String toString() { - return key; - } - - @Override - public int compareTo(LocationPath other) { - return this.key.compareTo(other.key); - } - } - - /** - * Signature for result grouping based on structural similarity. - */ - private static class ResultSignature { - private final String type; - private final Object urlStructure; - private final Object headersStructure; - private final Object propertiesStructure; - private final int hashCode; - - ResultSignature(Rule result) { - this.type = result instanceof EndpointRule ? "endpoint" : "error"; - - if (result instanceof EndpointRule) { - Endpoint ep = ((EndpointRule) result).getEndpoint(); - this.urlStructure = buildExpressionStructure(ep.getUrl()); - this.headersStructure = buildHeadersStructure(ep.getHeaders()); - this.propertiesStructure = buildPropertiesStructure(ep.getProperties()); - } else if (result instanceof ErrorRule) { - this.urlStructure = buildExpressionStructure(((ErrorRule) result).getError()); - this.headersStructure = null; - this.propertiesStructure = null; - } else { - this.urlStructure = null; - this.headersStructure = null; - this.propertiesStructure = null; - } - - this.hashCode = Objects.hash(type, urlStructure, headersStructure, propertiesStructure); - } - - private Object buildExpressionStructure(Expression expr) { - if (expr instanceof Reference) { - return "VAR"; - } else if (expr instanceof StringLiteral) { - return buildTemplateStructure(((StringLiteral) expr).value()); - } else if (expr instanceof LibraryFunction) { - return buildFunctionStructure((LibraryFunction) expr); - } else if (expr instanceof TupleLiteral) { - return buildTupleStructure((TupleLiteral) expr); - } else if (expr instanceof RecordLiteral) { - return buildRecordStructure((RecordLiteral) expr); - } else if (expr instanceof Literal) { - return expr.toString(); - } - return expr.getClass().getSimpleName(); - } - - private Object buildFunctionStructure(LibraryFunction fn) { - Map fnStructure = new LinkedHashMap<>(); - fnStructure.put("fn", fn.getName()); - List args = new ArrayList<>(); - for (Expression arg : fn.getArguments()) { - args.add(buildExpressionStructure(arg)); - } - fnStructure.put("args", args); - return fnStructure; - } - - private Object buildTupleStructure(TupleLiteral tuple) { - List structure = new ArrayList<>(); - for (Literal member : tuple.members()) { - structure.add(buildExpressionStructure(member)); - } - return structure; - } - - private Object buildRecordStructure(RecordLiteral record) { - Map recordStructure = new LinkedHashMap<>(); - for (Map.Entry entry : record.members().entrySet()) { - recordStructure.put(entry.getKey().toString(), buildExpressionStructure(entry.getValue())); - } - return recordStructure; - } - - private Object buildTemplateStructure(Template template) { - if (template.isStatic()) { - return template.expectLiteral(); - } - - List parts = new ArrayList<>(); - for (Template.Part part : template.getParts()) { - if (part instanceof Template.Literal) { - parts.add(((Template.Literal) part).getValue()); - } else if (part instanceof Template.Dynamic) { - parts.add(buildExpressionStructure(((Template.Dynamic) part).toExpression())); - } - } - return parts; - } - - private Object buildHeadersStructure(Map> headers) { - if (headers.isEmpty()) { - return Collections.emptyMap(); - } - - // Sort header keys for deterministic ordering - List sortedKeys = new ArrayList<>(headers.keySet()); - Collections.sort(sortedKeys); - - Map structure = new LinkedHashMap<>(); - for (String key : sortedKeys) { - List values = new ArrayList<>(); - for (Expression expr : headers.get(key)) { - values.add(buildExpressionStructure(expr)); - } - structure.put(key, values); - } - return structure; - } - - private Object buildPropertiesStructure(Map properties) { - if (properties.isEmpty()) { - return Collections.emptyMap(); - } - - // Sort property keys by their string representation for deterministic ordering - List sortedIds = new ArrayList<>(properties.keySet()); - sortedIds.sort(Comparator.comparing(Identifier::toString)); - - Map structure = new LinkedHashMap<>(); - for (Identifier id : sortedIds) { - structure.put(id.toString(), buildExpressionStructure(properties.get(id))); - } - return structure; - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof ResultSignature)) { - return false; - } - ResultSignature that = (ResultSignature) o; - return type.equals(that.type) - && Objects.equals(urlStructure, that.urlStructure) - && Objects.equals(headersStructure, that.headersStructure) - && Objects.equals(propertiesStructure, that.propertiesStructure); - } - - @Override - public int hashCode() { - return hashCode; - } - } - - /** - * Group of structurally similar results. - */ - private static class ResultGroup { - private final List results = new ArrayList<>(); - private Map> divergentPaths = null; - - void add(Rule result) { - results.add(result); - divergentPaths = null; // Invalidate cache - } - - Map> getDivergentPaths() { - if (divergentPaths == null) { - divergentPaths = computeDivergentByPath(); - } - return divergentPaths; - } - - int getDivergentPathCount() { - return getDivergentPaths().size(); - } - - private Map> computeDivergentByPath() { - Map> byPath = new LinkedHashMap<>(); - - // Collect references by path for URL only (not headers/properties) - for (Rule result : results) { - if (result instanceof EndpointRule) { - Endpoint ep = ((EndpointRule) result).getEndpoint(); - collectRefsByPath(ep.getUrl(), new ArrayList<>(), byPath); - } else if (result instanceof ErrorRule) { - Expression error = ((ErrorRule) result).getError(); - collectRefsByPath(error, new ArrayList<>(), byPath); - } - } - - // Remove paths with only one variable name (no divergence) - byPath.entrySet().removeIf(e -> e.getValue().size() <= 1); - - return byPath; - } - - private void collectRefsByPath(Expression expr, List path, Map> out) { - if (expr instanceof StringLiteral) { - collectTemplateRefs((StringLiteral) expr, path, out); - } else if (expr instanceof Reference) { - LocationPath p = new LocationPath(path); - out.computeIfAbsent(p, k -> new LinkedHashSet<>()).add(((Reference) expr).getName().toString()); - } else if (expr instanceof LibraryFunction) { - collectFunctionRefs((LibraryFunction) expr, path, out); - } else { - throw new UnsupportedOperationException("Unexpected URL or error type: " + expr); - } - } - - private void collectTemplateRefs(StringLiteral str, List path, Map> out) { - Template template = str.value(); - int i = 0; - for (Template.Part part : template.getParts()) { - if (part instanceof Template.Dynamic) { - List newPath = new ArrayList<>(path); - newPath.add("T"); - newPath.add(i); - collectRefsByPath(((Template.Dynamic) part).toExpression(), newPath, out); - } - i++; - } - } - - private void collectFunctionRefs(LibraryFunction fn, List path, Map> out) { - int i = 0; - for (Expression arg : fn.getArguments()) { - List newPath = new ArrayList<>(path); - newPath.add("F"); - newPath.add(fn.getName()); - newPath.add(i++); - collectRefsByPath(arg, newPath, out); - } - } - } - /** * Signature for node deduplication during construction. */ diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java index da6eda2f02f..209bac8dcd7 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransform.java @@ -29,8 +29,8 @@ * *

    This transformation ensures that each variable is assigned exactly once by renaming variables when they are * reassigned in different parts of the tree. For example, if variable "x" is assigned in multiple branches, they - * become "x", "x_1", "x_2", etc. Without this transform, the BDD compilation would confuse divergent paths that have - * the same variable name. + * become "x_ssa_1", "x_ssa_2", "x_ssa_3", etc. Without this transform, the BDD compilation would confuse divergent + * paths that have the same variable name. * *

    Note that this transform is only applied when the reassignment is done using different * arguments than previously seen assignments of the same variable name. @@ -59,7 +59,8 @@ private SsaTransform(VariableAnalysis variableAnalysis) { static EndpointRuleSet transform(EndpointRuleSet ruleSet) { ruleSet = VariableConsolidationTransform.transform(ruleSet); ruleSet = CoalesceTransform.transform(ruleSet); - SsaTransform ssaTransform = new SsaTransform(VariableAnalysis.analyze(ruleSet)); + VariableAnalysis variableAnalysis = VariableAnalysis.analyze(ruleSet); + SsaTransform ssaTransform = new SsaTransform(variableAnalysis); List rewrittenRules = new ArrayList<>(ruleSet.getRules().size()); for (Rule original : ruleSet.getRules()) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java index 33d995239df..7d477a72b51 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableAnalysis.java @@ -29,26 +29,24 @@ /** * Analyzes variables in an endpoint rule set, collecting bindings, reference counts, - * and other metadata needed for SSA transformation and optimization. + * and expression mappings needed for SSA transformation. */ final class VariableAnalysis { private final Set inputParams; private final Map> bindings; private final Map referenceCounts; private final Map> expressionMappings; - private final Map> expressionToVars; private VariableAnalysis( Set inputParams, Map> bindings, Map referenceCounts, - Map> expressionToVars + Map> expressionMappings ) { this.inputParams = inputParams; this.bindings = bindings; this.referenceCounts = referenceCounts; - this.expressionToVars = expressionToVars; - this.expressionMappings = createExpressionMappings(bindings); + this.expressionMappings = expressionMappings; } static VariableAnalysis analyze(EndpointRuleSet ruleSet) { @@ -63,7 +61,7 @@ static VariableAnalysis analyze(EndpointRuleSet ruleSet) { inputParameters, visitor.bindings, visitor.referenceCounts, - visitor.expressionToVars); + createExpressionMappings(visitor.bindings)); } Set getInputParams() { @@ -123,15 +121,16 @@ private static Map createMappingForVariable( Map mapping = new HashMap<>(); if (expressions.size() == 1) { + // Single binding: no SSA rename needed String expression = expressions.iterator().next(); mapping.put(expression, varName); } else { + // Multiple bindings: use SSA naming convention List sortedExpressions = new ArrayList<>(expressions); sortedExpressions.sort(String::compareTo); - for (int i = 0; i < sortedExpressions.size(); i++) { String expression = sortedExpressions.get(i); - String uniqueName = (i == 0) ? varName : varName + "_" + i; + String uniqueName = varName + "_ssa_" + (i + 1); mapping.put(expression, uniqueName); } } @@ -142,8 +141,6 @@ private static Map createMappingForVariable( private static class AnalysisVisitor { final Map> bindings = new HashMap<>(); final Map referenceCounts = new HashMap<>(); - final Map> expressionToVars = new HashMap<>(); - private final Set inputParams; AnalysisVisitor(Set inputParams) { @@ -156,13 +153,9 @@ void visitRule(Rule rule) { String varName = condition.getResult().get().toString(); LibraryFunction fn = condition.getFunction(); String expression = fn.toString(); - String canonical = fn.canonicalize().toString(); bindings.computeIfAbsent(varName, k -> new HashSet<>()) .add(expression); - - expressionToVars.computeIfAbsent(canonical, k -> new ArrayList<>()) - .add(varName); } countReferences(condition.getFunction()); @@ -194,7 +187,6 @@ private void countReferences(Expression expression) { if (expression instanceof Reference) { Reference ref = (Reference) expression; String name = ref.getName().toString(); - // Count all references, including input parameters referenceCounts.merge(name, 1, Integer::sum); } else if (expression instanceof StringLiteral) { StringLiteral str = (StringLiteral) expression; diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java index f6cdb96a1c2..8fb61f6b89a 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/CfgBuilderTest.java @@ -13,7 +13,6 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -318,44 +317,6 @@ void createResultPreservesHeadersAndPropertiesInSignature() { assertNotSame(node1, node2); } - @Test - void createResultWithStructurallyIdenticalEndpointsCreatesConvergenceNode() { - // Create two rules with structurally identical endpoints but different variable names - Rule rule1 = EndpointRule.builder() - .endpoint(TestHelpers.endpoint("https://{region1}.example.com")); - Rule rule2 = EndpointRule.builder() - .endpoint(TestHelpers.endpoint("https://{region2}.example.com")); - - // Create parameters for the variables used in the endpoints - Parameters params = Parameters.builder() - .addParameter(Parameter.builder() - .name("region1") - .type(ParameterType.STRING) - .defaultValue(Value.stringValue("a")) - .required(true) - .build()) - .addParameter(Parameter.builder() - .name("region2") - .type(ParameterType.STRING) - .defaultValue(Value.stringValue("a")) - .required(true) - .build()) - .build(); - - EndpointRuleSet ruleSetWithEndpoints = EndpointRuleSet.builder() - .parameters(params) - .rules(ListUtils.of(rule1, rule2)) - .build(); - CfgBuilder convergenceBuilder = new CfgBuilder(ruleSetWithEndpoints); - - CfgNode node1 = convergenceBuilder.createResult(rule1); - CfgNode node2 = convergenceBuilder.createResult(rule2); - - // Both should return the same convergence node - assertSame(node1, node2); - assertInstanceOf(ConditionNode.class, node1); - } - @Test void createResultDistinguishesEndpointsWithDifferentStructure() { // Create rules with different endpoint structures @@ -385,139 +346,4 @@ void createResultDistinguishesEndpointsWithDifferentStructure() { // Different structures should not converge assertNotSame(node1, node2); } - - @Test - void createResultWithIdenticalErrorsCreatesConvergenceNode() { - // Create structurally identical error rules with different variable references - Rule error1 = ErrorRule.builder() - .error(Expression.of("Region {r1} is not supported")); - Rule error2 = ErrorRule.builder() - .error(Expression.of("Region {r2} is not supported")); - - Parameters params = Parameters.builder() - .addParameter(Parameter.builder() - .name("r1") - .type(ParameterType.STRING) - .defaultValue(Value.stringValue("a")) - .required(true) - .build()) - .addParameter(Parameter.builder() - .name("r2") - .type(ParameterType.STRING) - .defaultValue(Value.stringValue("a")) - .required(true) - .build()) - .build(); - - EndpointRuleSet ruleSetWithErrors = EndpointRuleSet.builder() - .parameters(params) - .rules(ListUtils.of(error1, error2)) - .build(); - CfgBuilder convergenceBuilder = new CfgBuilder(ruleSetWithErrors); - - CfgNode node1 = convergenceBuilder.createResult(error1); - CfgNode node2 = convergenceBuilder.createResult(error2); - - // Should converge to the same node - assertSame(node1, node2); - } - - @Test - void createResultHandlesComplexTemplateConvergence() { - // Create endpoints with complex templates that are structurally identical - Rule rule1 = EndpointRule.builder() - .endpoint(TestHelpers.endpoint("https://{svc}.{reg}.amazonaws.com/{path}")); - Rule rule2 = EndpointRule.builder() - .endpoint(TestHelpers.endpoint("https://{service}.{region}.amazonaws.com/{resource}")); - - Parameters params = Parameters.builder() - .addParameter(Parameter.builder() - .name("svc") - .type(ParameterType.STRING) - .defaultValue(Value.stringValue("a")) - .required(true) - .build()) - .addParameter(Parameter.builder() - .name("reg") - .type(ParameterType.STRING) - .defaultValue(Value.stringValue("a")) - .required(true) - .build()) - .addParameter(Parameter.builder() - .name("path") - .type(ParameterType.STRING) - .defaultValue(Value.stringValue("a")) - .required(true) - .build()) - .addParameter(Parameter.builder() - .name("service") - .type(ParameterType.STRING) - .defaultValue(Value.stringValue("a")) - .required(true) - .build()) - .addParameter(Parameter.builder() - .name("region") - .type(ParameterType.STRING) - .defaultValue(Value.stringValue("a")) - .required(true) - .build()) - .addParameter(Parameter.builder() - .name("resource") - .type(ParameterType.STRING) - .defaultValue(Value.stringValue("a")) - .required(true) - .build()) - .build(); - - EndpointRuleSet ruleSetWithTemplates = EndpointRuleSet.builder() - .parameters(params) - .rules(ListUtils.of(rule1, rule2)) - .build(); - CfgBuilder convergenceBuilder = new CfgBuilder(ruleSetWithTemplates); - - CfgNode node1 = convergenceBuilder.createResult(rule1); - CfgNode node2 = convergenceBuilder.createResult(rule2); - - // Structurally identical templates should converge - assertSame(node1, node2); - } - - @Test - void createResultDoesNotConvergeWithTooManyDivergentPaths() { - // Create many endpoints with different variable names at multiple positions - // This should exceed MAX_DIVERGENT_PATHS_FOR_CONVERGENCE (5) - List rules = new ArrayList<>(); - Parameters.Builder paramsBuilder = Parameters.builder(); - - for (int i = 0; i < 7; i++) { - rules.add(EndpointRule.builder() - .endpoint(TestHelpers.endpoint(String.format("https://{var%d}.{reg%d}.example.com", i, i)))); - paramsBuilder.addParameter(Parameter.builder() - .name("var" + i) - .type(ParameterType.STRING) - .defaultValue(Value.stringValue("a")) - .required(true) - .build()); - paramsBuilder.addParameter(Parameter.builder() - .name("reg" + i) - .type(ParameterType.STRING) - .defaultValue(Value.stringValue("a")) - .required(true) - .build()); - } - - EndpointRuleSet ruleSetWithMany = EndpointRuleSet.builder() - .parameters(paramsBuilder.build()) - .rules(rules) - .build(); - CfgBuilder convergenceBuilder = new CfgBuilder(ruleSetWithMany); - - // With too many divergent paths, convergence should be skipped - CfgNode firstNode = convergenceBuilder.createResult(rules.get(0)); - CfgNode lastNode = convergenceBuilder.createResult(rules.get(rules.size() - 1)); - - // They should still be cached as the same due to interning, - // but won't have phi node convergence due to performance limits - assertSame(firstNode, lastNode); - } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java index 8a0d2b6e14f..c5d7b5c405f 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/logic/cfg/SsaTransformTest.java @@ -77,15 +77,14 @@ void testSimpleShadowing() { EndpointRuleSet result = SsaTransform.transform(original); - // The second "temp" should be renamed List resultRules = result.getRules(); assertEquals(2, resultRules.size()); EndpointRule resultRule1 = (EndpointRule) resultRules.get(0); - assertEquals("temp", resultRule1.getConditions().get(0).getResult().get().toString()); + assertEquals("temp_ssa_1", resultRule1.getConditions().get(0).getResult().get().toString()); EndpointRule resultRule2 = (EndpointRule) resultRules.get(1); - assertEquals("temp_1", resultRule2.getConditions().get(0).getResult().get().toString()); + assertEquals("temp_ssa_2", resultRule2.getConditions().get(0).getResult().get().toString()); } @Test @@ -110,9 +109,9 @@ void testMultipleShadowsOfSameVariable() { EndpointRuleSet result = SsaTransform.transform(original); List resultRules = result.getRules(); - assertEquals("temp", (resultRules.get(0)).getConditions().get(0).getResult().get().toString()); - assertEquals("temp_1", (resultRules.get(1)).getConditions().get(0).getResult().get().toString()); - assertEquals("temp_2", (resultRules.get(2)).getConditions().get(0).getResult().get().toString()); + assertEquals("temp_ssa_1", resultRules.get(0).getConditions().get(0).getResult().get().toString()); + assertEquals("temp_ssa_2", resultRules.get(1).getConditions().get(0).getResult().get().toString()); + assertEquals("temp_ssa_3", resultRules.get(2).getConditions().get(0).getResult().get().toString()); } @Test From bed1c9a423e67711781f5564ca6248f9b6b9c2db Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Fri, 22 Aug 2025 13:27:31 -0500 Subject: [PATCH 23/23] Add BDD trait docs --- .../rules-engine/specification.rst | 188 +++++++++++++++++- 1 file changed, 183 insertions(+), 5 deletions(-) diff --git a/docs/source-2.0/additional-specs/rules-engine/specification.rst b/docs/source-2.0/additional-specs/rules-engine/specification.rst index f1861f97689..36fa421bd08 100644 --- a/docs/source-2.0/additional-specs/rules-engine/specification.rst +++ b/docs/source-2.0/additional-specs/rules-engine/specification.rst @@ -14,7 +14,7 @@ are composed of a set of *conditions*, which determine if a rule should be selected, and a result. Conditions act on the defined parameters, and allow for the modeling of statements. -When a rule’s conditions are evaluated successfully, the rule provides either a +When a rule's conditions are evaluated successfully, the rule provides either a result and its accompanying requirements or an error describing the unsupported state. Modeled endpoint errors allow for more explicit descriptions to users, such as providing errors when a service doesn't support a combination of @@ -24,8 +24,9 @@ conditions. .. smithy-trait:: smithy.rules#endpointRuleSet .. _smithy.rules#endpointRuleSet-trait: +-------------------------------------- ``smithy.rules#endpointRuleSet`` trait -====================================== +-------------------------------------- Summary Defines a rule set for deriving service endpoints at runtime. @@ -45,8 +46,7 @@ The content of the ``endpointRuleSet`` document has the following properties: - Description * - version - ``string`` - - **Required**. The rule set schema version. This specification covers - version 1.0 of the endpoint rule set. + - **Required**. The rules engine version (e.g., 1.0). * - serviceId - ``string`` - **Required**. An identifier for the corresponding service. @@ -74,6 +74,184 @@ or :ref:`error rules, ` with an empty set of conditions to provide a more meaningful default or error depending on the scenario. +.. smithy-trait:: smithy.rules#endpointBdd +.. _smithy.rules#endpointBdd-trait: + +---------------------------------- +``smithy.rules#endpointBdd`` trait +---------------------------------- + +.. warning:: Experimental + + This trait is experimental and subject to change. + +Summary + Defines a `Binary Decision Diagram (BDD) `_ representation + of endpoint rules for efficient runtime evaluation. +Trait selector + ``service`` +Value type + ``structure`` + +The ``endpointBdd`` trait provides a BDD representation of endpoint rules, optimizing runtime evaluation by +eliminating redundant condition evaluations and reducing the decision tree to a minimal directed acyclic graph. +This trait is an alternative to ``endpointRuleSet`` that trades compile-time complexity for significantly improved +runtime performance and reduced artifact sizes. + +.. note:: + + The ``endpointBdd`` trait can be generated from an ``endpointRuleSet`` trait through compilation. Services may + provide either trait, with ``endpointBdd`` preferred for production use due to its performance characteristics. + +The ``endpointBdd`` structure has the following properties: + +.. list-table:: + :header-rows: 1 + :widths: 10 30 60 + + * - Property name + - Type + - Description + * - version + - ``string`` + - **Required**. The endpoint rules engine version. Must be at least version 1.1. + * - parameters + - ``map`` of `Parameter object`_ + - **Required**. A map of zero or more endpoint parameter names to + their parameter configuration. Uses the same parameter structure as + ``endpointRuleSet``. + * - conditions + - ``array`` of `Condition object`_ + - **Required**. Array of conditions that are evaluated during BDD + traversal. Each condition is referenced by its index in this array. + * - results + - ``array`` of `Endpoint rule object`_ or `Error rule object`_ + - **Required**. Array of possible endpoint results. The implicit `NoMatchRule` at BDD reference 0 is not included + in the array. These rule objects MUST NOT contain conditions. + * - root + - ``integer`` + - **Required**. The root reference where BDD evaluation begins. + * - nodeCount + - ``integer`` + - **Required**. The total number of nodes in the BDD. Used for validation and exact-sizing arrays during + deserialization. + * - nodes + - ``string`` + - **Required**. Base64-encoded binary representation of BDD nodes. Each node is encoded as three 4-byte + integers: ``[conditionIndex, highRef, lowRef]``. + +.. _rules-engine-endpoint-bdd-node-structure: + +BDD node structure +------------------ + +Each BDD node is encoded as a triple of integers: + +* ``conditionIndex``: Zero-based index into the ``conditions`` array +* ``highRef``: Reference to follow when the condition evaluates to true +* ``lowRef``: Reference to follow when the condition evaluates to false + +The first node, index 0, is always the terminal node ``[-1, 1, -1]`` and MUST NOT be referenced directly. This node +serves as the canonical base case for BDD reduction algorithms. + +.. _rules-engine-endpoint-bdd-reference-encoding: + +Reference encoding +------------------ + +BDD references use the following encoding scheme: + +.. list-table:: + :header-rows: 1 + :widths: 20 80 + + * - Reference value + - Description + * - ``0`` + - Invalid/unused reference (never appears in valid BDDs) + * - ``1`` + - TRUE terminal (no match in endpoint resolution) + * - ``-1`` + - FALSE terminal (no match in endpoint resolution) + * - ``2, 3, 4, ...`` + - Node references (points to ``nodes[ref-1]``) + * - ``-2, -3, -4, ...`` + - Complement edges (logical NOT of the referenced node) + * - ``100000000+`` + - Result terminals (100000000 + resultIndex) + +When traversing a complement edge (negative reference), the high and low branches are swapped during evaluation. +This enables significant node sharing and BDD size reduction. + +.. _rules-engine-endpoint-bdd-binary-encoding: + +Binary node encoding +-------------------- + +Nodes are encoded as a Base64 string using binary encoding for efficiency: + +* Each node consists of three 4-byte big-endian integers +* Nodes are concatenated sequentially: ``[node0, node1, ..., nodeN-1]`` +* The resulting byte array is Base64-encoded + + +.. note:: Why binary? + + This encoding provides: + + * **Size efficiency**: smaller than an array of JSON integers, or an array of arrays of integers + * **Performance**: Direct deserialization into the target data structure (e.g., primitive arrays and integers) + * **Cleaner diffs**: BDD node changes appear as single-line modifications rather than spread over thousands + of numbers. + +.. _rules-engine-endpoint-bdd-evaluation: + +BDD evaluation +-------------- + +BDD evaluation follows these steps: + +#. Start at the root reference +#. While the reference is a node reference (not a terminal or result): + + * Extract the node index: ``nodeIndex = |ref| - 1`` + * Retrieve the node at that index + * Evaluate the condition at ``conditionIndex`` + * Determine which branch to follow: + + * If the reference is complemented (negative) AND condition is true: follow ``lowRef`` + * If the reference is complemented (negative) AND condition is false: follow ``highRef`` + * If the reference is positive AND condition is true: follow ``highRef`` + * If the reference is positive AND condition is false: follow ``lowRef`` + + * Update the reference to the chosen branch and continue + +#. When reaching a terminal or result: + + * For result references ≥ 100000000: return ``results[ref - 100000000]`` + * For terminals (1 or -1): return the ``NoMatchRule`` + +For example, a reference of 100000003 would return ``results[3]``, while a reference of 1 or -1 indicates no matching +rule was found. + +.. _rules-engine-endpoint-bdd-validation: + +Validation requirements +----------------------- + +* **Root reference**: MUST NOT be complemented (negative) +* **Reference validity**: All references MUST be valid: + + * ``0`` is forbidden + * Node references MUST point to existing nodes + * Result references MUST point to existing results + +* **Node structure**: Each node MUST be a properly formed triple +* **Condition indices**: Each node's condition index MUST be within ``[0, conditionCount)`` +* **Result structure**: The first result (index 0) implicitly represents ``NoMatchRule`` and is not serialized. + All serialized results MUST be either ``EndpointRule`` or ``ErrorRule`` objects without conditions. +* **Version requirement**: The version MUST be at least 1.1 + .. _rules-engine-endpoint-rule-set-parameter: ---------------- @@ -119,7 +297,7 @@ allow values to be bound to parameters from other locations in generated clients. Parameters MAY be annotated with the ``builtIn`` property, which designates that -the parameter should be bound to a value determined by the built-in’s name. The +the parameter should be bound to a value determined by the built-in's name. The :ref:`rules engine contains built-ins ` and the set is extensible.