diff --git a/GOALS.md b/GOALS.md index f362c45..4ed7ae6 100644 --- a/GOALS.md +++ b/GOALS.md @@ -6,8 +6,12 @@ Be the default benchmark for new process-mining methods. Within 18 months, ## v0 success criteria - 7 datasets fetchable + hash-verified -- 5 tasks with fixed scoring scripts -- `gnn` runs end-to-end as the reference baseline +- 5 tasks with fixed scoring scripts (next-event ✅; remaining-time, outcome, + conformance, bottleneck pending) +- `gnn` runs end-to-end as the reference baseline (Markov reference ✅; + `gnn` integration pending v0.1 dataset machinery) +- End-to-end loop runs on `synthetic-toy` ✅ — split → prefixes → + predict → score, covered by `tests/test_e2e.py` ## v1 success criteria - ≥3 external groups submit to the leaderboard diff --git a/README.md b/README.md index bf05b1d..3071239 100644 --- a/README.md +++ b/README.md @@ -105,14 +105,24 @@ and cache locally. Datasets carry their original licenses (linked in ```bash pip install pm-bench -pm-bench list # available datasets -pm-bench fetch bpi2020 # download + verify hash -pm-bench split bpi2020 --task next-event > split.json +pm-bench list # available datasets +pm-bench split synthetic-toy > split.json # train/val/test case ids +pm-bench prefixes synthetic-toy \ + --split split.json --out prefixes.csv # prediction targets +pm-bench predict synthetic-toy \ + --split split.json --prefixes prefixes.csv \ + --out predictions.csv --baseline markov # reference baseline pm-bench score predictions.csv \ - --task next-event --dataset bpi2020 --split split.json -pm-bench leaderboard --task next-event # current standings + --prefixes prefixes.csv --task next-event # top-1 / top-3 ``` +The full loop (`split → prefixes → predict → score`) runs end-to-end on +`synthetic-toy` today; it's covered by `tests/test_e2e.py` and locks +the file formats the leaderboard depends on. BPI / Sepsis / Helpdesk +will use the same commands once v0.1's fetch+cache machinery lands — +4TU's interactive TOS makes the download itself a one-time manual +step, but everything downstream is automated. + The full pipeline: ```mermaid @@ -191,11 +201,14 @@ honesty. The point of the benchmark is to make the comparison real. ## ✦ Roadmap - [x] v0.0 — scaffold, dataset registry, split design +- [x] v0.0.1 — end-to-end loop on `synthetic-toy`: split → prefixes → + predict (Markov) → score, with a smoke test that locks the file + formats - [ ] v0.1 — fetch + cache + hash for all 7 datasets - [ ] v0.2 — splits: next-event, remaining-time - [ ] v0.3 — scoring scripts for all 5 tasks - [ ] v0.4 — leaderboard CI + landing page -- [ ] v0.5 — baselines: `gnn`, transformer, LSTM, Markov +- [ ] v0.5 — baselines: `gnn`, transformer, LSTM, Markov ✅ (Markov shipped) - [ ] v1.0 — first external submissions; cited in ≥1 paper ## ✦ Topics diff --git a/STATUS.md b/STATUS.md new file mode 100644 index 0000000..bc837ee --- /dev/null +++ b/STATUS.md @@ -0,0 +1,64 @@ +# Status + +_Last updated: 2026-04-30._ + +## Where we are + +Pre-v0. The end-to-end loop runs on the bundled `synthetic-toy` +dataset; the seven public datasets are still pending v0.1's fetch + +hash machinery. + +A submission today looks like: + +```bash +pm-bench split synthetic-toy > split.json +pm-bench prefixes synthetic-toy --split split.json --out prefixes.csv +pm-bench predict synthetic-toy --split split.json \ + --prefixes prefixes.csv --out predictions.csv --baseline markov +pm-bench score predictions.csv --prefixes prefixes.csv --task next-event +# → top1 0.976, top3 1.000 (Markov on synthetic-toy) +``` + +That sequence is the contract — it's what `tests/test_e2e.py` runs in +CI, and it's what the leaderboard CI will run once datasets are pinned. + +## Recently shipped + +- **End-to-end loop on synthetic-toy** (`end-to-end-loop` branch). + - `pm_bench/prefixes.py` — extract prediction targets from a split, + write/read CSV. Skips length-1 cases. + - `pm_bench/predictions.py` — predictions CSV format + (`case_id,prefix_idx,predictions`). + - `pm_bench/baselines/markov.py` — first-order Markov reference + baseline. Trained on the train partition only; falls back to + unigram for unseen last-activities. + - CLI gained `prefixes`, `predict`, `score`. The full + `split → prefixes → predict → score` loop now matches what the + README advertises. + - `tests/test_e2e.py` covers the loop end-to-end via the click + runner; format changes will trip it. +- **v0.0** (initial release): scaffold, registry, case-chrono split, + next-event scoring function, CLI `list` / `info` / `split`. + +## Next up + +- **v0.1 — dataset fetch + hash** for the seven public logs. The 4TU + portal needs interactive TOS acceptance per dataset, so the fetch + itself is a one-time manual step; the rest (cache → verify hash → + parse XES → run the same loop) is automated. This is the work that + unblocks every downstream milestone. +- **`gnn` as the second reference baseline** once v0.1 lands. `gnn`'s + v0.5 milestone is symmetrical with this — it's been waiting for a + pinned dataset registry, which `pm-bench` is meant to provide. +- Additional tasks beyond next-event (remaining-time, outcome, + conformance, bottleneck). The split + prefixes machinery is shared; + scoring is the per-task piece. + +## Known gaps + +- No `pm-bench fetch` yet. README still hints at it; the install & + use section now shows the loop that actually works (synthetic-toy + only) so the doc and the CLI line up. +- `predict` currently only knows `markov`. The `--baseline` flag is a + click choice so adding a second is a one-liner, but the second one + worth adding is `gnn`, which depends on v0.1. diff --git a/pm_bench/__init__.py b/pm_bench/__init__.py index a319afe..b53680f 100644 --- a/pm_bench/__init__.py +++ b/pm_bench/__init__.py @@ -3,6 +3,8 @@ __version__ = "0.1.0" +from pm_bench.predictions import Prediction, read_predictions_csv, write_predictions_csv +from pm_bench.prefixes import Prefix, extract_prefixes, read_prefixes_csv, write_prefixes_csv from pm_bench.registry import Dataset, get_dataset, load_registry from pm_bench.score import NextEventScore, score_next_event from pm_bench.split import Event, Split, case_chrono_split @@ -11,9 +13,16 @@ "Dataset", "Event", "NextEventScore", + "Prediction", + "Prefix", "Split", "case_chrono_split", + "extract_prefixes", "get_dataset", "load_registry", + "read_predictions_csv", + "read_prefixes_csv", "score_next_event", + "write_predictions_csv", + "write_prefixes_csv", ] diff --git a/pm_bench/baselines/__init__.py b/pm_bench/baselines/__init__.py new file mode 100644 index 0000000..891922c --- /dev/null +++ b/pm_bench/baselines/__init__.py @@ -0,0 +1,12 @@ +"""Reference baselines that ship with pm-bench. + +Baselines exist to anchor the leaderboard: a submission that loses to +the markov reference is an immediate red flag. They're deliberately +simple — no torch, no scikit-learn, no GPUs, just CPython — so anyone +can read the code and trust the number. +""" +from __future__ import annotations + +from pm_bench.baselines.markov import MarkovBaseline, predict_markov + +__all__ = ["MarkovBaseline", "predict_markov"] diff --git a/pm_bench/baselines/markov.py b/pm_bench/baselines/markov.py new file mode 100644 index 0000000..cd2771b --- /dev/null +++ b/pm_bench/baselines/markov.py @@ -0,0 +1,67 @@ +"""First-order Markov reference baseline. + +Counts (current_activity → next_activity) transitions on training cases +only, then ranks candidates by frequency. Falls back to the global +unigram distribution when a prefix ends in an activity unseen during +training. No smoothing — the leaderboard reports raw frequencies. + +Why first-order: it's the dumbest model that has any business being on +the leaderboard, and it sets the floor any "real" sequence model has to +clear. A transformer that ties or loses to first-order Markov is +broken or overfit. +""" +from __future__ import annotations + +from collections import Counter, defaultdict +from collections.abc import Iterable +from dataclasses import dataclass + +from pm_bench.predictions import Prediction +from pm_bench.prefixes import Prefix +from pm_bench.split import Activity, Event + + +@dataclass +class MarkovBaseline: + transitions: dict[Activity, Counter[Activity]] + unigram: Counter[Activity] + + def rank(self, last_activity: Activity | None) -> list[Activity]: + """Return candidate next activities, best first.""" + if last_activity is not None and last_activity in self.transitions: + counts = self.transitions[last_activity] + if counts: + return [a for a, _ in counts.most_common()] + return [a for a, _ in self.unigram.most_common()] + + +def fit_markov(events: Iterable[Event], train_case_ids: Iterable[Activity]) -> MarkovBaseline: + """Fit a first-order Markov model on the training cases only.""" + keep = set(train_case_ids) + by_case: dict[Activity, list[tuple[Activity, object]]] = {} + for case_id, activity, ts in events: + if case_id not in keep: + continue + by_case.setdefault(case_id, []).append((activity, ts)) + + transitions: dict[Activity, Counter[Activity]] = defaultdict(Counter) + unigram: Counter[Activity] = Counter() + for rows in by_case.values(): + rows.sort(key=lambda r: r[1]) + activities = [a for a, _ in rows] + for a in activities: + unigram[a] += 1 + for prev, nxt in zip(activities, activities[1:], strict=False): + transitions[prev][nxt] += 1 + + return MarkovBaseline(transitions=dict(transitions), unigram=unigram) + + +def predict_markov(model: MarkovBaseline, prefixes: Iterable[Prefix]) -> list[Prediction]: + """Score each prefix with the Markov model.""" + out: list[Prediction] = [] + for p in prefixes: + last = p.prefix[-1] if p.prefix else None + ranked = model.rank(last) + out.append(Prediction(case_id=p.case_id, prefix_idx=p.prefix_idx, ranked=tuple(ranked))) + return out diff --git a/pm_bench/cli.py b/pm_bench/cli.py index bca4a92..3af96e0 100644 --- a/pm_bench/cli.py +++ b/pm_bench/cli.py @@ -7,10 +7,29 @@ import click from pm_bench import _synth +from pm_bench.baselines.markov import fit_markov, predict_markov +from pm_bench.predictions import read_predictions_csv, write_predictions_csv +from pm_bench.prefixes import extract_prefixes, read_prefixes_csv, write_prefixes_csv from pm_bench.registry import get_dataset, load_registry +from pm_bench.score import score_next_event from pm_bench.split import case_chrono_split +def _load_events(name: str) -> list: + """Return a materialized event list for a dataset. + + v0 supports `synthetic-toy` only; other datasets exit with a clear + instruction to wait for v0.1's fetch+cache machinery. + """ + if name != "synthetic-toy": + click.echo( + f"v0 only supports 'synthetic-toy' (got {name}); see README for the v0.1 milestone", + err=True, + ) + sys.exit(1) + return list(_synth.synthetic_log()) + + @click.group() @click.version_option() def main() -> None: @@ -60,13 +79,7 @@ def split(name: str, task: str) -> None: v0 supports `synthetic-toy` only; other datasets require manual fetch. """ - if name != "synthetic-toy": - click.echo( - f"v0 only supports 'synthetic-toy' (got {name}); see README for the v0.1 milestone", - err=True, - ) - sys.exit(1) - events = list(_synth.synthetic_log()) + events = _load_events(name) s = case_chrono_split(events) click.echo( json.dumps( @@ -83,5 +96,131 @@ def split(name: str, task: str) -> None: ) +@main.command() +@click.argument("name") +@click.option( + "--split", + "split_path", + type=click.Path(exists=True, dir_okay=False), + required=True, + help="Path to a split.json emitted by `pm-bench split`.", +) +@click.option( + "--out", + "out_path", + type=click.Path(dir_okay=False), + required=True, + help="Where to write the prefixes CSV.", +) +@click.option( + "--partition", + type=click.Choice(["test", "val", "train"]), + default="test", + show_default=True, + help="Which split partition to emit prefixes for. The leaderboard scores 'test'.", +) +def prefixes(name: str, split_path: str, out_path: str, partition: str) -> None: + """Emit prediction targets (prefix → true-next) for a partition. + + The output is the truth file scoring runs against. Submissions + write a predictions.csv with the same `(case_id, prefix_idx)` keys. + """ + events = _load_events(name) + with open(split_path) as f: + split_data = json.load(f) + case_ids = split_data[partition] + n = write_prefixes_csv(extract_prefixes(events, case_ids), out_path) + click.echo(f"wrote {n} prefixes to {out_path} (partition={partition})") + + +@main.command() +@click.argument("name") +@click.option( + "--split", + "split_path", + type=click.Path(exists=True, dir_okay=False), + required=True, +) +@click.option( + "--prefixes", + "prefixes_path", + type=click.Path(exists=True, dir_okay=False), + required=True, + help="Truth file emitted by `pm-bench prefixes`.", +) +@click.option( + "--out", + "out_path", + type=click.Path(dir_okay=False), + required=True, +) +@click.option( + "--baseline", + type=click.Choice(["markov"]), + default="markov", + show_default=True, +) +def predict( + name: str, + split_path: str, + prefixes_path: str, + out_path: str, + baseline: str, +) -> None: + """Run a reference baseline and emit predictions.csv.""" + events = _load_events(name) + with open(split_path) as f: + split_data = json.load(f) + if baseline != "markov": + # click already restricts the choice; this is a guard for the future. + raise click.UsageError(f"unknown baseline: {baseline}") + model = fit_markov(events, split_data["train"]) + targets = read_prefixes_csv(prefixes_path) + preds = predict_markov(model, targets) + n = write_predictions_csv(preds, out_path) + click.echo(f"wrote {n} predictions to {out_path} (baseline={baseline})") + + +@main.command() +@click.argument("predictions_path", type=click.Path(exists=True, dir_okay=False)) +@click.option( + "--prefixes", + "prefixes_path", + type=click.Path(exists=True, dir_okay=False), + required=True, + help="Truth file from `pm-bench prefixes`.", +) +@click.option("--task", default="next-event", show_default=True) +def score(predictions_path: str, prefixes_path: str, task: str) -> None: + """Score predictions against the truth file.""" + if task != "next-event": + click.echo(f"v0 only scores 'next-event' (got {task})", err=True) + sys.exit(1) + truth_rows = read_prefixes_csv(prefixes_path) + pred_rows = read_predictions_csv(predictions_path) + pred_lookup = {(p.case_id, p.prefix_idx): p.ranked for p in pred_rows} + missing = [ + (t.case_id, t.prefix_idx) + for t in truth_rows + if (t.case_id, t.prefix_idx) not in pred_lookup + ] + if missing: + click.echo( + f"predictions.csv is missing {len(missing)} target(s); " + f"first: {missing[0]}", + err=True, + ) + sys.exit(2) + ranked = [list(pred_lookup[(t.case_id, t.prefix_idx)]) for t in truth_rows] + truth = [t.true_next for t in truth_rows] + s = score_next_event(ranked, truth) + click.echo( + json.dumps( + {"task": task, "top1": s.top1, "top3": s.top3, "n": s.n}, + indent=2, + ), + ) + + if __name__ == "__main__": main() diff --git a/pm_bench/predictions.py b/pm_bench/predictions.py new file mode 100644 index 0000000..5b2b87b --- /dev/null +++ b/pm_bench/predictions.py @@ -0,0 +1,58 @@ +"""Predictions file format. + +A submission writes one row per prediction target, joined to the truth +file (prefixes.csv) on `(case_id, prefix_idx)`: + + case_id,prefix_idx,predictions + +`predictions` is a `|`-joined ranked list of candidate next activities, +best first. Top-1 is `predictions[0]`; top-3 is `predictions[:3]`. +""" +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass + +from pm_bench.prefixes import PREFIX_SEP +from pm_bench.split import Activity, CaseId + + +@dataclass(frozen=True) +class Prediction: + case_id: CaseId + prefix_idx: int + ranked: tuple[Activity, ...] + + +def write_predictions_csv(predictions: Iterable[Prediction], path: str) -> int: + """Write predictions to a CSV file. Returns the number of rows.""" + import csv + + n = 0 + with open(path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(["case_id", "prefix_idx", "predictions"]) + for p in predictions: + w.writerow([p.case_id, p.prefix_idx, PREFIX_SEP.join(p.ranked)]) + n += 1 + return n + + +def read_predictions_csv(path: str) -> list[Prediction]: + """Read a predictions CSV.""" + import csv + + out: list[Prediction] = [] + with open(path, newline="") as f: + r = csv.DictReader(f) + for row in r: + ranked_str = row["predictions"] + ranked = tuple(ranked_str.split(PREFIX_SEP)) if ranked_str else () + out.append( + Prediction( + case_id=row["case_id"], + prefix_idx=int(row["prefix_idx"]), + ranked=ranked, + ) + ) + return out diff --git a/pm_bench/prefixes.py b/pm_bench/prefixes.py new file mode 100644 index 0000000..db8d31b --- /dev/null +++ b/pm_bench/prefixes.py @@ -0,0 +1,100 @@ +"""Prefix extraction — the bridge between split and score. + +For next-event prediction, every test case of length L generates L-1 +prediction targets: prefixes of length 1..L-1, each paired with the +activity that actually came next. Per the suffix-aware split rule, +prefixes are only ever drawn from test cases (never train/val), so the +score is honest about what the model saw at training time. + +The emitted format is a CSV with columns: + + case_id,prefix_idx,prefix,true_next + +where `prefix_idx` is the (1-based) length of the prefix and `prefix` +is `|`-joined activity names. This is the lingua franca file that +predictions are written against. +""" +from __future__ import annotations + +from collections.abc import Iterable, Iterator +from dataclasses import dataclass + +from pm_bench.split import Activity, CaseId, Event + +PREFIX_SEP = "|" + + +@dataclass(frozen=True) +class Prefix: + case_id: CaseId + prefix_idx: int + prefix: tuple[Activity, ...] + true_next: Activity + + +def extract_prefixes( + events: Iterable[Event], + case_ids: Iterable[CaseId], +) -> Iterator[Prefix]: + """Yield prediction targets for the given case ids. + + Events are grouped by `case_id` and ordered by timestamp; for each + case of length L, prefixes of length 1..L-1 are yielded together + with the activity that follows. Cases of length < 2 are skipped + silently (nothing to predict). + """ + keep = set(case_ids) + by_case: dict[CaseId, list[tuple[Activity, object]]] = {} + for case_id, activity, ts in events: + if case_id not in keep: + continue + by_case.setdefault(case_id, []).append((activity, ts)) + + for case_id in keep: + rows = by_case.get(case_id) + if not rows or len(rows) < 2: + continue + rows.sort(key=lambda r: r[1]) + activities = [a for a, _ in rows] + for k in range(1, len(activities)): + yield Prefix( + case_id=case_id, + prefix_idx=k, + prefix=tuple(activities[:k]), + true_next=activities[k], + ) + + +def write_prefixes_csv(prefixes: Iterable[Prefix], path: str) -> int: + """Write prefixes to a CSV file. Returns the number of rows.""" + import csv + + n = 0 + with open(path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(["case_id", "prefix_idx", "prefix", "true_next"]) + for p in prefixes: + w.writerow([p.case_id, p.prefix_idx, PREFIX_SEP.join(p.prefix), p.true_next]) + n += 1 + return n + + +def read_prefixes_csv(path: str) -> list[Prefix]: + """Read a prefixes CSV emitted by `write_prefixes_csv`.""" + import csv + + out: list[Prefix] = [] + with open(path, newline="") as f: + r = csv.DictReader(f) + for row in r: + prefix_str = row["prefix"] + prefix = tuple(prefix_str.split(PREFIX_SEP)) if prefix_str else () + out.append( + Prefix( + case_id=row["case_id"], + prefix_idx=int(row["prefix_idx"]), + prefix=prefix, + true_next=row["true_next"], + ) + ) + return out diff --git a/tests/test_baselines.py b/tests/test_baselines.py new file mode 100644 index 0000000..429b88a --- /dev/null +++ b/tests/test_baselines.py @@ -0,0 +1,40 @@ +import datetime as dt + +from pm_bench import Prefix +from pm_bench.baselines.markov import fit_markov, predict_markov + + +def _events() -> list[tuple[str, str, dt.datetime]]: + base = dt.datetime(2024, 1, 1) + return [ + # Train: c1, c2 — pattern "a→b" 2x, "b→c" 2x. + ("c1", "a", base), + ("c1", "b", base + dt.timedelta(hours=1)), + ("c1", "c", base + dt.timedelta(hours=2)), + ("c2", "a", base), + ("c2", "b", base + dt.timedelta(hours=1)), + ("c2", "c", base + dt.timedelta(hours=2)), + # Test: c3 — same shape. + ("c3", "a", base), + ("c3", "b", base + dt.timedelta(hours=1)), + ("c3", "c", base + dt.timedelta(hours=2)), + ] + + +def test_markov_top1_perfect_on_deterministic_chain() -> None: + model = fit_markov(_events(), ["c1", "c2"]) + targets = [ + Prefix(case_id="c3", prefix_idx=1, prefix=("a",), true_next="b"), + Prefix(case_id="c3", prefix_idx=2, prefix=("a", "b"), true_next="c"), + ] + preds = predict_markov(model, targets) + assert preds[0].ranked[0] == "b" + assert preds[1].ranked[0] == "c" + + +def test_markov_falls_back_to_unigram_for_unseen_last() -> None: + model = fit_markov(_events(), ["c1", "c2"]) + targets = [Prefix(case_id="c3", prefix_idx=1, prefix=("never_seen",), true_next="b")] + preds = predict_markov(model, targets) + # Unigram is non-empty and ranked; just assert we got *some* ranked list. + assert len(preds[0].ranked) > 0 diff --git a/tests/test_e2e.py b/tests/test_e2e.py new file mode 100644 index 0000000..e545390 --- /dev/null +++ b/tests/test_e2e.py @@ -0,0 +1,74 @@ +"""End-to-end smoke: split → prefixes → predict → score on synthetic-toy. + +Locks the file format the leaderboard depends on. If this test changes +shape, leaderboard submissions break — bump the version and announce. +""" +from __future__ import annotations + +import json + +from click.testing import CliRunner + +from pm_bench.cli import main + + +def test_full_pipeline_runs_and_scores(tmp_path) -> None: + runner = CliRunner() + split_path = tmp_path / "split.json" + prefixes_path = tmp_path / "prefixes.csv" + preds_path = tmp_path / "predictions.csv" + + r = runner.invoke(main, ["split", "synthetic-toy"]) + assert r.exit_code == 0, r.output + split_path.write_text(r.output) + + r = runner.invoke( + main, + [ + "prefixes", + "synthetic-toy", + "--split", + str(split_path), + "--out", + str(prefixes_path), + ], + ) + assert r.exit_code == 0, r.output + assert prefixes_path.exists() + + r = runner.invoke( + main, + [ + "predict", + "synthetic-toy", + "--split", + str(split_path), + "--prefixes", + str(prefixes_path), + "--out", + str(preds_path), + "--baseline", + "markov", + ], + ) + assert r.exit_code == 0, r.output + + r = runner.invoke( + main, + [ + "score", + str(preds_path), + "--prefixes", + str(prefixes_path), + "--task", + "next-event", + ], + ) + assert r.exit_code == 0, r.output + result = json.loads(r.output) + assert result["task"] == "next-event" + assert result["n"] > 0 + # Synthetic-toy has tight transitions; markov should clear 50% top-1. + assert result["top1"] >= 0.5 + assert 0.0 <= result["top1"] <= 1.0 + assert result["top3"] >= result["top1"] diff --git a/tests/test_prefixes.py b/tests/test_prefixes.py new file mode 100644 index 0000000..eb8039c --- /dev/null +++ b/tests/test_prefixes.py @@ -0,0 +1,57 @@ +import datetime as dt + +from pm_bench import ( + Prefix, + extract_prefixes, + read_prefixes_csv, + write_prefixes_csv, +) + + +def _events() -> list[tuple[str, str, dt.datetime]]: + base = dt.datetime(2024, 1, 1) + return [ + ("c1", "a", base), + ("c1", "b", base + dt.timedelta(hours=1)), + ("c1", "c", base + dt.timedelta(hours=2)), + ("c2", "x", base), + ("c2", "y", base + dt.timedelta(hours=1)), + ("c3", "solo", base), # length-1 case, gets skipped + ] + + +def test_extract_prefixes_yields_n_minus_1_per_case() -> None: + out = list(extract_prefixes(_events(), ["c1", "c2", "c3"])) + # c1 → 2 targets, c2 → 1 target, c3 (len 1) → 0 + assert len(out) == 3 + + +def test_extract_prefixes_respects_chronology() -> None: + base = dt.datetime(2024, 1, 1) + shuffled = [ + ("c1", "c", base + dt.timedelta(hours=2)), + ("c1", "a", base), + ("c1", "b", base + dt.timedelta(hours=1)), + ] + out = list(extract_prefixes(shuffled, ["c1"])) + assert out[0].prefix == ("a",) + assert out[0].true_next == "b" + assert out[1].prefix == ("a", "b") + assert out[1].true_next == "c" + + +def test_extract_prefixes_filters_to_kept_cases() -> None: + out = list(extract_prefixes(_events(), ["c1"])) + assert {p.case_id for p in out} == {"c1"} + + +def test_round_trip_csv(tmp_path) -> None: + prefixes = [ + Prefix(case_id="c1", prefix_idx=1, prefix=("a",), true_next="b"), + Prefix(case_id="c1", prefix_idx=2, prefix=("a", "b"), true_next="c"), + ] + path = tmp_path / "prefixes.csv" + n = write_prefixes_csv(prefixes, str(path)) + assert n == 2 + back = read_prefixes_csv(str(path)) + assert back == prefixes