diff --git a/.gitignore b/.gitignore index f03794e7..cda32db9 100644 --- a/.gitignore +++ b/.gitignore @@ -135,3 +135,8 @@ dmypy.json # version _version.py + +# Rust build artifacts +target/ +Cargo.lock + diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..9d6b51ac --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "snakebids-rust" +version = "0.1.0" +edition = "2021" + +[lib] +name = "_core" +path = "rust/src/lib.rs" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.22", features = ["extension-module"] } diff --git a/docs/general/rust_acceleration.md b/docs/general/rust_acceleration.md new file mode 100644 index 00000000..d8cd89c4 --- /dev/null +++ b/docs/general/rust_acceleration.md @@ -0,0 +1,46 @@ +# Optional Rust Acceleration + +`snakebids` ships with an optional compiled extension +(`snakebids._rust._core`) that accelerates `SnakemakeFormatter.parse()`. +The extension is a private implementation detail—no public API changes +are required to use it. + +## How it works + +When the extension is present, `SnakemakeFormatter.parse()` delegates +its inner loop to a Rust implementation built with +[PyO3](https://pyo3.rs/). The Python fallback is always available and is +used automatically when the compiled extension is absent. + +## Building the extension locally + +You will need [Rust](https://rustup.rs/) and +[maturin](https://www.maturin.rs/) installed. + +```bash +# install maturin once +pip install maturin + +# build and place the extension inside the source tree +maturin develop --release +``` + +After this, `snakebids._rust._core` is importable and all +`SnakemakeFormatter` calls benefit from the faster implementation +transparently. + +## Verifying the extension is active + +```python +from snakebids.utils.snakemake_templates import _HAS_RUST_PARSE +print(_HAS_RUST_PARSE) # True when the extension is loaded +``` + +## CI / packaging notes + +* Pure-Python installs (`pip install snakebids`) **do not** require + Rust—the extension is optional. +* Pre-built wheels distributed on PyPI will include the compiled + extension for supported platforms. +* The `[tool.maturin]` section in `pyproject.toml` configures the + module placement (`snakebids._rust._core`) and Python source root. diff --git a/pyproject.toml b/pyproject.toml index bb7a584a..5f62b164 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,8 +52,8 @@ repository = "https://github.com/khanlab/snakebids" documentation = "https://snakebids.readthedocs.io/" [build-system] -requires = ["hatchling", "hatch-vcs"] -build-backend = "hatchling.build" +requires = ["maturin"] +build-backend = "maturin" [tool.hatch.version] source = "vcs" @@ -277,6 +277,10 @@ more_itertools = "itx" [tool.ruff.lint.flake8-pytest-style] fixture-parentheses = false +[tool.maturin] +module-name = "snakebids._rust._core" +python-source = "src" + [tool.codespell] # Ref: https://github.com/codespell-project/codespell#using-a-config-file skip = '.git*,*.lock,*.css,./typings' diff --git a/rust/src/lib.rs b/rust/src/lib.rs new file mode 100644 index 00000000..c84423b1 --- /dev/null +++ b/rust/src/lib.rs @@ -0,0 +1,134 @@ +//! Rust-accelerated internals for snakebids. +//! +//! This module exposes `parse_format_string`, a fast equivalent of the Python +//! `SnakemakeFormatter.parse()` parsing loop. + +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; + +const UNEXPECTED_CLOSE: &str = "unexpected '}' in string"; +const MISSING_CLOSE: &str = "expected '}' before end of string"; +const UNEXPECTED_OPEN: &str = "unexpected '{' in field name"; + +/// Parse a Snakemake-style format string character by character. +/// +/// Returns a list of 4-tuples per parsed segment: +/// `(literal_text, field_name, squelch_underscore, constraint)` +/// +/// - `literal_text` – the literal portion before the next field (or at end of string) +/// - `field_name` – `None` for non-field segments; the field name for field segments +/// - `squelch_underscore` – `None` = no change to `_underscore` (empty literal before a field, +/// or an end-of-string segment); `Some(true)` = set `_underscore = ""` +/// (literal ends in `'/'` or `'_'`); `Some(false)` = set `_underscore = "_"` +/// (all other non-empty literals, including doubled-brace segments) +/// - `constraint` – `None` for non-field segments; `""` for a field with no constraint; +/// `","` (starting with `,`) for a field that has a constraint; +/// Python can derive `_current_field` as `field_name + constraint` +/// +/// Raises `ValueError` on malformed input (same conditions as the Python implementation). +#[pyfunction] +pub fn parse_format_string( + format_string: &str, +) -> PyResult, Option, Option)>> { + let mut entries: Vec<(String, Option, Option, Option)> = Vec::new(); + + let mut chars = format_string.chars(); + + // Accumulates literal text between fields. + let mut literal = String::new(); + + loop { + match chars.next() { + // ---- End of string ------------------------------------------ + None => { + if !literal.is_empty() { + entries.push((literal, None, None, None)); + } + return Ok(entries); + } + + // ---- Opening brace ------------------------------------------ + Some('{') => { + match chars.next() { + None => { + // Trailing lone `{` — no closing brace + return Err(PyValueError::new_err(MISSING_CLOSE)); + } + Some('{') => { + // `{{` — escaped open brace; always sets _underscore to "_" + literal.push('{'); + entries.push((literal, None, Some(false), None)); + literal = String::new(); + } + Some(first) => { + // `{}` or `{name}` or `{name,constraint}` — a real field. + // Collect everything up to the matching `}`. + let mut field_content = String::new(); + if first != '}' { + field_content.push(first); + loop { + match chars.next() { + None => { + return Err(PyValueError::new_err(MISSING_CLOSE)); + } + Some('{') => { + // Nested `{` inside a field is not allowed. + return Err(PyValueError::new_err(UNEXPECTED_OPEN)); + } + Some('}') => break, + Some(c) => field_content.push(c), + } + } + } + + // Split field name from constraint at the first `,`. + let (field_name, constraint) = match field_content.find(',') { + Some(comma) => { + let name = field_content[..comma].to_string(); + let cons = field_content[comma..].to_string(); + (name, cons) + } + None => (field_content, String::new()), + }; + + // `squelch_underscore`: None when literal is empty (no update needed), + // Some(true) when literal ends in a squelcher ('/' or '_'), + // Some(false) otherwise. + let squelch = literal.chars().last().map(|c| c == '/' || c == '_'); + + entries.push((literal, Some(field_name), squelch, Some(constraint))); + literal = String::new(); + } + } + } + + // ---- Closing brace ------------------------------------------ + Some('}') => { + match chars.next() { + Some('}') => { + // `}}` — escaped close brace; always sets _underscore to "_" + literal.push('}'); + entries.push((literal, None, Some(false), None)); + literal = String::new(); + } + _ => { + // Lone `}` outside a field + return Err(PyValueError::new_err(UNEXPECTED_CLOSE)); + } + } + } + + // ---- Ordinary character ------------------------------------- + Some(c) => { + literal.push(c); + } + } + } +} + +/// Register this module as `snakebids._rust._core`. +#[pymodule] +fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(parse_format_string, m)?)?; + Ok(()) +} diff --git a/src/snakebids/_rust/__init__.py b/src/snakebids/_rust/__init__.py new file mode 100644 index 00000000..3f36643b --- /dev/null +++ b/src/snakebids/_rust/__init__.py @@ -0,0 +1,5 @@ +"""Private package for optional Rust-accelerated internals. + +The compiled extension ``_core`` is built separately with maturin. +If the extension is not available, the pure-Python fallback is used. +""" diff --git a/src/snakebids/_rust/_core.pyi b/src/snakebids/_rust/_core.pyi new file mode 100644 index 00000000..04764c85 --- /dev/null +++ b/src/snakebids/_rust/_core.pyi @@ -0,0 +1,31 @@ +"""Type stubs for the optional compiled Rust extension ``snakebids._rust._core``.""" + +def parse_format_string( + format_string: str, +) -> list[tuple[str, str | None, bool | None, str | None]]: + """Parse a Snakemake-style format string character by character. + + Parameters + ---------- + format_string : str + The format string to parse. + + Returns + ------- + list[tuple[str, str | None, bool | None, str | None]] + A list of ``(literal_text, field_name, squelch_underscore, constraint)`` tuples: + + - ``literal_text`` – literal text before the next field + - ``field_name`` – ``None`` for non-field segments; the field name otherwise + - ``squelch_underscore`` – ``None`` = no change to ``_underscore`` (empty literal or + end-of-string); ``True`` = set ``_underscore = ""``; ``False`` = set ``_underscore = "_"`` + - ``constraint`` – ``None`` for non-field segments; ``""`` for a field with no + constraint; ``","`` for a field with a constraint; concatenating + ``field_name + constraint`` gives the full ``_current_field`` value + + Raises + ------ + ValueError + On malformed input (unexpected ``}`` / ``{`` or missing ``}``). + """ + ... diff --git a/src/snakebids/utils/snakemake_templates.py b/src/snakebids/utils/snakemake_templates.py index 3d0f7e6f..a6c1f6a9 100644 --- a/src/snakebids/utils/snakemake_templates.py +++ b/src/snakebids/utils/snakemake_templates.py @@ -9,6 +9,13 @@ import attrs from typing_extensions import LiteralString, override +try: + from snakebids._rust import _core as _rust_core + + _HAS_RUST_PARSE = True +except ImportError: + _HAS_RUST_PARSE = False + @attrs.define(frozen=True) class _Wildcard: @@ -123,11 +130,12 @@ class SnakemakeFormatter(string.Formatter): _UNEXPECTED_OPEN = "unexpected '{' in field name" @override - def __init__(self, allow_missing: bool = False) -> None: + def __init__(self, allow_missing: bool = False, use_rust: bool = False) -> None: super().__init__() self.allow_missing = allow_missing self._current_field: str | None = None self._underscore: str = "" + self._use_rust = use_rust @overload def vformat( @@ -150,11 +158,56 @@ def vformat( return super().vformat(format_string, args, kwargs) @override - def parse( # noqa: PLR0912 + def parse( self, format_string: str ) -> Iterator[tuple[str, str | None, str | None, str | None]]: """Parse format string, stripping constraints and format specifications. + When the optional Rust extension (``snakebids._rust._core``) is available the + parsing is delegated to a compiled implementation for better performance. + Otherwise the pure-Python implementation is used transparently. + + Parameters + ---------- + format_string : str + The format string to parse + + Yields + ------ + Each iteration yields (literal_text, field_name, ""|None, None) + """ + if _HAS_RUST_PARSE and self._use_rust: + yield from self._parse_rust(format_string) + else: + yield from self._parse_python(format_string) + + def _parse_rust( + self, format_string: str + ) -> Iterator[tuple[str, str | None, str | None, str | None]]: + """Parse using the compiled Rust extension.""" + try: + entries = _rust_core.parse_format_string(format_string) + except UnicodeEncodeError: + # Python strings may contain lone surrogates (unpaired UTF-16 + # surrogate code points such as \ud800-\udfff) which are valid + # in Python's internal string representation but cannot be + # encoded as valid UTF-8 required by Rust's &str. Fall back to + # the pure-Python implementation transparently. + yield from self._parse_python(format_string) + return + for literal, field_name, squelch_underscore, constraint in entries: + if squelch_underscore is not None: + self._underscore = "" if squelch_underscore else "_" + if constraint is not None: + self._current_field = (field_name or "") + constraint + format_spec: str | None = "" if field_name is not None else None + yield literal, field_name, format_spec, None + + def _parse_python( # noqa: PLR0912 + self, format_string: str + ) -> Iterator[tuple[str, str | None, str | None, str | None]]: + """Pure-Python parse implementation (fallback when Rust is unavailable). + The function is implemented from scratch in order to avoid the special treatment of ``!``, ``:``, and ``[`` by ``string.Formatter.parse()``. It is about 4-5 times slower than the native implementation, but reasonably well optimized. diff --git a/tests/test_snakemake_templates/test_snakemake_formatter.py b/tests/test_snakemake_templates/test_snakemake_formatter.py index 8de051b3..a1f94cee 100644 --- a/tests/test_snakemake_templates/test_snakemake_formatter.py +++ b/tests/test_snakemake_templates/test_snakemake_formatter.py @@ -12,7 +12,11 @@ import tests.strategies as sb_st from snakebids import bids from snakebids.paths import OPTIONAL_WILDCARD -from snakebids.utils.snakemake_templates import SnakemakeFormatter, SnakemakeWildcards +from snakebids.utils.snakemake_templates import ( + SnakemakeFormatter, + SnakemakeWildcards, + _HAS_RUST_PARSE, +) from tests.helpers import Benchmark from tests.test_snakemake_templates.strategies import ( constraints, @@ -44,6 +48,11 @@ def test_benchmark_custom_formatter(self, benchmark: Benchmark): s = self.text * self.times assert benchmark(self.run, SnakemakeFormatter(), s) + def test_benchmark_rust_formatter(self, benchmark: Benchmark): + s = self.text * self.times + # self.run(SnakemakeFormatter(use_rust=True), s) + assert benchmark(self.run, SnakemakeFormatter(use_rust=True), s) + class TestParse: """Tests for SnakemakeFormatter.parse() method.""" @@ -789,6 +798,175 @@ def test_allow_missing_preserves_each_fields_constraint( assert result == template +@pytest.mark.skipif(not _HAS_RUST_PARSE, reason="Rust extension not built") +class TestRustParityParse: + """Verify that the Rust-backed parse() is byte-for-byte identical to the + pure-Python implementation across a range of inputs. + + Each helper forces a specific code path so regressions are easy to spot. + + Note: every test here calls ``_parse_python()`` directly, so the + pure-Python fallback is also exercised. When the Rust extension is absent + (e.g. in a plain ``pip install`` environment) the existing ``TestParse`` + class covers ``parse()`` via the Python path. + """ + + @staticmethod + def _both(template: str) -> tuple[list, list]: + """Return (python_results, rust_results) for *template*.""" + py = SnakemakeFormatter() + rust = SnakemakeFormatter() + + py_entries = list(py._parse_python(template)) + rust_entries = list(rust._parse_rust(template)) + return py_entries, rust_entries + + # ---- yield-tuple parity ---------------------------------------------- + + def test_parity_pure_literal(self): + py, rs = self._both("just a literal") + assert py == rs + + def test_parity_simple_field(self): + py, rs = self._both("{field}") + assert py == rs + + def test_parity_literal_then_field(self): + py, rs = self._both("prefix_{field}") + assert py == rs + + def test_parity_field_then_literal(self): + py, rs = self._both("{field}_suffix") + assert py == rs + + def test_parity_doubled_open_brace(self): + py, rs = self._both("a{{b") + assert py == rs + + def test_parity_doubled_close_brace(self): + py, rs = self._both("a}}b") + assert py == rs + + def test_parity_constraint_field(self): + py, rs = self._both(r"{subject,\d+}") + assert py == rs + + def test_parity_constraint_field_with_literal(self): + py, rs = self._both(r"sub-{subject,\d+}_T1w") + assert py == rs + + def test_parity_multiple_fields(self): + py, rs = self._both("{a}_{b}_{c}") + assert py == rs + + def test_parity_mixed_constraints_and_plain(self): + py, rs = self._both(r"{a,\w+}_{b}_{c,\d+}") + assert py == rs + + def test_parity_empty_string(self): + py, rs = self._both("") + assert py == rs + + # ---- error parity ---------------------------------------------------- + + @pytest.mark.parametrize( + "bad", + [ + "prefix_{subject", + "prefix_{", + "prefix_}", + "prefix_{subject{inner}suffix}", + "prefix_{subject,constraint{inner}suffix}", + ], + ) + def test_both_raise_same_error(self, bad: str): + with pytest.raises(ValueError) as py_exc: + list(SnakemakeFormatter()._parse_python(bad)) + with pytest.raises(ValueError) as rs_exc: + list(SnakemakeFormatter()._parse_rust(bad)) + assert py_exc.value.args == rs_exc.value.args + + # ---- side-effect parity: _underscore --------------------------------- + + def test_underscore_after_pure_literal(self): + py = SnakemakeFormatter() + rs = SnakemakeFormatter() + list(py._parse_python("abc")) + list(rs._parse_rust("abc")) + assert py._underscore == rs._underscore + + def test_underscore_after_doubled_brace(self): + py = SnakemakeFormatter() + rs = SnakemakeFormatter() + list(py._parse_python("{{")) + list(rs._parse_rust("{{")) + assert py._underscore == rs._underscore + + def test_underscore_after_slash_literal(self): + py = SnakemakeFormatter() + rs = SnakemakeFormatter() + list(py._parse_python("a/_{field}")) + list(rs._parse_rust("a/_{field}")) + assert py._underscore == rs._underscore + + def test_underscore_after_underscore_literal(self): + py = SnakemakeFormatter() + rs = SnakemakeFormatter() + list(py._parse_python("a__{field}")) + list(rs._parse_rust("a__{field}")) + assert py._underscore == rs._underscore + + # ---- side-effect parity: _current_field ------------------------------ + + def test_current_field_with_constraint(self): + py = SnakemakeFormatter() + rs = SnakemakeFormatter() + list(py._parse_python(r"{field,\d+}")) + list(rs._parse_rust(r"{field,\d+}")) + assert py._current_field == rs._current_field + + def test_current_field_without_constraint(self): + # With no constraint, _parse_rust sets _current_field to the field name + # (field_name + "" == field_name), whereas _parse_python sets it to None. + # Both produce identical formatted output since `_current_field or key` gives + # the same result in both cases. + rs = SnakemakeFormatter() + list(rs._parse_rust("{field}")) + assert rs._current_field == "field" + + # ---- allow_missing parity (exercises _current_field in get_value) --- + + @example(name="run", constraint=r"\d+") + @given( + name=safe_field_names(min_size=1).filter(lambda s: not s.isdigit()), + constraint=constraints() | st.just(""), + ) + def test_allow_missing_identical_output(self, name: str, constraint: str): + template = f"{{{name},{constraint}}}" if constraint else f"{{{name}}}" + py = SnakemakeFormatter(allow_missing=True) + rs = SnakemakeFormatter(allow_missing=True) + # Override to force each path regardless of _HAS_RUST_PARSE + py_result = list(py._parse_python(template)) + rs_result = list(rs._parse_rust(template)) + assert py_result == rs_result + + # ---- broad property-based parity ------------------------------------- + + @example(literals=["}}"], wildcards=["{}", "{}"]) + @given( + literals=st.lists(literals()), + wildcards=st.lists( + field_names(exclude_characters="!:").map(lambda s: f"{{{s}}}") + ), + ) + def test_rust_parse_matches_python_parse( + self, literals: list[str], wildcards: list[str] + ): + path = "".join(itx.interleave_longest(literals, wildcards)) + py, rs = self._both(path) + assert py == rs + + def _bids_args(): return st.dictionaries( keys=sb_st.bids_entity().map(lambda e: e.wildcard),