diff --git a/src/snakebids/core/datasets.py b/src/snakebids/core/datasets.py index afe8eb2c..47ebc5d0 100644 --- a/src/snakebids/core/datasets.py +++ b/src/snakebids/core/datasets.py @@ -15,11 +15,11 @@ from pvandyken.deprecated import deprecated from typing_extensions import Self, TypedDict +from snakebids.core.expanding import expand as _expand from snakebids.core.filtering import filter_list from snakebids.exceptions import DuplicateComponentError from snakebids.io.console import get_console_size from snakebids.io.printing import format_zip_lists, quote_wrap -from snakebids.snakemake_compat import expand as sn_expand from snakebids.types import ZipList from snakebids.utils.containers import ImmutableList, MultiSelectDict from snakebids.utils.snakemake_templates import MissingEntityError, SnakemakeFormatter @@ -122,16 +122,11 @@ def expand( Keywords not found in the path will be ignored. Keywords take values or lists of values to be expanded over the provided paths. """ - return sn_expand( + return _expand( list(itx.always_iterable(paths)), - allow_missing=allow_missing - if isinstance(allow_missing, bool) - else list(itx.always_iterable(allow_missing)), - **{self.entity: list(dict.fromkeys(self._data))}, - **{ - wildcard: list(itx.always_iterable(v)) - for wildcard, v in wildcards.items() - }, + zip_lists={self.entity: list(dict.fromkeys(self._data))}, + allow_missing=bool(allow_missing), + **wildcards, ) def filter( @@ -349,7 +344,7 @@ def expand( paths: Iterable[Path | str] | Path | str, /, allow_missing: bool | str | Iterable[str] = False, - **wildcards: str | Iterable[str], + **wildcards: Iterable[str | None] | str | None, ) -> list[str]: """Safely expand over given paths with component wildcards. @@ -376,39 +371,18 @@ def expand( Keywords not found in the path will be ignored. Keywords take values or lists of values to be expanded over the provided paths. """ - - def sequencify(item: bool | str | Iterable[str]) -> bool | list[str]: - if isinstance(item, bool): - return item - return list(itx.always_iterable(item)) - - allow_missing_seq = sequencify(allow_missing) - if self.zip_lists: - inner_expand = list( - # order preserving deduplication - dict.fromkeys( - sn_expand( - list(itx.always_iterable(paths)), - zip, - allow_missing=True if wildcards else allow_missing_seq, - **self.zip_lists, - ) - ) - ) - else: - inner_expand = list(itx.always_iterable(paths)) - if not wildcards: - return inner_expand - - return sn_expand( - inner_expand, - allow_missing=allow_missing_seq, - # Turn all the wildcard items into lists because Snakemake doesn't handle - # iterables very well - **{ - wildcard: list(itx.always_iterable(v)) - for wildcard, v in wildcards.items() - }, + path_list = list(itx.always_iterable(paths)) + + # When zip_lists is empty and no extra wildcards, return paths as-is + # (no formatting: avoids errors on arbitrary path text with {wildcards}) + if allow_missing and not self.zip_lists and not wildcards: + return path_list + + return _expand( + path_list, + zip_lists=self.zip_lists, + allow_missing=bool(allow_missing), + **wildcards, ) def filter( @@ -615,7 +589,7 @@ def expand( paths: Iterable[Path | str] | Path | str | None = None, /, allow_missing: bool | str | Iterable[str] = False, - **wildcards: str | Iterable[str], + **wildcards: Iterable[str | None] | str | None, ) -> list[str]: """Safely expand over given paths with component wildcards. @@ -645,7 +619,7 @@ def expand( Keywords not found in the path will be ignored. Keywords take values or lists of values to be expanded over the provided paths. """ - paths = paths or self.path + paths = self.path if paths is None else paths return super().expand(paths, allow_missing, **wildcards) @property diff --git a/src/snakebids/core/expanding.py b/src/snakebids/core/expanding.py new file mode 100644 index 00000000..e279677f --- /dev/null +++ b/src/snakebids/core/expanding.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import itertools as it +from collections.abc import Iterable +from typing import Any + +import more_itertools as itx + +from snakebids.types import ZipListLike +from snakebids.utils.snakemake_templates import MissingEntityError, SnakemakeFormatter + +try: + from snakemake.io import AnnotatedString as _AnnotatedString # type: ignore + +except ImportError: + + def _get_flags(path: Any) -> dict[str, Any] | None: + """Return None when snakemake is not installed (no AnnotatedString support).""" + return None + + def _make_annotated(s: str, flags: dict[str, Any]) -> str: + """Return the string unchanged when snakemake is not installed.""" + return s +else: + + def _get_flags(path: Any) -> dict[str, Any] | None: + """Extract flags from an AnnotatedString path.""" + if isinstance(path, _AnnotatedString) and path.flags: # type: ignore + return dict(path.flags) # type: ignore + return None + + def _make_annotated(s: str, flags: dict[str, Any]) -> str: + """Create an AnnotatedString with the given flags copied from the template.""" + result = _AnnotatedString(s) + result.flags.update(flags) # type: ignore + return result + + +def expand( + paths: Iterable[Any], + zip_lists: ZipListLike, + allow_missing: bool, + **wildcards: Any, +) -> list[str]: + """Expand template paths using SnakemakeFormatter. + + For each template path, iterates over rows of the zip-list and optional + extra wildcard combinations (product), formatting each path per row. + None values in wildcards are converted to "". + Output is order-preserving deduplicated. + """ + formatter = SnakemakeFormatter(allow_missing=allow_missing) + + # Normalize extra wildcards: convert None (scalar or in list) → "", ensure lists + extra: dict[str, list[str]] = {} + for k, vals in wildcards.items(): + if vals is None: + extra[k] = [""] + else: + extra[k] = ["" if v is None else v for v in itx.always_iterable(vals)] + + rows = list(zip(*zip_lists.values(), strict=True)) + + results: list[str] = [] + for path in paths: + path_str = str(path) + flags = _get_flags(path) + for row, *combo in it.product(rows, *extra.values()): + kwargs = { + **dict(zip(zip_lists, row, strict=True)), + **dict(zip(extra, combo, strict=True)), + } + try: + result = formatter.vformat(path_str, (), kwargs) + except MissingEntityError as err: + msg = f"no values given for wildcard {err.entity!r}." + raise KeyError(msg) from err + except KeyError as err: + msg = f"no values given for wildcard {err.args[0]!r}." + raise KeyError(msg) from err + + if flags is not None: + result = _make_annotated(result, flags) + results.append(result) + + return list(dict.fromkeys(results)) diff --git a/src/snakebids/snakemake_compat.py b/src/snakebids/snakemake_compat.py index 1952e7f4..baa173ed 100644 --- a/src/snakebids/snakemake_compat.py +++ b/src/snakebids/snakemake_compat.py @@ -9,7 +9,6 @@ from snakemake import get_argument_parser, main from snakemake.io import load_configfile -from snakemake.exceptions import WildcardError from snakemake.io import expand # Handle different snakemake versions for regex function @@ -22,7 +21,6 @@ __all__ = [ "Snakemake", - "WildcardError", "configfile", "expand", "get_argument_parser", diff --git a/src/snakebids/snakemake_compat.pyi b/src/snakebids/snakemake_compat.pyi index 57668b57..40cb62e5 100644 --- a/src/snakebids/snakemake_compat.pyi +++ b/src/snakebids/snakemake_compat.pyi @@ -8,8 +8,6 @@ from snakemake.common import configfile as configfile # type: ignore configfile: ModuleType -class WildcardError(Exception): ... - def load_configfile(configpath: str) -> dict[str, Any]: "Load a JSON or YAML configfile as a dict, then checks that it's a dict." diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 9d20061e..bbac68dd 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -2,6 +2,7 @@ import copy import functools as ft +import importlib.util import itertools as it import operator as op import re @@ -24,7 +25,6 @@ from snakebids.exceptions import DuplicateComponentError from snakebids.paths._presets import bids from snakebids.paths._utils import OPTIONAL_WILDCARD -from snakebids.snakemake_compat import WildcardError from snakebids.types import Expandable, ZipList from snakebids.utils import sb_itertools as sb_it from snakebids.utils.snakemake_io import glob_wildcards @@ -375,37 +375,28 @@ class TestExpandables: component=sb_st.expandables( restrict_patterns=True, unique=True, blacklist_entities=["extension"] ), - wildcards=st.lists( + wildcards=st.dictionaries( st.text(string.ascii_letters, min_size=1, max_size=10).filter( lambda s: s not in sb_st.valid_entities ), + st.lists( + st.text(sb_st.alphanum, min_size=1, max_size=10), + min_size=1, + max_size=3, + ), min_size=1, max_size=5, ), - data=st.data(), ) def test_expand_with_extra_args_returns_all_paths( - self, component: Expandable, wildcards: list[str], data: st.DataObject + self, component: Expandable, wildcards: dict[str, list[str]] ): - num_wildcards = len(wildcards) - values = data.draw( - st.lists( - st.lists( - st.text(sb_st.alphanum, min_size=1, max_size=10), - min_size=1, - max_size=3, - ), - min_size=num_wildcards, - max_size=num_wildcards, - ) - ) path_tpl = bids( **get_wildcard_dict(component.zip_lists), **get_wildcard_dict(wildcards), ) - wcard_dict = dict(zip(wildcards, values, strict=True)) - zlist = expand_zip_list(component.zip_lists, wcard_dict) - paths = component.expand(path_tpl, **wcard_dict) + zlist = expand_zip_list(component.zip_lists, wildcards) + paths = component.expand(path_tpl, **wildcards) assert zip_list_eq(glob_wildcards(path_tpl, paths), zlist) @given( @@ -421,7 +412,7 @@ def test_expand_over_multiple_paths(self, component: Expandable): @given( component=sb_st.expandables(restrict_patterns=True), - wildcard=st.text(string.ascii_letters, min_size=1, max_size=10).filter( + wildcard=st.text(string.ascii_letters, min_size=1).filter( lambda s: s not in sb_st.valid_entities ), ) @@ -431,11 +422,29 @@ def test_partial_expansion(self, component: Expandable, wildcard: str): ) paths = component.expand(path_tpl, allow_missing=True) for path in paths: - assert re.search(r"\{.+\}", path) + assert len(re.findall(r"\{.+\}", path)) == 1 @given( component=sb_st.expandables(restrict_patterns=True), - wildcard=st.text(string.ascii_letters, min_size=1, max_size=10).filter( + wildcard=st.text(string.ascii_letters, min_size=1).filter( + lambda s: s not in sb_st.valid_entities + ), + ) + def test_partial_expansion_with_constraint( + self, component: Expandable, wildcard: str + ): + path_tpl = bids( + "", + **get_wildcard_dict(component.zip_lists), + **{wildcard: OPTIONAL_WILDCARD}, + ) + paths = component.expand(path_tpl, allow_missing=True) + for path in paths: + assert len(re.findall(r"\{[^}]+,[^}]+\}", path)) == 2 # noqa: PLR2004 + + @given( + component=sb_st.expandables(restrict_patterns=True), + wildcard=st.text(string.ascii_letters, min_size=1).filter( lambda s: s not in sb_st.valid_entities ), ) @@ -443,7 +452,7 @@ def test_prevent_partial_expansion(self, component: Expandable, wildcard: str): path_tpl = bids( **get_wildcard_dict(component.zip_lists), **get_wildcard_dict(wildcard) ) - with pytest.raises(WildcardError): + with pytest.raises(KeyError, match="no values given for wildcard"): component.expand(path_tpl) @given(component=sb_st.expandables(restrict_patterns=True, path_safe=True)) @@ -470,10 +479,14 @@ def test_expand_preserves_entry_order(self, component: Expandable): == path ) - @given(path=st.text()) - def test_expandable_with_no_wildcards_returns_path_unaltered(self, path: str): - component = BidsPartialComponent(zip_lists={}) - assert itx.one(component.expand(path)) == path + @given( + path=st.text().map(lambda s: s.replace("{", "{{").replace("}", "}}")), + component=sb_st.expandables(), + ) + def test_expandable_with_no_wildcards_returns_path_unaltered( + self, path: str, component: BidsComponent + ): + assert itx.one(component.expand(path)) == path.format() @given(component=sb_st.expandables(min_values=0, max_values=0, path_safe=True)) def test_expandable_with_no_entries_returns_empty_list(self, component: Expandable): @@ -516,6 +529,70 @@ def test_not_expand_over_internal_path_when_novel_given( paths = component.expand(novel_path) assert not glob_wildcards(component.path, paths) + @given(component=sb_st.bids_components(), paths=st.just("") | st.lists(st.just(""))) + def test_returns_back_empty_containers( + self, component: BidsComponent, paths: str | list[str] + ): + assert component.expand(paths) == list(set(itx.always_iterable(paths))) + + +class TestExpandNoneHandling: + """Tests for None value conversion in expand().""" + + def test_none_extra_wildcard_treated_as_empty_string(self): + """None values in extra wildcards are converted to empty string.""" + component = BidsComponent( + name="test", + path="sub-{subject}", + zip_lists={"subject": ["001"]}, + ) + # Path with optional entity that gets a None value + paths = component.expand( + "sub-{subject}{_acq_}{acq}", + acq=[None, "MPRAGE"], + ) + # None → "" → optional entity not included + assert paths == ["sub-001", "sub-001_acq-MPRAGE"] + + +def _has_annotated_string(): + return bool(importlib.util.find_spec("snakemake")) + + +@pytest.mark.skipif( + not _has_annotated_string(), + reason="snakemake AnnotatedString not available", +) +class TestExpandAnnotatedString: + """Tests for AnnotatedString flag propagation in expand().""" + + def test_annotated_string_flags_propagated_to_outputs(self): + """AnnotatedString flags are preserved on expanded output paths.""" + from snakemake.io import AnnotatedString # noqa: PLC0415 # type: ignore + + component = BidsComponent( + name="test", + path="sub-{subject}", + zip_lists={"subject": ["001", "002"]}, + ) + template = AnnotatedString("sub-{subject}_T1w.nii.gz") + template.flags["temp"] = True # type: ignore + + paths = component.expand(template) + for path in paths: + assert isinstance(path, AnnotatedString) + assert path.flags.get("temp") is True # type: ignore + + def test_plain_string_not_annotated(self): + """Plain string templates produce plain string outputs.""" + component = BidsComponent( + name="test", + path="sub-{subject}", + zip_lists={"subject": ["001"]}, + ) + paths = component.expand("sub-{subject}_T1w.nii.gz") + assert type(paths[0]) is str + class TestFiltering: def get_filter_dict(