Skip to content

Commit dd9771a

Browse files
Copilotpvandyken
andcommitted
Reimplement snakebids expand using SnakemakeFormatter
- Add allow_missing support to SnakemakeFormatter with constraint preservation - Store _field_constraints in parse() for use by get_value() with allow_missing - Add _expand() helper in datasets.py replacing sn_expand calls - BidsPartialComponent.expand and BidsComponentRow.expand now use _expand() - None values in extra wildcards converted to empty string - AnnotatedString flags propagated to expanded outputs (gated on snakemake import) - Add focused tests for allow_missing, None handling, and AnnotatedString propagation" Co-authored-by: pvandyken <87136354+pvandyken@users.noreply.github.com>
1 parent 74e2c9a commit dd9771a

4 files changed

Lines changed: 242 additions & 45 deletions

File tree

src/snakebids/core/datasets.py

Lines changed: 100 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,95 @@
1919
from snakebids.exceptions import DuplicateComponentError
2020
from snakebids.io.console import get_console_size
2121
from snakebids.io.printing import format_zip_lists, quote_wrap
22-
from snakebids.snakemake_compat import expand as sn_expand
22+
from snakebids.snakemake_compat import WildcardError
2323
from snakebids.types import ZipList
2424
from snakebids.utils.containers import ImmutableList, MultiSelectDict
2525
from snakebids.utils.snakemake_templates import SnakemakeFormatter
2626
from snakebids.utils.utils import get_wildcard_dict, property_alias, zip_list_eq
2727

28+
try:
29+
from snakemake.io import AnnotatedString as _AnnotatedString
30+
31+
def _get_flags(path: Any) -> dict[str, Any] | None:
32+
"""Extract flags from an AnnotatedString path for preservation during expansion."""
33+
if isinstance(path, _AnnotatedString) and path.flags:
34+
return dict(path.flags)
35+
return None
36+
37+
def _make_annotated(s: str, flags: dict[str, Any]) -> str:
38+
"""Create an AnnotatedString with the given flags copied from the template."""
39+
result = _AnnotatedString(s)
40+
result.flags.update(flags)
41+
return result
42+
43+
except ImportError:
44+
45+
def _get_flags(path: Any) -> dict[str, Any] | None: # type: ignore[misc]
46+
"""Return None when snakemake is not installed (no AnnotatedString support)."""
47+
return None
48+
49+
def _make_annotated(s: str, flags: dict[str, Any]) -> str: # type: ignore[misc]
50+
"""Return the string unchanged when snakemake is not installed."""
51+
return s
52+
53+
54+
def _expand(
55+
paths: Iterable[Any],
56+
zip_lists: ZipList,
57+
allow_missing: bool,
58+
**wildcards: Any,
59+
) -> list[str]:
60+
"""Expand template paths using SnakemakeFormatter.
61+
62+
For each template path, iterates over rows of the zip-list and optional
63+
extra wildcard combinations (product), formatting each path per row.
64+
None values in wildcards are converted to "".
65+
Output is order-preserving deduplicated.
66+
"""
67+
formatter = SnakemakeFormatter(allow_missing=allow_missing)
68+
69+
# Normalize extra wildcards: convert None (scalar or in list) → "", ensure lists
70+
extra: dict[str, list[str]] = {}
71+
for k, vals in wildcards.items():
72+
if vals is None:
73+
extra[k] = [""]
74+
else:
75+
extra[k] = ["" if v is None else v for v in itx.always_iterable(vals)]
76+
77+
# Build zip-list rows
78+
if zip_lists:
79+
rows: list[dict[str, str]] = [
80+
dict(zip(zip_lists.keys(), vals, strict=True))
81+
for vals in zip(*zip_lists.values(), strict=True)
82+
]
83+
else:
84+
rows = [{}]
85+
86+
# Compute extra wildcard combinations (product)
87+
if extra:
88+
extra_keys = list(extra.keys())
89+
extra_combos: list[tuple[str, ...]] = list(it.product(*extra.values()))
90+
else:
91+
extra_keys = []
92+
extra_combos = [()]
93+
94+
try:
95+
results: list[str] = []
96+
for path in paths:
97+
path_str = str(path)
98+
flags = _get_flags(path)
99+
for row in rows:
100+
for combo in extra_combos:
101+
kwargs = {**row, **dict(zip(extra_keys, combo, strict=True))}
102+
result: str = formatter.format(path_str, **kwargs)
103+
if flags is not None:
104+
result = _make_annotated(result, flags)
105+
results.append(result)
106+
except KeyError as e:
107+
raise WildcardError(f"No values given for wildcard {e}.") from e
108+
109+
return list(dict.fromkeys(results))
110+
28111

29112
class BidsDatasetDict(TypedDict):
30113
"""Dict equivalent of BidsInputs, for backwards-compatibility."""
@@ -122,16 +205,11 @@ def expand(
122205
Keywords not found in the path will be ignored. Keywords take values or
123206
lists of values to be expanded over the provided paths.
124207
"""
125-
return sn_expand(
208+
return _expand(
126209
list(itx.always_iterable(paths)),
127-
allow_missing=allow_missing
128-
if isinstance(allow_missing, bool)
129-
else list(itx.always_iterable(allow_missing)),
130-
**{self.entity: list(dict.fromkeys(self._data))},
131-
**{
132-
wildcard: list(itx.always_iterable(v))
133-
for wildcard, v in wildcards.items()
134-
},
210+
zip_lists={self.entity: list(dict.fromkeys(self._data))},
211+
allow_missing=bool(allow_missing),
212+
**wildcards,
135213
)
136214

137215
def filter(
@@ -376,39 +454,18 @@ def expand(
376454
Keywords not found in the path will be ignored. Keywords take values or
377455
lists of values to be expanded over the provided paths.
378456
"""
379-
380-
def sequencify(item: bool | str | Iterable[str]) -> bool | list[str]:
381-
if isinstance(item, bool):
382-
return item
383-
return list(itx.always_iterable(item))
384-
385-
allow_missing_seq = sequencify(allow_missing)
386-
if self.zip_lists:
387-
inner_expand = list(
388-
# order preserving deduplication
389-
dict.fromkeys(
390-
sn_expand(
391-
list(itx.always_iterable(paths)),
392-
zip,
393-
allow_missing=True if wildcards else allow_missing_seq,
394-
**self.zip_lists,
395-
)
396-
)
397-
)
398-
else:
399-
inner_expand = list(itx.always_iterable(paths))
400-
if not wildcards:
401-
return inner_expand
402-
403-
return sn_expand(
404-
inner_expand,
405-
allow_missing=allow_missing_seq,
406-
# Turn all the wildcard items into lists because Snakemake doesn't handle
407-
# iterables very well
408-
**{
409-
wildcard: list(itx.always_iterable(v))
410-
for wildcard, v in wildcards.items()
411-
},
457+
path_list = list(itx.always_iterable(paths))
458+
459+
# When zip_lists is empty and no extra wildcards, return paths as-is
460+
# (no formatting: avoids errors on arbitrary path text with {wildcards})
461+
if not self.zip_lists and not wildcards:
462+
return path_list
463+
464+
return _expand(
465+
path_list,
466+
zip_lists=self.zip_lists,
467+
allow_missing=bool(allow_missing),
468+
**wildcards,
412469
)
413470

414471
def filter(

src/snakebids/utils/snakemake_templates.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,11 @@ class SnakemakeFormatter(string.Formatter):
116116
_UNEXPECTED_OPEN = "unexpected '{' in field name"
117117

118118
@override
119-
def __init__(self) -> None:
119+
def __init__(self, allow_missing: bool = False) -> None:
120120
super().__init__()
121121
self.squelch_underscore = True
122+
self.allow_missing = allow_missing
123+
self._field_constraints: dict[str, str] = {}
122124

123125
@overload
124126
def vformat(
@@ -137,6 +139,7 @@ def vformat(
137139
) -> str:
138140
"""Call base vformat after resetting squelch_underscore."""
139141
self.squelch_underscore = True
142+
self._field_constraints = {}
140143
return super().vformat(format_string, args, kwargs)
141144

142145
@override
@@ -214,6 +217,8 @@ def parse( # noqa: PLR0912
214217
comma = format_string.find(",", i + 1, j)
215218
if comma != -1:
216219
field_name = format_string[i + 1 : comma]
220+
# Store full text (with constraint) for allow_missing support
221+
self._field_constraints[field_name] = format_string[i + 1 : close]
217222
else:
218223
field_name = format_string[i + 1 : j]
219224

@@ -302,8 +307,12 @@ def get_value(
302307
elif key.startswith("_") and key.endswith("_") and len(key) > 1:
303308
entity = key[1:-1]
304309

305-
# Otherwise, not a special key, so error
310+
# Otherwise, not a special key, so error (or preserve if allow_missing)
306311
else:
312+
if self.allow_missing:
313+
# Return original brace-wrapped wildcard including constraint
314+
original = self._field_constraints.get(key, key)
315+
return f"{{{original}}}"
307316
raise KeyError(key)
308317

309318
# Rule 2.5. Handle directory wildcards

tests/test_datasets.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,3 +810,84 @@ def test_order_of_selectors_is_preserved(
810810
assert tuple(dicts.zip_lists[selectors]) == tuple(
811811
itx.unique_everseen(selectors)
812812
)
813+
814+
815+
def _has_annotated_string() -> bool:
816+
try:
817+
from snakemake.io import AnnotatedString # noqa: F401
818+
819+
return True
820+
except ImportError:
821+
return False
822+
823+
824+
class TestExpandNoneHandling:
825+
"""Tests for None value conversion in expand()."""
826+
827+
def test_none_extra_wildcard_treated_as_empty_string(self):
828+
"""None values in extra wildcards are converted to empty string."""
829+
component = BidsComponent(
830+
name="test",
831+
path="sub-{subject}",
832+
zip_lists={"subject": ["001"]},
833+
)
834+
# Path with optional entity that gets a None value
835+
paths = component.expand(
836+
"sub-{subject}{_acq_}{acq}",
837+
acq=None,
838+
)
839+
# None → "" → optional entity not included
840+
assert paths == ["sub-001"]
841+
842+
def test_none_in_extra_wildcard_list(self):
843+
"""None values within a list of extra wildcards are converted to ''."""
844+
component = BidsComponent(
845+
name="test",
846+
path="sub-{subject}",
847+
zip_lists={"subject": ["001"]},
848+
)
849+
paths = component.expand(
850+
"sub-{subject}{_acq_}{acq}",
851+
acq=["mprage", None],
852+
)
853+
assert "sub-001_acq-mprage" in paths
854+
assert "sub-001" in paths
855+
856+
857+
class TestExpandAnnotatedString:
858+
"""Tests for AnnotatedString flag propagation in expand()."""
859+
860+
@pytest.mark.skipif(
861+
not _has_annotated_string(),
862+
reason="snakemake AnnotatedString not available",
863+
)
864+
def test_annotated_string_flags_propagated_to_outputs(self):
865+
"""AnnotatedString flags are preserved on expanded output paths."""
866+
from snakemake.io import AnnotatedString
867+
868+
component = BidsComponent(
869+
name="test",
870+
path="sub-{subject}",
871+
zip_lists={"subject": ["001", "002"]},
872+
)
873+
template = AnnotatedString("sub-{subject}_T1w.nii.gz")
874+
template.flags["temp"] = True
875+
876+
paths = component.expand(template)
877+
for path in paths:
878+
assert isinstance(path, AnnotatedString)
879+
assert path.flags.get("temp") is True
880+
881+
@pytest.mark.skipif(
882+
not _has_annotated_string(),
883+
reason="snakemake AnnotatedString not available",
884+
)
885+
def test_plain_string_not_annotated(self):
886+
"""Plain string templates produce plain string outputs."""
887+
component = BidsComponent(
888+
name="test",
889+
path="sub-{subject}",
890+
zip_lists={"subject": ["001"]},
891+
)
892+
paths = component.expand("sub-{subject}_T1w.nii.gz")
893+
assert type(paths[0]) is str # noqa: E721

tests/test_snakemake_templates/test_snakemake_formatter.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,3 +605,53 @@ def test_format_complex_path_with_constraints(self):
605605
acq="mprage",
606606
)
607607
assert result == "sub-001_acq-mprage_T1w.nii.gz"
608+
609+
610+
class TestAllowMissing:
611+
"""Tests for allow_missing behavior in SnakemakeFormatter."""
612+
613+
def test_allow_missing_preserves_plain_wildcard(self):
614+
"""Missing wildcard with allow_missing=True is preserved as {name}."""
615+
formatter = SnakemakeFormatter(allow_missing=True)
616+
result = formatter.format("sub-{subject}_{run}_T1w.nii.gz", subject="001")
617+
assert result == "sub-001_{run}_T1w.nii.gz"
618+
619+
def test_allow_missing_preserves_wildcard_with_constraint(self):
620+
"""Missing wildcard with constraint is preserved with original {name,constraint}."""
621+
formatter = SnakemakeFormatter(allow_missing=True)
622+
result = formatter.format(
623+
r"sub-{subject,\d+}_{run,\d+}_T1w.nii.gz", subject="001"
624+
)
625+
assert result == r"sub-001_{run,\d+}_T1w.nii.gz"
626+
627+
def test_allow_missing_false_raises_for_missing_wildcard(self):
628+
"""Missing wildcard with allow_missing=False raises KeyError."""
629+
formatter = SnakemakeFormatter(allow_missing=False)
630+
with pytest.raises(KeyError):
631+
formatter.format("sub-{subject}_{run}_T1w.nii.gz", subject="001")
632+
633+
@given(
634+
name=safe_field_names(min_size=1).filter(
635+
# Exclude numeric names (positional args, not named wildcards)
636+
lambda s: not s.isdigit()
637+
),
638+
constraint=constraints(),
639+
)
640+
def test_allow_missing_preserves_constraint_from_template(
641+
self, name: str, constraint: str
642+
):
643+
"""Constraint from template is preserved verbatim in allow_missing output."""
644+
formatter = SnakemakeFormatter(allow_missing=True)
645+
template = f"{{{name},{constraint}}}"
646+
result = formatter.format(template)
647+
assert result == template
648+
649+
def test_allow_missing_resets_between_vformat_calls(self):
650+
"""_field_constraints is reset on each vformat call."""
651+
formatter = SnakemakeFormatter(allow_missing=True)
652+
# First call with constraint
653+
r1 = formatter.format(r"sub-{subject,\d+}")
654+
assert r1 == r"sub-{subject,\d+}"
655+
# Second call without constraint: should use the new (no-constraint) form
656+
r2 = formatter.format("{subject}")
657+
assert r2 == "{subject}"

0 commit comments

Comments
 (0)