diff --git a/flappy/synth.py b/flappy/synth.py index e9f8edd..c7a7d7e 100644 --- a/flappy/synth.py +++ b/flappy/synth.py @@ -46,29 +46,34 @@ def enumerate( max_candidates: int = 3, context: Optional[dict] = None, ) -> Iterable[CandidatePlan]: - selectors = self._extract_selectors(context) + selectors, selectors_explicit = self._extract_selectors(context) plan_copy = copy.deepcopy(sketch.root) - if self._plan_compatible(plan_copy, selectors): + if self._plan_compatible(plan_copy, selectors, selectors_explicit): yield CandidatePlan(root=plan_copy, score=1.0, rationale="selectors-bound") - elif not selectors: + elif not selectors_explicit: # If no selector information is available, fall back to the raw sketch. yield CandidatePlan(root=plan_copy, score=0.0, rationale="identity-no-selectors") - def _extract_selectors(self, context: Optional[dict]) -> Sequence[str]: - if not context: - return [] + def _extract_selectors(self, context: Optional[dict]) -> tuple[List[str], bool]: + if not context or "selectors" not in context: + return [], False selectors = context.get("selectors") if isinstance(selectors, (list, tuple)): - return selectors - return [] + return list(selectors), True + return [], True - def _plan_compatible(self, node: DSLNode, selectors: Sequence[str]) -> bool: + def _plan_compatible( + self, node: DSLNode, selectors: Sequence[str], selectors_explicit: bool + ) -> bool: if node.verb in {DSLVerb.CLICK, DSLVerb.TYPE} and node.args: selector = node.args[0] - if selector and selectors and selector not in selectors: - return False + if selector: + if selectors and selector not in selectors: + return False + if selectors_explicit and not selectors: + return False for child in node.children: - if not self._plan_compatible(child, selectors): + if not self._plan_compatible(child, selectors, selectors_explicit): return False return True diff --git a/tests/test_flappy_stubs.py b/tests/test_flappy_stubs.py index e606251..f80cad6 100644 --- a/tests/test_flappy_stubs.py +++ b/tests/test_flappy_stubs.py @@ -45,6 +45,15 @@ def test_plan_synthesiser_identity(): assert proposals[0].root.to_dict() == leaf.to_dict() +def test_plan_synthesiser_empty_selectors_rejects_missing(): + leaf = dsl.make_leaf(dsl.DSLVerb.CLICK, "#missing") + sketch = synth.Sketch(root=leaf, holes=[]) + proposals = list( + synth.PlanSynthesiser().enumerate(sketch, context={"selectors": []}) + ) + assert not proposals + + def test_verifier_stub(): verifier = verify.PlanVerifier() result = verifier.verify(