diff --git a/pyproject.toml b/pyproject.toml index d7a96faa..e23decaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,9 @@ dependencies = [ ] [project.optional-dependencies] +optuna = [ + "optuna>=4.0.0", +] # TODO: update these dependencies to the correct versions colab_evals = [ # Core ML/AI Framework diff --git a/rapidfireai/automl/__init__.py b/rapidfireai/automl/__init__.py index 83ed81d8..3f1a0f3b 100644 --- a/rapidfireai/automl/__init__.py +++ b/rapidfireai/automl/__init__.py @@ -6,6 +6,27 @@ from .random_search import RFRandomSearch from .automl_utils import get_flattened_config_leaf, get_runs +# Optuna integration (conditionally available) +try: + from .optuna_search import RFOptuna + _OPTUNA_AVAILABLE = True +except ImportError as _optuna_import_error: + + class RFOptuna: # type: ignore[misc] + """Stub so imports succeed; instantiation explains how to enable Optuna.""" + + def __new__(cls, *args, **kwargs): # noqa: ARG004 + raise ImportError( + "RFOptuna requires Optuna importable from this Python environment. " + "Install into the **same interpreter as your Jupyter kernel**, then restart the kernel:\n" + " python -m pip install optuna\n" + "Check in a notebook cell: import sys; print(sys.executable)\n" + "Original error: " + + str(_optuna_import_error) + ) from _optuna_import_error + + _OPTUNA_AVAILABLE = False + # Import fit mode configs (conditionally available) try: from .model_config import ( @@ -60,6 +81,8 @@ "get_runs", ] +__all__.append("RFOptuna") + # Conditionally add fit mode configs to __all__ if _FIT_CONFIGS_AVAILABLE: __all__.extend([ diff --git a/rapidfireai/automl/base.py b/rapidfireai/automl/base.py index e9b18d9d..b35cf159 100644 --- a/rapidfireai/automl/base.py +++ b/rapidfireai/automl/base.py @@ -1,4 +1,11 @@ -"""Base classes and configurations for AutoML algorithms.""" +"""Base class for AutoML search algorithms. + +Classes +------- +AutoMLAlgorithm + Abstract base subclassed by ``RFGridSearch``, ``RFRandomSearch``, + and ``RFOptuna``. +""" from abc import ABC, abstractmethod from typing import Any @@ -8,7 +15,35 @@ class AutoMLAlgorithm(ABC): - """Base class for AutoML algorithms.""" + """Abstract base class for AutoML search strategies. + + Parameters + ---------- + configs : + Config templates (``RFModelConfig`` for fit, dicts for evals). + Accepts a list, a ``List([...])`` wrapper, or a single object. + create_model_fn : + Legacy parameter (unused). + trainer_type : str or None + ``"SFT"`` / ``"DPO"`` / ``"GRPO"`` for fit mode, ``None`` for evals. + num_runs : int + Number of samples (used by ``RFRandomSearch``). + + Attributes + ---------- + configs : list + mode : str + ``"fit"`` or ``"evals"``. + trainer_type : str or None + num_runs : int + + Methods + ------- + get_runs(seed) -> list[dict] + Return concrete config-leaf dicts. + get_callback(**kwargs) -> ChunkCallback | ShardCallback | None + Return an optional inter-step pruning callback. + """ VALID_TRAINER_TYPES = {"SFT", "DPO", "GRPO"} @@ -19,19 +54,6 @@ def __init__( trainer_type: str | None = None, num_runs: int = 1, ): - """ - Initialize AutoML algorithm with configurations and trainer type. - - Args: - configs: List of configurations (RFModelConfig for fit mode, dict for evals mode) - create_model_fn: Optional function to create models (legacy parameter) - trainer_type: Trainer type ("SFT", "DPO", "GRPO") for fit mode, None for evals mode - num_runs: Number of runs for random search - - Mode detection: - - If trainer_type is provided: fit mode (requires RFModelConfig instances) - - If trainer_type is None: evals mode (requires dict instances) - """ try: self.configs = self._normalize_configs(configs) self.num_runs = num_runs @@ -87,8 +109,27 @@ def _validate_configs(self): f"If you want fit mode, provide a trainer_type." ) + def get_callback(self, **kwargs): + """Return an optional callback for inter-chunk/shard pruning decisions. + + Returns + ------- + ChunkCallback or ShardCallback or None + """ + return None + @abstractmethod def get_runs(self, seed: int) -> list[dict[str, Any]]: - """Generate hyperparameter combinations for different training configurations.""" + """Return concrete config-leaf dicts for the controller. + + Parameters + ---------- + seed : int + Non-negative random seed. + + Returns + ------- + list[dict[str, Any]] + """ if not isinstance(seed, int) or seed < 0: raise AutoMLException("seed must be a non-negative integer") diff --git a/rapidfireai/automl/callbacks.py b/rapidfireai/automl/callbacks.py new file mode 100644 index 00000000..4ff9a32f --- /dev/null +++ b/rapidfireai/automl/callbacks.py @@ -0,0 +1,118 @@ +"""Callback protocols for inter-chunk/shard decision-making during experiments. + +Classes +------- +RunDecision + Dataclass returned by ``ChunkCallback.on_chunk_complete`` (fit mode). +PipelineDecision + Dataclass returned by ``ShardCallback.on_shard_complete`` (evals mode). +ChunkCallback + Protocol for fit-mode inter-chunk pruning callbacks. +ShardCallback + Protocol for evals-mode inter-shard pruning callbacks. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal, Protocol + + +@dataclass +class RunDecision: + """Decision returned by a ``ChunkCallback`` after a fit-mode chunk completes. + + Attributes + ---------- + action : ``"continue"`` or ``"prune"`` + replacement_config : dict or None + Config-leaf dict for a replacement run, or ``None``. + """ + + action: Literal["continue", "prune"] + replacement_config: dict[str, Any] | None = None + + +@dataclass +class PipelineDecision: + """Decision returned by a ``ShardCallback`` after an evals-mode shard completes. + + Attributes + ---------- + action : ``"continue"`` or ``"prune"`` + replacement_config : dict or None + Config-leaf dict for a replacement pipeline, or ``None``. + """ + + action: Literal["continue", "prune"] + replacement_config: dict[str, Any] | None = None + + +class ChunkCallback(Protocol): + """Protocol for callbacks invoked after each chunk in fit mode. + + Call order: ``register_runs`` → ``on_chunk_complete`` (repeated) → ``finalize``. + """ + + def register_runs(self, run_id_to_config: dict[int, dict[str, Any]]) -> None: + """Map newly created DB run IDs to their config dicts.""" + ... + + def on_chunk_complete( + self, + run_id: int, + chunk_id: int, + metrics: dict[str, Any], + ) -> RunDecision: + """Evaluate a run after it finishes a chunk. + + Parameters + ---------- + run_id : int + chunk_id : int + metrics : dict[str, Any] + + Returns + ------- + RunDecision + """ + ... + + def finalize(self, final_metrics: dict[int, dict[str, Any]]) -> None: + """Called after the experiment loop ends.""" + ... + + +class ShardCallback(Protocol): + """Protocol for callbacks invoked after each shard in evals mode. + + Call order: ``register_pipelines`` → ``on_shard_complete`` (repeated) → ``finalize``. + """ + + def register_pipelines(self, pipeline_id_to_config: dict[int, dict[str, Any]]) -> None: + """Map newly created DB pipeline IDs to their config dicts.""" + ... + + def on_shard_complete( + self, + pipeline_id: int, + shard_id: int, + metrics: dict[str, Any], + ) -> PipelineDecision: + """Evaluate a pipeline after it finishes a shard. + + Parameters + ---------- + pipeline_id : int + shard_id : int + metrics : dict[str, Any] + + Returns + ------- + PipelineDecision + """ + ... + + def finalize(self, final_metrics: dict[int, dict[str, Any]]) -> None: + """Called after the experiment loop ends.""" + ... diff --git a/rapidfireai/automl/datatypes.py b/rapidfireai/automl/datatypes.py index 1ffa6de6..3c31edd5 100644 --- a/rapidfireai/automl/datatypes.py +++ b/rapidfireai/automl/datatypes.py @@ -1,16 +1,46 @@ -"""Contains classes for representing hyperparameter data types.""" +"""Contains classes for representing hyperparameter data types. -import random +Covers all Optuna distribution types: + +- ``Range(start, end, dtype="float")`` → ``FloatDistribution`` / ``suggest_float`` +- ``Range(start, end, dtype="float", log=True)`` → log-uniform float +- ``Range(start, end, dtype="float", step=0.1)`` → discrete float +- ``Range(start, end, dtype="int")`` → ``IntDistribution`` / ``suggest_int`` +- ``Range(start, end, dtype="int", log=True)`` → log-uniform int +- ``Range(start, end, dtype="int", step=2)`` → stepped int +- ``List([...])`` → ``CategoricalDistribution`` / ``suggest_categorical`` +""" -# TODO: need to set seed for random module. -# TODO: List.sample() will not work for nested lists. -# TODO: add support for sampling methods like 'uniform' and 'loguniform'. +import math +import random class Range: - """Represents a range of values for a hyperparameter.""" + """Represents a range of values for a hyperparameter. + + Supports uniform, log-uniform, and discrete (stepped) sampling for both + int and float dtypes — matching all variants of Optuna's + ``IntDistribution`` and ``FloatDistribution``. + + Args: + start: Lower bound (inclusive). + end: Upper bound (inclusive). + dtype: ``"int"`` or ``"float"``. Inferred from *start*/*end* types + when not provided. + log: If ``True``, sample in log-space (start and end must be > 0). + Mutually exclusive with *step*. + step: Discretisation step. When set, sampled values are multiples of + *step* starting from *start*. Mutually exclusive with *log*. + """ - def __init__(self, start, end, dtype: str | None = None): + def __init__( + self, + start, + end, + dtype: str | None = None, + log: bool = False, + step: int | float | None = None, + ): if dtype is None: self.dtype = ( "int" if isinstance(start, int) and isinstance(end, int) else "float" @@ -21,13 +51,44 @@ def __init__(self, start, end, dtype: str | None = None): self.dtype = dtype if not (isinstance(start, int | float) and isinstance(end, int | float)): raise ValueError("start and end must be either int or float.") + if log and step is not None: + raise ValueError( + "log=True and step are mutually exclusive " + "(Optuna does not support this combination either)." + ) + if log and (start <= 0 or end <= 0): + raise ValueError( + "log=True requires both start and end to be > 0." + ) self.start = start self.end = end + self.log = log + self.step = step def sample(self): - """Sample a value from the range [self.start, self.end].""" + """Sample a value from the range [self.start, self.end]. + + Respects *log* (log-uniform) and *step* (discrete) settings so that + ``RFRandomSearch`` produces the same family of distributions as + Optuna's ``suggest_int`` / ``suggest_float``. + """ if self.dtype == "int": + if self.log: + log_low, log_high = math.log(self.start), math.log(self.end) + return int(round(math.exp(random.uniform(log_low, log_high)))) + if self.step is not None: + step = int(self.step) + n_steps = (self.end - self.start) // step + return self.start + random.randint(0, n_steps) * step return random.randint(self.start, self.end) + + # dtype == "float" + if self.log: + log_low, log_high = math.log(self.start), math.log(self.end) + return math.exp(random.uniform(log_low, log_high)) + if self.step is not None: + n_steps = int((self.end - self.start) / self.step) + return self.start + random.randint(0, n_steps) * self.step return random.uniform(self.start, self.end) diff --git a/rapidfireai/automl/optuna_search.py b/rapidfireai/automl/optuna_search.py new file mode 100644 index 00000000..9805b6a9 --- /dev/null +++ b/rapidfireai/automl/optuna_search.py @@ -0,0 +1,1309 @@ +"""Optuna-based hyperparameter optimization integrated with RapidFire's chunk/shard loop. + +Classes +------- +RFOptuna + User-facing ``AutoMLAlgorithm`` subclass. Drop-in replacement for + ``RFGridSearch`` / ``RFRandomSearch``. +OptunaChunkCallback + ``ChunkCallback`` implementation for fit mode — prunes/replaces runs + between training chunks. +OptunaShardCallback + ``ShardCallback`` implementation for evals mode — prunes/replaces + pipelines between evaluation shards. + +Helper functions handle search-space extraction, Optuna trial sampling, +config-leaf expansion, and metric resolution. +""" + +from __future__ import annotations + +import copy +import math +import statistics +import uuid +from dataclasses import fields, is_dataclass +from typing import Any + +import optuna + +from rapidfireai.automl.base import AutoMLAlgorithm +from rapidfireai.automl.callbacks import ( + ChunkCallback, + PipelineDecision, + RunDecision, + ShardCallback, +) +from rapidfireai.automl.datatypes import List, Range +from rapidfireai.fit.utils.exceptions import AutoMLException + +# --------------------------------------------------------------------------- +# Optuna Trial helpers (API compatibility across Optuna versions) +# --------------------------------------------------------------------------- + + +def _trial_state_from_storage(study: optuna.Study, trial: optuna.Trial) -> optuna.trial.TrialState: + """Return the stored state for *trial*. + + ``Trial`` instances returned by :meth:`~optuna.study.Study.ask` do not always + expose a ``state`` attribute (e.g. recent Optuna releases); use frozen trials + from the study storage instead. + """ + for frozen in study.get_trials(deepcopy=False): + if frozen.number == trial.number: + return frozen.state + raise AutoMLException( + f"Could not resolve Optuna trial state for trial number {trial.number}" + ) + + +# When the primary objective (e.g. eval_loss) is never logged — common on tiny +# runs where eval may not fire — try common Trainer / MLflow key aliases. +_OBJECTIVE_ALIAS_KEYS: dict[str, tuple[str, ...]] = { + "eval_loss": ("eval/loss", "eval-loss", "validation_loss", "train_loss", "loss"), +} + + +def _ordered_objective_keys(primary: str) -> tuple[str, ...]: + keys = [primary] + seen = {primary} + for alias in _OBJECTIVE_ALIAS_KEYS.get(primary, ()): + if alias not in seen: + seen.add(alias) + keys.append(alias) + return tuple(keys) + + +def _float_from_logged_metric_value(raw: Any) -> float | None: + """Parse a scalar from MLflow-style history or a plain numeric. Returns ``None`` on failure.""" + if raw is None: + return None + if isinstance(raw, list) and raw: + last = raw[-1] + if isinstance(last, (list, tuple)) and len(last) >= 2: + return float(last[1]) + if isinstance(last, dict) and "value" in last: + return float(last["value"]) + if isinstance(last, (int, float)): + return float(last) + return None + if isinstance(raw, dict) and "value" in raw: + return float(raw["value"]) + if isinstance(raw, (int, float)): + return float(raw) + return None + + +def _resolve_scalar_for_objective(metrics: dict[str, Any], objective_metric: str) -> float | None: + """Return a scalar for *objective_metric*, trying known aliases as fallbacks.""" + for key in _ordered_objective_keys(objective_metric): + val = _float_from_logged_metric_value(metrics.get(key)) + if val is not None: + return val + return None + + +def _resolve_metric_history(metrics: dict[str, Any], objective_metric: str) -> list[tuple[int, float]]: + """Return the full ``(step, value)`` history for the objective metric. + + Tries the primary key first, then known aliases. Returns an empty list + when no history is available. Handles MLflow-style ``[(step, value), ...]`` + lists, plain numeric scalars, and bare lists of numbers. + """ + for key in _ordered_objective_keys(objective_metric): + raw = metrics.get(key) + if raw is None: + continue + if isinstance(raw, list) and raw: + history: list[tuple[int, float]] = [] + for entry in raw: + if isinstance(entry, (list, tuple)) and len(entry) >= 2: + history.append((int(entry[0]), float(entry[1]))) + elif isinstance(entry, (int, float)): + history.append((len(history), float(entry))) + if history: + return sorted(history, key=lambda x: x[0]) + if isinstance(raw, (int, float)): + return [(0, float(raw))] + return [] + + +# --------------------------------------------------------------------------- +# Multi-objective helpers +# --------------------------------------------------------------------------- + + +def _pareto_dominates(a: list[float], b: list[float], directions: list[str]) -> bool: + """Return True if solution *a* Pareto-dominates solution *b*. + + *a* dominates *b* when it is at least as good in every objective and + strictly better in at least one. + """ + strictly_better = False + for va, vb, d in zip(a, b, directions): + if d == "minimize": + if va > vb: + return False + if va < vb: + strictly_better = True + else: + if va < vb: + return False + if va > vb: + strictly_better = True + return strictly_better + + +def _resolve_multi_objectives( + metrics: dict[str, Any], + objective_metrics: list[str], +) -> list[float] | None: + """Resolve a value for each objective metric. Returns ``None`` if any is missing.""" + values: list[float] = [] + for metric in objective_metrics: + v = _resolve_scalar_for_objective(metrics, metric) + if v is None: + return None + values.append(v) + return values + + +# --------------------------------------------------------------------------- +# Search-space extraction and sampling +# --------------------------------------------------------------------------- + + +def _extract_search_space( + obj: Any, + prefix: str = "", +) -> list[tuple[str, Range | List]]: + """Walk a config template and collect all Range/List parameters. + + Returns a flat list of ``(dotted_path, Range_or_List)`` tuples. The + traversal mirrors ``recursive_expand_gridsearch`` so the same config + structures that work with ``RFGridSearch`` / ``RFRandomSearch`` also work + here (including ``RFModelConfig`` dataclass templates with nested + ``peft_config`` / ``training_args`` objects). + """ + params: list[tuple[str, Range | List]] = [] + + if isinstance(obj, (Range, List)): + params.append((prefix, obj)) + elif hasattr(obj, "_user_params"): + params.extend(_extract_search_space(obj._user_params, prefix)) + elif isinstance(obj, dict): + for key, value in obj.items(): + child_prefix = f"{prefix}.{key}" if prefix else key + params.extend(_extract_search_space(value, child_prefix)) + elif is_dataclass(obj) and not isinstance(obj, type): + # RFModelConfig and other templates are dataclasses without _user_params; + # nested Range/List live under peft_config / training_args / dict fields. + for f in fields(obj): + value = getattr(obj, f.name) + child_prefix = f"{prefix}.{f.name}" if prefix else f.name + params.extend(_extract_search_space(value, child_prefix)) + # Primitive or non-searchable -- skip + return params + + +_PRIMITIVE_TYPES = (type(None), bool, int, float, str) + + +def _object_labels(objects: list[Any]) -> list[str]: + """Build concise labels showing only the attributes that differ across *objects*. + + For example, two ``RecursiveCharacterTextSplitter`` instances that only + differ in ``chunk_size`` produce:: + + ["RecursiveCharacterTextSplitter(chunk_size=256)", + "RecursiveCharacterTextSplitter(chunk_size=128)"] + + Shared defaults (``keep_separator``, ``strip_whitespace``, etc.) are omitted + so the labels stay short and meaningful in Optuna trial output. + """ + per_obj: list[tuple[str, dict[str, Any]]] = [] + for obj in objects: + attrs = {} + for key, val in sorted(vars(obj).items()): + if isinstance(val, _PRIMITIVE_TYPES): + attrs[key.lstrip("_")] = val + per_obj.append((type(obj).__name__, attrs)) + + all_keys: set[str] = set() + for _, attrs in per_obj: + all_keys.update(attrs) + + varying = { + k for k in all_keys + if len({attrs.get(k) for _, attrs in per_obj}) > 1 + } + if not varying: + varying = all_keys + + labels: list[str] = [] + for cls_name, attrs in per_obj: + parts = [f"{k}={attrs[k]!r}" for k in sorted(varying) if k in attrs] + labels.append(f"{cls_name}({', '.join(parts)})" if parts else repr(objects[len(labels)])) + return labels + + +def _suggest_value(trial: optuna.Trial, name: str, param: Range | List) -> Any: + """Use an Optuna trial to sample a single value for *param*. + + Maps ``Range`` → ``suggest_int`` / ``suggest_float`` and + ``List`` → ``suggest_categorical``. + """ + if isinstance(param, Range): + if param.dtype == "int": + kwargs: dict[str, Any] = {} + if param.step is not None: + kwargs["step"] = int(param.step) + if param.log: + kwargs["log"] = True + return trial.suggest_int(name, int(param.start), int(param.end), **kwargs) + else: + kwargs = {} + if param.step is not None: + kwargs["step"] = float(param.step) + if param.log: + kwargs["log"] = True + return trial.suggest_float(name, float(param.start), float(param.end), **kwargs) + elif isinstance(param, List): + if all(isinstance(v, _PRIMITIVE_TYPES) for v in param.values): + return trial.suggest_categorical(name, param.values) + if all(isinstance(v, (list, tuple, dict)) for v in param.values): + labels = [repr(v) for v in param.values] + else: + labels = _object_labels(param.values) + if len(set(labels)) < len(labels): + labels = [f"{lbl}#{i}" for i, lbl in enumerate(labels)] + chosen = trial.suggest_categorical(name, labels) + return param.values[labels.index(chosen)] + raise AutoMLException(f"Unsupported search-space type: {type(param)}") + + +def _set_nested(obj: Any, dotted_path: str, value: Any) -> None: + """Set a value inside a nested dict / ``_user_params`` object by dotted path.""" + parts = dotted_path.split(".") + for part in parts[:-1]: + if hasattr(obj, "_user_params"): + obj = obj._user_params + if isinstance(obj, dict): + obj = obj[part] + else: + obj = getattr(obj, part) + + last = parts[-1] + if hasattr(obj, "_user_params"): + obj = obj._user_params + if isinstance(obj, dict): + obj[last] = value + else: + setattr(obj, last, value) + + +def _sample_from_trial( + trial: optuna.Trial, + search_space: list[tuple[str, Range | List]], + config_template: Any, + param_prefix: str = "", +) -> Any: + """Deep-copy *config_template* and replace each Range/List with a sampled value. + + *param_prefix* is prepended to Optuna parameter names (used for multi-template + namespacing so identically-named params in different templates stay distinct). + """ + config = copy.deepcopy(config_template) + for dotted_path, param in search_space: + optuna_name = f"{param_prefix}{dotted_path}" if param_prefix else dotted_path + value = _suggest_value(trial, optuna_name, param) + _set_nested(config, dotted_path, value) + return config + + +def _sample_from_trial_multi( + trial: optuna.Trial, + config_templates: list[Any], + search_spaces: list[list[tuple[str, Range | List]]], +) -> Any: + """Pick a template via Optuna categorical (if >1), then sample its search space. + + Single-template case is identical to ``_sample_from_trial`` (no extra + categorical, no parameter prefix) for full backward compatibility. + """ + if len(config_templates) == 1: + return _sample_from_trial(trial, search_spaces[0], config_templates[0]) + + tidx = trial.suggest_categorical( + "_config_template_idx", list(range(len(config_templates))), + ) + return _sample_from_trial( + trial, + search_spaces[tidx], + config_templates[tidx], + param_prefix=f"_t{tidx}.", + ) + + +# --------------------------------------------------------------------------- +# Helpers to expand a sampled config template into a config leaf +# (mirrors the expansion in grid_search / random_search) +# --------------------------------------------------------------------------- + + +def _template_to_leaf_fit(config_obj: Any, trainer_type: str) -> dict[str, Any]: + """Convert a sampled ``RFModelConfig`` into a flat config-leaf dict for the controller.""" + from rapidfireai.automl.random_search import recursive_expand_randomsearch + + peft_params = ( + {} + if config_obj.peft_config is None + else recursive_expand_randomsearch(config_obj.peft_config._user_params) + ) + training_params = ( + {} + if config_obj.training_args is None + else recursive_expand_randomsearch(config_obj.training_args._user_params) + ) + model_kwargs = ( + {} + if config_obj.model_kwargs is None + else recursive_expand_randomsearch(config_obj.model_kwargs) + ) + ref_model_kwargs = ( + {} + if config_obj.ref_model_kwargs is None + else recursive_expand_randomsearch(config_obj.ref_model_kwargs) + ) + reward_funcs = ( + {} + if config_obj.reward_funcs is None + else recursive_expand_randomsearch(config_obj.reward_funcs) + ) + + excluded_attrs = { + "model_name", + "tokenizer", + "tokenizer_kwargs", + "model_type", + "model_kwargs", + "peft_config", + "training_args", + "ref_model_name", + "ref_model_type", + "ref_model_kwargs", + "reward_funcs", + "num_gpus", + } + additional_kwargs = { + k: v + for k, v in config_obj.__dict__.items() + if k not in excluded_attrs and v is not None + } + + leaf: dict[str, Any] = { + "trainer_type": trainer_type, + "training_args": training_params, + "peft_params": peft_params, + "model_name": config_obj.model_name, + "tokenizer": config_obj.tokenizer, + "tokenizer_kwargs": config_obj.tokenizer_kwargs, + "model_type": config_obj.model_type, + "model_kwargs": model_kwargs, + "additional_kwargs": additional_kwargs, + } + num_gpus = getattr(config_obj, "num_gpus", None) + if num_gpus is not None: + leaf["num_gpus"] = num_gpus + + if trainer_type == "DPO": + leaf["ref_model_config"] = { + "model_name": config_obj.ref_model_name, + "model_type": config_obj.ref_model_type, + "model_kwargs": ref_model_kwargs, + } + elif trainer_type == "GRPO": + leaf["reward_funcs"] = reward_funcs + + return leaf + + +def _template_to_leaf_evals(config_dict: dict[str, Any]) -> dict[str, Any]: + """Convert a sampled evals config dict into a config-leaf dict for the controller.""" + from rapidfireai.automl.random_search import recursive_expand_randomsearch + + pipeline_key = None + for key in ("pipeline", "vllm_config", "openai_config", "gemini_config"): + if key in config_dict: + pipeline_key = key + break + + if pipeline_key is None: + return config_dict + + pipeline = config_dict[pipeline_key] + pipeline_instance = recursive_expand_randomsearch(pipeline) + + additional = { + k: recursive_expand_randomsearch(v) + for k, v in config_dict.items() + if k not in {"pipeline", "vllm_config", "openai_config", "gemini_config"} + and v is not None + } + + return {"pipeline": pipeline_instance, **additional} + + +# --------------------------------------------------------------------------- +# Sampler / pruner factories +# --------------------------------------------------------------------------- + +_SAMPLERS: dict[str, Any] = { + "tpe": lambda seed: optuna.samplers.TPESampler(seed=seed), + "cmaes": lambda seed: optuna.samplers.CmaEsSampler(seed=seed), + "random": lambda seed: optuna.samplers.RandomSampler(seed=seed), +} + +_PRUNERS: dict[str, Any] = { + "median": lambda n_startup: optuna.pruners.MedianPruner(n_startup_trials=n_startup), + "hyperband": lambda n_startup: optuna.pruners.HyperbandPruner(), +} + + +# --------------------------------------------------------------------------- +# Optuna callback implementations +# --------------------------------------------------------------------------- + + +class OptunaChunkCallback: + """``ChunkCallback`` implementation for Optuna-based pruning in fit mode. + + Created by :meth:`RFOptuna.get_callback`. After each training chunk the + controller calls ``on_chunk_complete`` which reports metrics to Optuna + and returns a ``RunDecision`` (continue / prune with optional replacement). + + Parameters + ---------- + study : optuna.Study + search_spaces : list[list[tuple[str, Range | List]]] + Per-template search spaces. + config_templates : list[Any] + Original ``RFModelConfig`` template objects. + trainer_type : str + ``"SFT"`` / ``"DPO"`` / ``"GRPO"``. + budget : int + Max total trials (initial + replacements). + objective_metric : str + Primary metric key (e.g. ``"eval_loss"``). + granularity : str + ``"chunk"`` or ``"epoch"``. + num_chunks : int or None + Total chunks per epoch; required when ``granularity="epoch"``. + objective_metrics : list[str] or None + All metric keys (multi-objective). + directions : list[str] or None + ``"minimize"`` / ``"maximize"`` per metric. + + Methods + ------- + on_chunk_complete(run_id, chunk_id, metrics) -> RunDecision + Evaluate a run after a chunk. + finalize(final_metrics) + Tell remaining RUNNING trials their final objective values. + _remap_pending_trial(db_run_id) + Swap a placeholder key with the real DB run ID after replacement. + """ + + def __init__( + self, + study: optuna.Study, + search_spaces: list[list[tuple[str, Range | List]]], + config_templates: list[Any], + trainer_type: str, + budget: int, + objective_metric: str, + granularity: str = "chunk", + num_chunks: int | None = None, + *, + objective_metrics: list[str] | None = None, + directions: list[str] | None = None, + ): + if granularity not in ("chunk", "epoch"): + raise AutoMLException( + f"granularity must be 'chunk' or 'epoch', got '{granularity}'" + ) + if granularity == "epoch" and (num_chunks is None or num_chunks < 1): + raise AutoMLException( + "num_chunks must be a positive integer when granularity='epoch'" + ) + + self._study = study + self._search_spaces = search_spaces + self._config_templates = config_templates + self._trainer_type = trainer_type + self._budget = budget + self._objective_metric = objective_metric + self._objective_metrics = objective_metrics or [objective_metric] + self._directions = directions or ["minimize"] + self._is_multi_objective = len(self._objective_metrics) > 1 + self._granularity = granularity + self._num_chunks = num_chunks + self._trials: dict[int, optuna.trial.Trial] = {} + self._spawned = 0 + self._last_reported_step: dict[int, int] = {} + self._chunks_since_last_eval: dict[int, int] = {} + self._multi_intermediates: dict[int, dict[int, list[float]]] = {} + self._pruned_run_ids: set[int] = set() + + # -- bookkeeping kept by RFOptuna before handing off -- + + def _set_initial_trials(self, trial_map: dict[int, optuna.trial.Trial], spawned: int) -> None: + """Populate the ``run_id → trial`` mapping and set the spawned count.""" + self._trials.update(trial_map) + self._spawned = spawned + + # -- ChunkCallback protocol -- + + def register_runs(self, run_id_to_config: dict[int, dict[str, Any]]) -> None: + """No-op — initial mapping is handled via ``_set_initial_trials``.""" + pass + + def on_chunk_complete( + self, + run_id: int, + chunk_id: int, + metrics: dict[str, Any], + ) -> RunDecision: + """Evaluate a run after a training chunk. + + Parameters + ---------- + run_id : int + DB run identifier. + chunk_id : int + Zero-based chunk index. + metrics : dict[str, Any] + Metric values (flat scalars, MLflow step histories, or + dict-wrapped values). + + Returns + ------- + RunDecision + """ + trial = self._trials.get(run_id) + if trial is None: + return RunDecision(action="continue") + + if self._is_multi_objective: + return self._on_chunk_complete_multi(run_id, chunk_id, metrics, trial) + + history = _resolve_metric_history(metrics, self._objective_metric) + if not history: + return RunDecision(action="continue") + + last_reported = self._last_reported_step.get(run_id, -1) + for step, value in history: + if step > last_reported: + trial.report(value, step=step) + self._last_reported_step[run_id] = step + + if self._granularity == "epoch": + self._chunks_since_last_eval[run_id] = ( + self._chunks_since_last_eval.get(run_id, 0) + 1 + ) + if self._chunks_since_last_eval[run_id] < self._num_chunks: + return RunDecision(action="continue") + self._chunks_since_last_eval[run_id] = 0 + + if trial.should_prune() or self._should_prune_concurrent(trial): + self._study.tell(trial, state=optuna.trial.TrialState.PRUNED) + replacement = self._maybe_suggest_replacement() + return RunDecision(action="prune", replacement_config=replacement) + + return RunDecision(action="continue") + + def _on_chunk_complete_multi( + self, + run_id: int, + chunk_id: int, + metrics: dict[str, Any], + trial: optuna.Trial, + ) -> RunDecision: + """Multi-objective variant of on_chunk_complete. + + Optuna's built-in pruners and ``trial.report()`` don't support + multi-objective studies, so we track intermediate values ourselves + and use Pareto-dominance-based pruning. + """ + values = _resolve_multi_objectives(metrics, self._objective_metrics) + if values is None: + return RunDecision(action="continue") + + intermediates = self._multi_intermediates.setdefault(run_id, {}) + intermediates[chunk_id] = values + + if self._granularity == "epoch": + self._chunks_since_last_eval[run_id] = ( + self._chunks_since_last_eval.get(run_id, 0) + 1 + ) + if self._chunks_since_last_eval[run_id] < self._num_chunks: + return RunDecision(action="continue") + self._chunks_since_last_eval[run_id] = 0 + + if self._should_prune_pareto(run_id, chunk_id): + self._pruned_run_ids.add(run_id) + self._study.tell(trial, state=optuna.trial.TrialState.PRUNED) + replacement = self._maybe_suggest_replacement() + return RunDecision(action="prune", replacement_config=replacement) + + return RunDecision(action="continue") + + def finalize(self, final_metrics: dict[int, dict[str, Any]]) -> None: + """Tell all remaining RUNNING trials their final objective values. + + Parameters + ---------- + final_metrics : dict[int, dict[str, Any]] + ``run_id → final metrics dict``. + """ + for run_id, trial in self._trials.items(): + if not isinstance(run_id, int): + continue + if _trial_state_from_storage(self._study, trial) == optuna.trial.TrialState.RUNNING: + run_metrics = final_metrics.get(run_id, {}) + if self._is_multi_objective: + values = _resolve_multi_objectives(run_metrics, self._objective_metrics) + if values is not None: + self._study.tell(trial, values=values) + else: + self._study.tell(trial, state=optuna.trial.TrialState.FAIL) + else: + value = self._resolve_metric(run_metrics) + if value is not None: + self._study.tell(trial, values=value) + else: + self._study.tell(trial, state=optuna.trial.TrialState.FAIL) + + # -- internals -- + + def _should_prune_pareto(self, run_id: int, step: int) -> bool: + """Pareto-dominance pruning for multi-objective studies. + + A run is pruned if it is Pareto-dominated by more than half the + *active* (non-pruned) peers at the current step — analogous to + single-objective median pruning. Already-pruned runs are excluded + so their ghost values don't block every subsequent trial. + """ + current_vals = self._multi_intermediates.get(run_id, {}).get(step) + if current_vals is None: + return False + + dominating_peers = 0 + total_peers = 0 + for other_id, other_steps in self._multi_intermediates.items(): + if other_id == run_id: + continue + if other_id in self._pruned_run_ids: + continue + if step not in other_steps: + continue + total_peers += 1 + if _pareto_dominates(other_steps[step], current_vals, self._directions): + dominating_peers += 1 + + if total_peers == 0: + return False + return dominating_peers > total_peers / 2 + + def _should_prune_concurrent(self, trial: optuna.Trial) -> bool: + """Concurrent-aware pruning that compares intermediate values across + ALL trials (RUNNING + COMPLETE). + + Optuna's built-in pruners (MedianPruner, etc.) only compare against + COMPLETE trials, but in RapidFire's concurrent chunk loop every trial + stays RUNNING until ``finalize()``, so the built-in pruner never has + reference data. This method supplements ``trial.should_prune()`` by + checking intermediate values from all peers regardless of state. + """ + all_frozen = self._study.get_trials(deepcopy=False) + + current = None + for ft in all_frozen: + if ft.number == trial.number: + current = ft + break + if current is None or not current.intermediate_values: + return False + + last_step = max(current.intermediate_values.keys()) + values = [v for v in current.intermediate_values.values() if not math.isnan(v)] + if not values: + return False + + minimize = self._study.direction == optuna.study.StudyDirection.MINIMIZE + best_current = min(values) if minimize else max(values) + + peer_values = [] + for ft in all_frozen: + if ft.number == trial.number: + continue + if last_step in ft.intermediate_values: + v = ft.intermediate_values[last_step] + if not math.isnan(v): + peer_values.append(v) + + if not peer_values: + return False + + median_val = statistics.median(peer_values) + if minimize: + return best_current > median_val + return best_current < median_val + + def _resolve_metric(self, metrics: dict[str, Any]) -> float | None: + """Extract the objective metric value from a metrics dict. + + Supports both flat dicts (``{"eval_loss": 0.5}``) and MLflow-style + histories (``{"eval_loss": [(step, value), ...]}``) by taking the + last recorded value. If the primary objective is missing, tries aliases + (e.g. ``eval_loss`` → ``train_loss``) so small SFT runs still finalize. + """ + return _resolve_scalar_for_objective(metrics, self._objective_metric) + + def _maybe_suggest_replacement(self) -> dict[str, Any] | None: + """Ask Optuna for a new trial and return a config leaf, or ``None`` if budget exhausted.""" + if self._spawned >= self._budget: + return None + + new_trial = self._study.ask() + config_obj = _sample_from_trial_multi( + new_trial, self._config_templates, self._search_spaces, + ) + leaf = _template_to_leaf_fit(config_obj, self._trainer_type) + + placeholder_id = f"_optuna_pending_{uuid.uuid4().hex[:8]}" + self._trials[placeholder_id] = new_trial + self._spawned += 1 + return leaf + + def _remap_pending_trial(self, db_run_id: int) -> None: + """Replace a placeholder trial key with the real DB run ID after replacement.""" + pending = [k for k in self._trials if isinstance(k, str) and k.startswith("_optuna_pending_")] + if pending: + trial = self._trials.pop(pending[0]) + self._trials[db_run_id] = trial + + +class OptunaShardCallback: + """``ShardCallback`` implementation for Optuna-based pruning in evals mode. + + Evals-mode counterpart of :class:`OptunaChunkCallback`. + + Parameters + ---------- + study : optuna.Study + search_spaces : list[list[tuple[str, Range | List]]] + Per-template search spaces. + config_templates : list[dict[str, Any]] + Original evals config template dicts. + budget : int + Max total trials (initial + replacements). + objective_metric : str + Primary metric key. + objective_metrics : list[str] or None + All metric keys (multi-objective). + directions : list[str] or None + ``"minimize"`` / ``"maximize"`` per metric. + + Methods + ------- + on_shard_complete(pipeline_id, shard_id, metrics) -> PipelineDecision + Evaluate a pipeline after a shard. + finalize(final_metrics) + Tell remaining RUNNING trials their final objective values. + _remap_pending_trial(db_pipeline_id) + Swap a placeholder key with the real DB pipeline ID. + """ + + def __init__( + self, + study: optuna.Study, + search_spaces: list[list[tuple[str, Range | List]]], + config_templates: list[dict[str, Any]], + budget: int, + objective_metric: str, + *, + objective_metrics: list[str] | None = None, + directions: list[str] | None = None, + ): + self._study = study + self._search_spaces = search_spaces + self._config_templates = config_templates + self._budget = budget + self._objective_metric = objective_metric + self._objective_metrics = objective_metrics or [objective_metric] + self._directions = directions or ["minimize"] + self._is_multi_objective = len(self._objective_metrics) > 1 + self._trials: dict[int, optuna.trial.Trial] = {} + self._spawned = 0 + self._multi_intermediates: dict[int, dict[int, list[float]]] = {} + self._pruned_run_ids: set[int] = set() + + def _set_initial_trials(self, trial_map: dict[int, optuna.trial.Trial], spawned: int) -> None: + """Populate the pipeline_id → trial mapping from the initial batch.""" + self._trials.update(trial_map) + self._spawned = spawned + + # -- ShardCallback protocol -- + + def register_pipelines(self, pipeline_id_to_config: dict[int, dict[str, Any]]) -> None: + """No-op — initial mapping is handled via ``_set_initial_trials``.""" + pass + + def on_shard_complete( + self, + pipeline_id: int, + shard_id: int, + metrics: dict[str, Any], + ) -> PipelineDecision: + """Evaluate a pipeline after an evaluation shard. + + Parameters + ---------- + pipeline_id : int + DB pipeline identifier. + shard_id : int + Zero-based shard index. + metrics : dict[str, Any] + Cumulative aggregated metrics up to this shard. + + Returns + ------- + PipelineDecision + """ + trial = self._trials.get(pipeline_id) + if trial is None: + return PipelineDecision(action="continue") + + if self._is_multi_objective: + values = _resolve_multi_objectives(metrics, self._objective_metrics) + if values is None: + return PipelineDecision(action="continue") + intermediates = self._multi_intermediates.setdefault(pipeline_id, {}) + intermediates[shard_id] = values + if self._should_prune_pareto(pipeline_id, shard_id): + self._pruned_run_ids.add(pipeline_id) + self._study.tell(trial, state=optuna.trial.TrialState.PRUNED) + replacement = self._maybe_suggest_replacement() + return PipelineDecision(action="prune", replacement_config=replacement) + return PipelineDecision(action="continue") + + metric_value = self._resolve_metric(metrics) + if metric_value is None: + return PipelineDecision(action="continue") + + trial.report(metric_value, step=shard_id) + + if trial.should_prune() or self._should_prune_concurrent(trial): + self._study.tell(trial, state=optuna.trial.TrialState.PRUNED) + replacement = self._maybe_suggest_replacement() + return PipelineDecision(action="prune", replacement_config=replacement) + + return PipelineDecision(action="continue") + + def finalize(self, final_metrics: dict[int, dict[str, Any]]) -> None: + """Tell all remaining RUNNING trials their final objective values. + + Parameters + ---------- + final_metrics : dict[int, dict[str, Any]] + ``pipeline_id → final metrics dict``. + """ + for pipeline_id, trial in self._trials.items(): + if not isinstance(pipeline_id, int): + continue + if _trial_state_from_storage(self._study, trial) == optuna.trial.TrialState.RUNNING: + pm = final_metrics.get(pipeline_id, {}) + if self._is_multi_objective: + values = _resolve_multi_objectives(pm, self._objective_metrics) + if values is not None: + self._study.tell(trial, values=values) + else: + self._study.tell(trial, state=optuna.trial.TrialState.FAIL) + else: + value = self._resolve_metric(pm) + if value is not None: + self._study.tell(trial, values=value) + else: + self._study.tell(trial, state=optuna.trial.TrialState.FAIL) + + # -- internals -- + + def _should_prune_pareto(self, pipeline_id: int, step: int) -> bool: + """Pareto-dominance pruning for multi-objective studies. + + Only compares against active (non-pruned) peers so ghost values + from already-pruned pipelines don't block subsequent trials. + """ + current_vals = self._multi_intermediates.get(pipeline_id, {}).get(step) + if current_vals is None: + return False + + dominating_peers = 0 + total_peers = 0 + for other_id, other_steps in self._multi_intermediates.items(): + if other_id == pipeline_id: + continue + if other_id in self._pruned_run_ids: + continue + if step not in other_steps: + continue + total_peers += 1 + if _pareto_dominates(other_steps[step], current_vals, self._directions): + dominating_peers += 1 + + if total_peers == 0: + return False + return dominating_peers > total_peers / 2 + + def _should_prune_concurrent(self, trial: optuna.Trial) -> bool: + """Same concurrent-aware pruning as OptunaChunkCallback.""" + all_frozen = self._study.get_trials(deepcopy=False) + + current = None + for ft in all_frozen: + if ft.number == trial.number: + current = ft + break + if current is None or not current.intermediate_values: + return False + + last_step = max(current.intermediate_values.keys()) + current_value = current.intermediate_values[last_step] + if math.isnan(current_value): + return True + + peer_values = [] + for ft in all_frozen: + if ft.number == trial.number: + continue + if last_step in ft.intermediate_values: + v = ft.intermediate_values[last_step] + if not math.isnan(v): + peer_values.append(v) + + if not peer_values: + return False + + median_val = statistics.median(peer_values) + minimize = self._study.direction == optuna.study.StudyDirection.MINIMIZE + if minimize: + return current_value > median_val + return current_value < median_val + + def _resolve_metric(self, metrics: dict[str, Any]) -> float | None: + """Extract the objective metric value from a metrics dict.""" + direct = _resolve_scalar_for_objective(metrics, self._objective_metric) + if direct is not None: + return direct + raw = metrics.get(self._objective_metric) + if raw is None: + for key, val in metrics.items(): + if isinstance(val, dict) and "value" in val: + if key.lower().replace("_", "").replace(" ", "") == self._objective_metric.lower().replace("_", "").replace(" ", ""): + return float(val["value"]) + return None + if isinstance(raw, dict) and "value" in raw: + return float(raw["value"]) + if isinstance(raw, (int, float)): + return float(raw) + return None + + def _maybe_suggest_replacement(self) -> dict[str, Any] | None: + """Ask Optuna for a new trial and return an evals config leaf, or ``None`` if budget exhausted.""" + if self._spawned >= self._budget: + return None + + new_trial = self._study.ask() + config_dict = _sample_from_trial_multi( + new_trial, self._config_templates, self._search_spaces, + ) + leaf = _template_to_leaf_evals(config_dict) + + placeholder_id = f"_optuna_pending_{uuid.uuid4().hex[:8]}" + self._trials[placeholder_id] = new_trial + self._spawned += 1 + return leaf + + def _remap_pending_trial(self, db_pipeline_id: int) -> None: + """Replace a placeholder trial key with the real DB pipeline ID after replacement.""" + pending = [k for k in self._trials if isinstance(k, str) and k.startswith("_optuna_pending_")] + if pending: + trial = self._trials.pop(pending[0]) + self._trials[db_pipeline_id] = trial + + +# --------------------------------------------------------------------------- +# RFOptuna — user-facing AutoMLAlgorithm +# --------------------------------------------------------------------------- + + +class RFOptuna(AutoMLAlgorithm): + """Optuna-powered hyperparameter search for RapidFire AI. + + Drop-in replacement for ``RFGridSearch`` / ``RFRandomSearch`` that uses + Optuna's ask-and-tell API. Supports single and multi-objective + optimisation, adaptive pruning, and budget-controlled trial replacement. + + When a run is pruned (stopped early due to poor intermediate metrics), + Optuna automatically generates a replacement config via ``study.ask()`` + so the GPU slot is reused with a better-informed suggestion. This + continues until ``budget`` total trials have been created. + + Parameters + ---------- + configs : + One or more config templates containing ``Range`` / ``List`` + search-space definitions. Accepts a plain list, a ``List([...])`` + wrapper, or a single template. When multiple templates are + provided, Optuna treats the template choice as a categorical + hyperparameter. + trainer_type : str or None + ``"SFT"`` / ``"DPO"`` / ``"GRPO"`` for fit mode, ``None`` for evals + mode. + n_initial : int + Number of configs to generate up-front via ``study.ask()``. + budget : int + Maximum total trials (initial + replacements). Clamped to + ``max(budget, n_initial)``. Set ``budget == n_initial`` to disable + replacement. + objective : str + ``"minimize:eval_loss"`` or ``"maximize:accuracy"`` for + single-objective. ``"maximize:rougeL,maximize:bleu"`` + (comma-separated) for multi-objective. + sampler : str + ``"tpe"`` (default), ``"cmaes"``, or ``"random"``. + pruner : str or None + ``"median"`` (default), ``"hyperband"``, or ``None``. Ignored for + multi-objective studies. + seed : int + Random seed for the Optuna sampler. + granularity : str + ``"chunk"`` (default) or ``"epoch"``. Controls when pruning is + evaluated in fit mode. Ignored in evals mode. + + Methods + ------- + get_runs(seed=42) -> list[dict] + Create the Optuna study and sample ``n_initial`` config leaves. + get_callback(num_chunks=None) -> OptunaChunkCallback | OptunaShardCallback | None + Return the callback wired to the study. Call after ``get_runs()``. + bind_initial_trials(ordered_ids) + Map DB run/pipeline IDs to the Optuna trials from ``get_runs()``. + """ + + def __init__( + self, + configs=None, + trainer_type: str | None = None, + n_initial: int = 16, + budget: int = 40, + objective: str = "minimize:eval_loss", + sampler: str = "tpe", + pruner: str | None = "median", + seed: int = 42, + granularity: str = "chunk", + ): + if granularity not in ("chunk", "epoch"): + raise AutoMLException( + f"granularity must be 'chunk' or 'epoch', got '{granularity}'" + ) + + self.n_initial = n_initial + self.budget = max(budget, n_initial) + self.objective = objective + self.sampler_name = sampler.lower() + self.pruner_name = pruner.lower() if pruner else None + self._seed = seed + self._granularity = granularity + + self._study: optuna.Study | None = None + self._callback: OptunaChunkCallback | OptunaShardCallback | None = None + self._config_templates: list[Any] = [] + self._search_spaces: list[list[tuple[str, Range | List]]] = [] + self._initial_trials: list[optuna.trial.Trial] = [] + + # Parse objective(s) — supports single or comma-separated multi-objective + objectives = [o.strip() for o in objective.split(",")] + self._directions: list[str] = [] + self._objective_metrics: list[str] = [] + for obj_str in objectives: + parts = obj_str.split(":", 1) + if len(parts) != 2 or parts[0] not in ("minimize", "maximize"): + raise AutoMLException( + f"Each objective must be 'minimize:' or " + f"'maximize:', got '{obj_str}'" + ) + self._directions.append(parts[0]) + self._objective_metrics.append(parts[1]) + self._is_multi_objective = len(self._objective_metrics) > 1 + self._direction = self._directions[0] + self._objective_metric = self._objective_metrics[0] + + super().__init__( + configs=configs, + trainer_type=trainer_type, + num_runs=n_initial, + ) + + # -- AutoMLAlgorithm interface -- + + def get_runs(self, seed: int = 42) -> list[dict[str, Any]]: + """Create the Optuna study and sample ``n_initial`` config leaves. + + Parameters + ---------- + seed : int + Fallback seed (instance-level ``seed`` takes precedence). + + Returns + ------- + list[dict[str, Any]] + One config-leaf dict per initial trial. + + Raises + ------ + AutoMLException + If no config templates or no ``Range`` / ``List`` parameters + are found. + """ + if not isinstance(seed, int) or seed < 0: + raise AutoMLException("seed must be a non-negative integer") + + effective_seed = self._seed if self._seed is not None else seed + + if self._is_multi_objective: + self._study = optuna.create_study( + directions=self._directions, + sampler=self._create_sampler(effective_seed), + ) + else: + self._study = optuna.create_study( + direction=self._direction, + sampler=self._create_sampler(effective_seed), + pruner=self._create_pruner(), + ) + optuna.logging.set_verbosity(optuna.logging.WARNING) + + if not self.configs: + raise AutoMLException("At least one config template is required") + + self._config_templates = list(self.configs) + self._search_spaces = [_extract_search_space(t) for t in self._config_templates] + + if not any(self._search_spaces): + raise AutoMLException( + "No Range or List parameters found in any config template. " + "Use Range(...) and List([...]) to define the search space." + ) + + runs: list[dict[str, Any]] = [] + self._initial_trials = [] + + for _ in range(self.n_initial): + trial = self._study.ask() + self._initial_trials.append(trial) + + sampled = _sample_from_trial_multi( + trial, self._config_templates, self._search_spaces, + ) + + if self.mode == "fit": + leaf = _template_to_leaf_fit(sampled, self.trainer_type) + else: + leaf = _template_to_leaf_evals(sampled) + + runs.append(leaf) + + return runs + + def get_callback(self, num_chunks: int | None = None) -> OptunaChunkCallback | OptunaShardCallback | None: + """Return the callback for inter-chunk/shard pruning. Call after ``get_runs()``. + + Parameters + ---------- + num_chunks : int or None + Total chunks per epoch. Only used when ``granularity="epoch"`` + in fit mode so the callback can detect epoch boundaries. + + Returns + ------- + OptunaChunkCallback or OptunaShardCallback or None + """ + if self._study is None: + return None + + if self.mode == "fit": + cb = OptunaChunkCallback( + study=self._study, + search_spaces=self._search_spaces, + config_templates=self._config_templates, + trainer_type=self.trainer_type, + budget=self.budget, + objective_metric=self._objective_metric, + granularity=self._granularity, + num_chunks=num_chunks, + objective_metrics=self._objective_metrics, + directions=self._directions, + ) + else: + cb = OptunaShardCallback( + study=self._study, + search_spaces=self._search_spaces, + config_templates=self._config_templates, + budget=self.budget, + objective_metric=self._objective_metric, + objective_metrics=self._objective_metrics, + directions=self._directions, + ) + + self._callback = cb + return cb + + def bind_initial_trials(self, ordered_ids: list[int]) -> None: + """Map DB run/pipeline IDs to the Optuna trials from ``get_runs()``. + + Parameters + ---------- + ordered_ids : list[int] + DB IDs in the same order as the config leaves from ``get_runs()``. + """ + if self._callback is None: + return + trial_map = {} + for db_id, trial in zip(ordered_ids, self._initial_trials, strict=False): + trial_map[db_id] = trial + self._callback._set_initial_trials(trial_map, spawned=len(self._initial_trials)) + + # -- internal helpers -- + + def _create_sampler(self, seed: int) -> optuna.samplers.BaseSampler: + factory = _SAMPLERS.get(self.sampler_name) + if factory is None: + raise AutoMLException( + f"Unknown sampler '{self.sampler_name}'. " + f"Choose from: {', '.join(_SAMPLERS)}" + ) + return factory(seed) + + def _create_pruner(self) -> optuna.pruners.BasePruner: + if self.pruner_name is None: + return optuna.pruners.NopPruner() + factory = _PRUNERS.get(self.pruner_name) + if factory is None: + raise AutoMLException( + f"Unknown pruner '{self.pruner_name}'. " + f"Choose from: {', '.join(_PRUNERS)}, or None" + ) + n_startup = max(1, self.n_initial // 2) + return factory(n_startup) diff --git a/rapidfireai/evals/scheduling/controller.py b/rapidfireai/evals/scheduling/controller.py index c37c46fc..3ba573a2 100644 --- a/rapidfireai/evals/scheduling/controller.py +++ b/rapidfireai/evals/scheduling/controller.py @@ -541,6 +541,11 @@ def _register_pipelines( rag_spec = getattr(pipeline, "rag", None) if has_rag_attr else None prompt_manager = getattr(pipeline, "prompt_manager", None) + if rag_spec and not rag_spec.experiment_name: + rag_spec.experiment_name = self.experiment_name + if prompt_manager and not getattr(prompt_manager, "experiment_name", None): + prompt_manager.experiment_name = self.experiment_name + # Check if pipeline has RAG or prompt_manager to look up context if rag_spec or prompt_manager: # Get RAG hash if present @@ -814,6 +819,84 @@ def _compute_final_metrics_for_pipelines( progress_display.stop() return final_results + @staticmethod + def _extract_pipeline_info(pipeline_id, pipeline_config): + """Extract display metadata from a pipeline config dict. + + Used both for initial pipelines and dynamically added replacements + (e.g. Optuna pruning replacements) so that final results always + carry full config metadata. + """ + model_name = "Unknown" + text_splitter_cfg = None + embedding_cfg = None + vector_store_cfg = None + search_cfg = None + reranker_cfg = None + sampling_params = None + prompt_manager_k = None + model_config = None + + pipeline = pipeline_config["pipeline"] + if hasattr(pipeline, "model_config") and pipeline.model_config is not None: + if "model" in pipeline.model_config: + model_name = pipeline.model_config["model"] + model_config_copy = pipeline.model_config.copy() + model_config_copy.pop("model", None) + if model_config_copy: + model_config = model_config_copy + + if hasattr(pipeline, "rag") and pipeline.rag is not None: + rag = pipeline.rag + + if hasattr(rag, "text_splitter") and rag.text_splitter is not None: + text_splitter_cfg = rag.get_text_splitter_cfg() + if getattr(rag, "embedding_cls", None) is not None: + cls_name = rag.embedding_cls.__name__ if isinstance(rag.embedding_cls, type) else str(rag.embedding_cls) + embedding_cfg = {"class": cls_name} + if getattr(rag, "embedding_kwargs", None): + embedding_cfg.update(rag.embedding_kwargs) + if getattr(rag, "vector_store_cfg", None) is not None: + vector_store_cfg = dict(rag.vector_store_cfg) + + if getattr(rag, "search_type", None) is not None: + search_cfg = {"type": rag.search_type} + if hasattr(rag, "search_kwargs") and rag.search_kwargs: + allowed_keys = SEARCH_TYPE_KEYS.get(rag.search_type, set(rag.search_kwargs.keys())) + search_cfg.update({k: v for k, v in rag.search_kwargs.items() if k in allowed_keys and v is not None}) + if getattr(rag, "reranker_cls", None) is not None: + reranker_cfg = {"class": rag.reranker_cls.__qualname__ if isinstance(rag.reranker_cls, type) else str(rag.reranker_cls)} + if hasattr(rag, "reranker_kwargs") and rag.reranker_kwargs: + reranker_cfg.update({k: v for k, v in rag.reranker_kwargs.items() if v is not None}) + + if hasattr(pipeline, "sampling_params") and pipeline.sampling_params is not None: + sampling_params = pipeline._user_params.get("sampling_params", None) + + if hasattr(pipeline, "prompt_manager") and pipeline.prompt_manager is not None: + prompt_manager_k = getattr(pipeline.prompt_manager, "k", None) + + info_dict = { + "pipeline_id": pipeline_id, + "pipeline_config": pipeline_config, + "model_name": model_name, + } + + optional_fields = { + "text_splitter_cfg": text_splitter_cfg, + "embedding_cfg": embedding_cfg, + "vector_store_cfg": vector_store_cfg, + "search_cfg": search_cfg, + "reranker_cfg": reranker_cfg, + "sampling_params": sampling_params, + "prompt_manager_k": prompt_manager_k, + "model_config": model_config, + } + for key, value in optional_fields.items(): + if value is not None: + info_dict[key] = value + + return info_dict + def run_multi_pipeline_inference( self, experiment_id: int, @@ -824,6 +907,7 @@ def run_multi_pipeline_inference( num_actors: int = None, num_gpus_per_actor: float = None, num_cpus_per_actor: float = None, + shard_callback=None, ) -> dict[int, tuple[dict, dict]]: """ Run multi-pipeline inference with fair round-robin scheduling. @@ -840,6 +924,7 @@ def run_multi_pipeline_inference( num_actors: Number of query processing actors to spawn num_gpus_per_actor: GPUs per actor (float, e.g. 1.0 or 0.0) num_cpus_per_actor: CPUs per actor (float, e.g. 3.75) + shard_callback: Optional ShardCallback for Optuna integration Returns: Dict mapping pipeline_id to (aggregated_results, cumulative_metrics) tuple @@ -882,9 +967,18 @@ def run_multi_pipeline_inference( self.logger.info(f"Created {num_actors} query processing actors (generic pool)") + # Extract Optuna shard callback from config_group if available + if shard_callback is None and hasattr(config_group, "get_callback"): + shard_callback = config_group.get_callback() + # PHASE 5: Register pipelines in database pipeline_ids, pipeline_id_to_config = self._register_pipelines(config_leaves, db) + # Bind Optuna trials to the newly created DB pipeline IDs + if shard_callback is not None and hasattr(config_group, "bind_initial_trials"): + config_group.bind_initial_trials(pipeline_ids) + self.logger.info(f"Optuna shard callback bound to {len(pipeline_ids)} initial pipelines") + # PHASE 6: Initialize PipelineScheduler scheduler = PipelineScheduler( pipeline_ids=pipeline_ids, @@ -911,93 +1005,10 @@ def run_multi_pipeline_inference( # Initialize progress display table pipeline_info = [] - pipeline_configs = [pipeline_id_to_config[pipeline_id] for pipeline_id in pipeline_ids] - for pipeline_id, pipeline_config in zip(pipeline_ids, pipeline_configs, strict=False): - model_name = "Unknown" - # Indexing-stage fields (read-only in progress display) - text_splitter_cfg = None - embedding_cfg = None - vector_store_cfg = None - # Retrieval-stage fields - search_cfg = None - reranker_cfg = None - # Generation-stage fields - sampling_params = None - prompt_manager_k = None - model_config = None - - # Extract model name from config - pipeline = pipeline_config["pipeline"] - if hasattr(pipeline, "model_config") and pipeline.model_config is not None: - if "model" in pipeline.model_config: - model_name = pipeline.model_config["model"] - model_config_copy = pipeline.model_config.copy() - model_config_copy.pop("model", None) - if model_config_copy: - model_config = model_config_copy - - # Extract ALL RAG fields for display - if hasattr(pipeline, "rag") and pipeline.rag is not None: - rag = pipeline.rag - - # Indexing stage - if hasattr(rag, "text_splitter") and rag.text_splitter is not None: - text_splitter_cfg = rag.get_text_splitter_cfg() - if getattr(rag, "embedding_cls", None) is not None: - cls_name = rag.embedding_cls.__name__ if isinstance(rag.embedding_cls, type) else str(rag.embedding_cls) - embedding_cfg = {"class": cls_name} - if getattr(rag, "embedding_kwargs", None): - embedding_cfg.update(rag.embedding_kwargs) - if getattr(rag, "vector_store_cfg", None) is not None: - vector_store_cfg = dict(rag.vector_store_cfg) - - # Retrieval stage - if getattr(rag, "search_type", None) is not None: - search_cfg = {"type": rag.search_type} - if hasattr(rag, "search_kwargs") and rag.search_kwargs: - allowed_keys = SEARCH_TYPE_KEYS.get(rag.search_type, set(rag.search_kwargs.keys())) - search_cfg.update({k: v for k, v in rag.search_kwargs.items() if k in allowed_keys and v is not None}) - if getattr(rag, "reranker_cls", None) is not None: - reranker_cfg = {"class": rag.reranker_cls.__qualname__ if isinstance(rag.reranker_cls, type) else str(rag.reranker_cls)} - if hasattr(rag, "reranker_kwargs") and rag.reranker_kwargs: - reranker_cfg.update({k: v for k, v in rag.reranker_kwargs.items() if v is not None}) - - # Extract sampling params - if hasattr(pipeline, "sampling_params") and pipeline.sampling_params is not None: - sampling_params = pipeline._user_params.get("sampling_params", None) - - # Extract prompt_manager fields - if hasattr(pipeline, "prompt_manager") and pipeline.prompt_manager is not None: - prompt_manager_k = getattr(pipeline.prompt_manager, "k", None) - - pipeline_info_dict = { - "pipeline_id": pipeline_id, - "pipeline_config": pipeline_config, - "model_name": model_name, - } - - # Add optional fields only if they're not None - # Indexing stage (read-only display) - if text_splitter_cfg is not None: - pipeline_info_dict["text_splitter_cfg"] = text_splitter_cfg - if embedding_cfg is not None: - pipeline_info_dict["embedding_cfg"] = embedding_cfg - if vector_store_cfg is not None: - pipeline_info_dict["vector_store_cfg"] = vector_store_cfg - # Retrieval stage - if search_cfg is not None: - pipeline_info_dict["search_cfg"] = search_cfg - if reranker_cfg is not None: - pipeline_info_dict["reranker_cfg"] = reranker_cfg - # Generation stage - if sampling_params is not None: - pipeline_info_dict["sampling_params"] = sampling_params - if prompt_manager_k is not None: - pipeline_info_dict["prompt_manager_k"] = prompt_manager_k - if model_config is not None: - pipeline_info_dict["model_config"] = model_config - - pipeline_info.append(pipeline_info_dict) + for pipeline_id in pipeline_ids: + pipeline_info.append( + self._extract_pipeline_info(pipeline_id, pipeline_id_to_config[pipeline_id]) + ) progress_display = PipelineProgressDisplay(pipeline_info, num_shards) @@ -1306,6 +1317,79 @@ def run_multi_pipeline_inference( f"({task_info['batch_count']} batches, {duration:.2f}s)" ) + # Optuna shard callback: evaluate pipeline and potentially prune + if ( + shard_callback is not None + and shards_completed < num_shards + ): + try: + cb_metrics = display_metrics if display_metrics else {} + shard_decision = shard_callback.on_shard_complete( + pipeline_id, shard_id, cb_metrics + ) + if shard_decision.action == "prune": + db.set_pipeline_status(pipeline_id, PipelineStatus.STOPPED) + progress_display.update_pipeline(pipeline_id, status="STOPPED") + scheduler.remove_pipeline(pipeline_id) + self.logger.info( + f"Optuna pruned pipeline {pipeline_id} after shard {shard_id}" + ) + if shard_decision.replacement_config: + # Inject experiment_name so the replacement's + # RAG hash matches the cached context built + # during _setup_context_generators. + repl_pipeline = shard_decision.replacement_config.get("pipeline") + if repl_pipeline is not None: + repl_rag = getattr(repl_pipeline, "rag", None) + repl_pm = getattr(repl_pipeline, "prompt_manager", None) + if repl_rag: + repl_rag.experiment_name = self.experiment_name + if repl_pm: + repl_pm.experiment_name = self.experiment_name + + new_ids, new_map = self._register_pipelines( + [shard_decision.replacement_config], db + ) + for new_pid in new_ids: + pipeline_id_to_config[new_pid] = new_map[new_pid] + pipeline_ids.append(new_pid) + agg = Aggregator() + pc = new_map[new_pid] + p = pc["pipeline"] + if hasattr(p, "online_strategy"): + agg.set_online_strategy(**p.online_strategy) + agg.set_total_population_size(total_dataset_size) + pipeline_aggregators[new_pid] = agg + pipeline_results[new_pid] = { + "results": {}, + "metrics": {}, + "start_time": None, + } + scheduler.add_pipeline(new_pid, shards_completed=0) + if hasattr(shard_callback, "_remap_pending_trial"): + shard_callback._remap_pending_trial(new_pid) + + # Add replacement to live progress display + if progress_display: + info = self._extract_pipeline_info(new_pid, pc) + metadata = { + k: v for k, v in info.items() + if k not in ["pipeline_id", "pipeline_config", "model_name"] + } + progress_display.add_pipeline( + pipeline_id=new_pid, + pipeline_config=pc, + model_name=info.get("model_name", "Unknown"), + **metadata, + ) + self.logger.info( + f"Optuna suggested replacement pipeline(s): {new_ids}" + ) + except Exception as e: + self.logger.warning( + f"Optuna shard callback error for pipeline {pipeline_id}: {e}" + ) + # Mark for cleanup completed_actor_ids.append(actor_id) @@ -1564,6 +1648,26 @@ def run_multi_pipeline_inference( db.set_actor_task_status(task_id, TaskStatus.IN_PROGRESS) db.set_pipeline_current_shard(pipeline_id, shard_id) + # Finalize Optuna shard callback with final metrics from all pipelines. + # We must accumulate the raw per-shard metric lists into flat dicts + # (e.g. {"NDCG@5": {"value": 0.15}}) before passing to finalize(), + # because the callback's _resolve_metric expects scalar-valued dicts, + # not the raw aggregated lists stored in pipeline_results. + if shard_callback is not None: + try: + final_cb_metrics: dict[int, dict] = {} + for pid in pipeline_id_to_config: + if pid in pipeline_results and pipeline_results[pid]["metrics"]: + accumulate_fn = pipeline_id_to_config[pid].get("accumulate_metrics_fn") + if accumulate_fn: + final_cb_metrics[pid] = accumulate_fn(pipeline_results[pid]["metrics"]) + else: + final_cb_metrics[pid] = pipeline_results[pid]["metrics"] + shard_callback.finalize(final_cb_metrics) + self.logger.info("Optuna shard callback finalized") + except Exception as e: + self.logger.warning(f"Optuna shard callback finalize error: {e}") + # PHASE 8: Compute final metrics for each pipeline (including dynamically cloned ones). # pipeline_id_to_config contains all pipelines (originals + clones added via _handle_clone). # pipeline_ids only has the originals registered at startup, so we use the config dict keys. @@ -1576,15 +1680,23 @@ def run_multi_pipeline_inference( info_copy = {k: v for k, v in info_dict.items() if k not in ["pipeline_config", "pipeline_id"]} pipeline_id_to_info[pid] = info_copy - # For cloned pipelines (not in pipeline_info), pull their display metadata directly - # from the progress_display which already has it from add_pipeline(). + # For dynamically added pipelines (Optuna replacements, interactive clones) + # not in pipeline_info, extract metadata from the pipeline config itself. + # Fall back to the progress_display for interactive clones that populated it. for pid in all_pipeline_ids: - if pid not in pipeline_id_to_info and progress_display: - clone_metadata = progress_display.pipeline_metadata.get(pid, {}) - clone_data = progress_display.pipeline_data.get(pid, {}) - clone_info = {"model_name": clone_data.get("model", "Unknown")} - clone_info.update(clone_metadata) - pipeline_id_to_info[pid] = clone_info + if pid not in pipeline_id_to_info: + if pid in pipeline_id_to_config: + info_dict = self._extract_pipeline_info(pid, pipeline_id_to_config[pid]) + pipeline_id_to_info[pid] = { + k: v for k, v in info_dict.items() + if k not in ["pipeline_config", "pipeline_id"] + } + elif progress_display: + clone_metadata = progress_display.pipeline_metadata.get(pid, {}) + clone_data = progress_display.pipeline_data.get(pid, {}) + clone_info = {"model_name": clone_data.get("model", "Unknown")} + clone_info.update(clone_metadata) + pipeline_id_to_info[pid] = clone_info final_results = self._compute_final_metrics_for_pipelines( all_pipeline_ids, diff --git a/rapidfireai/experiment.py b/rapidfireai/experiment.py index 05c0d8b2..af4d3a66 100644 --- a/rapidfireai/experiment.py +++ b/rapidfireai/experiment.py @@ -389,6 +389,11 @@ def run_fit( print("⚠️ Training is already running in background. Please wait for it to complete.") return + # Extract chunk callback from param_config if it supports Optuna-style callbacks + chunk_callback = None + if hasattr(param_config, "get_callback"): + chunk_callback = param_config.get_callback() + if ColabConfig.ON_COLAB: # Run Controller in background thread to keep kernel responsive import sys @@ -405,7 +410,7 @@ def _run_controller_background(): try: controller = Controller(self.experiment_id, self.experiment_name) - controller.run_fit(param_config, create_model_fn, train_dataset, eval_dataset, num_chunks, seed, num_gpus, monte_carlo_simulations) + controller.run_fit(param_config, create_model_fn, train_dataset, eval_dataset, num_chunks, seed, num_gpus, monte_carlo_simulations, chunk_callback=chunk_callback) except Exception as e: # Restore stdout for error logging sys.stdout = old_stdout @@ -444,7 +449,7 @@ def _run_controller_background(): # Original blocking behavior for non-Colab environments try: controller = Controller(self.experiment_id, self.experiment_name) - controller.run_fit(param_config, create_model_fn, train_dataset, eval_dataset, num_chunks, seed, num_gpus, monte_carlo_simulations) + controller.run_fit(param_config, create_model_fn, train_dataset, eval_dataset, num_chunks, seed, num_gpus, monte_carlo_simulations, chunk_callback=chunk_callback) except Exception as e: if hasattr(self, "logger"): self.logger.opt(exception=True).error(f"Error running fit: {e}") @@ -534,6 +539,11 @@ def run_evals( # Update experiment with num_shards self.db.set_experiment_num_shards(self.experiment_id, num_shards) + # Extract shard callback from config_group if it supports Optuna-style callbacks + shard_callback = None + if hasattr(config_group, "get_callback"): + shard_callback = config_group.get_callback() + # Delegate all complexity to Controller try: results = self.controller.run_multi_pipeline_inference( @@ -545,6 +555,7 @@ def run_evals( num_actors=int(num_actors), num_gpus_per_actor=float(gpus_per_actor), num_cpus_per_actor=float(cpus_per_actor), + shard_callback=shard_callback, ) except Exception as e: self.logger.exception("Error running multi-config experiment") diff --git a/rapidfireai/fit/backend/controller.py b/rapidfireai/fit/backend/controller.py index 26df2669..59a8f3d4 100644 --- a/rapidfireai/fit/backend/controller.py +++ b/rapidfireai/fit/backend/controller.py @@ -583,6 +583,7 @@ def run_fit( seed: int = 42, num_gpus: int = 1, monte_carlo_simulations: int = 1000, + chunk_callback=None, ) -> None: """Run the fit.""" @@ -615,7 +616,7 @@ def run_fit( # create models try: len_train_dataset = len(train_dataset) - self._create_models( + run_ids = self._create_models( param_config, RunSource.INITIAL, seed, @@ -626,6 +627,16 @@ def run_fit( except Exception as e: raise ControllerException(f"Error creating models: {e}") from e + # RFOptuna: ``get_runs()`` (inside _create_models) creates the Optuna study. + # ``get_callback()`` must run after that or it returns None and trials never finalize. + if chunk_callback is None and hasattr(param_config, "get_callback"): + chunk_callback = param_config.get_callback(num_chunks=num_chunks) + + # Bind Optuna trials to the newly created DB run IDs + if chunk_callback is not None and hasattr(param_config, "bind_initial_trials"): + param_config.bind_initial_trials(run_ids) + self.logger.info(f"Optuna callback bound to {len(run_ids)} initial runs") + # set experiment task to create models self.db.set_experiment_current_task(ExperimentTask.RUN_FIT) self.logger.debug(f"Set experiment task to {ExperimentTask.RUN_FIT.value}.") @@ -748,6 +759,46 @@ def run_fit( ) self.db.set_controller_progress(run_id, progress_percentage) + # Optuna callback: evaluate run after chunk and potentially prune + if ( + chunk_callback is not None + and run_details["completed_steps"] < run_details["total_steps"] + ): + try: + metric_run_id = run_details.get("metric_run_id") + run_metrics = ( + self.metric_logger.get_run_metrics(metric_run_id) + if metric_run_id + else {} + ) + decision = chunk_callback.on_chunk_complete(run_id, chunk_id, run_metrics) + if decision.action == "prune": + scheduler.remove_run(run_id) + self.db.set_run_details( + run_id=run_id, + status=RunStatus.STOPPED, + ended_by=RunEndedBy.OPTUNA_PRUNED, + ) + self._clear_run_from_shm(run_id) + self.logger.info(f"Optuna pruned run {run_id} after chunk {chunk_id}") + if decision.replacement_config: + new_run_ids = self._create_models( + decision.replacement_config, + RunSource.OPTUNA, + seed, + len_train_dataset, + num_chunks=num_chunks, + ) + if hasattr(chunk_callback, "_remap_pending_trial"): + for new_id in new_run_ids: + chunk_callback._remap_pending_trial(new_id) + self.logger.info( + f"Optuna suggested replacement run(s): {new_run_ids}" + ) + continue + except Exception as e: + self.logger.warning(f"Optuna callback error for run {run_id}: {e}") + # Check if run has completed all epochs # completed_steps can go beyond total_steps since we stop only at a chunk boundary if run_details["completed_steps"] >= run_details["total_steps"]: @@ -899,6 +950,19 @@ def run_fit( # Small delay time.sleep(1) + # Finalize Optuna callback with final metrics from all runs + if chunk_callback is not None: + try: + final_metrics = {} + for rid, rdetails in self.db.get_all_runs().items(): + metric_run_id = rdetails.get("metric_run_id") + if metric_run_id: + final_metrics[rid] = self.metric_logger.get_run_metrics(metric_run_id) + chunk_callback.finalize(final_metrics) + self.logger.info("Optuna callback finalized") + except Exception as e: + self.logger.warning(f"Optuna finalize error: {e}") + # set experiment task to idle self.db.set_experiment_current_task(ExperimentTask.IDLE) self.logger.debug(f"Set experiment task to {ExperimentTask.IDLE.value}.") diff --git a/rapidfireai/fit/utils/constants.py b/rapidfireai/fit/utils/constants.py index 70c8b6f5..bed00673 100644 --- a/rapidfireai/fit/utils/constants.py +++ b/rapidfireai/fit/utils/constants.py @@ -108,6 +108,7 @@ class RunSource(Enum): SHA = "Successive Halving Algorithm" INITIAL = "Initial" INTERACTIVE_CONTROL = "Interactive Control" + OPTUNA = "Optuna Replacement" class RunEndedBy(Enum): @@ -117,6 +118,7 @@ class RunEndedBy(Enum): EPOCH_COMPLETED = "Epoch Completed" INTERACTIVE_CONTROL = "Interactive Control" TOLERENCE = "Tolerence Threshold Met" + OPTUNA_PRUNED = "Optuna Pruned" # SHM Model Type Constants diff --git a/tests/test_optuna.py b/tests/test_optuna.py new file mode 100644 index 00000000..fe8101d4 --- /dev/null +++ b/tests/test_optuna.py @@ -0,0 +1,806 @@ +"""Tests for Optuna integration: search-space extraction, callbacks, RFOptuna.get_runs().""" + +import copy +import types +from dataclasses import dataclass + +import pytest +import optuna + +from rapidfireai.automl.datatypes import List, Range +from rapidfireai.automl.optuna_search import ( + OptunaChunkCallback, + OptunaShardCallback, + RFOptuna, + _extract_search_space, + _resolve_metric_history, + _resolve_scalar_for_objective, + _sample_from_trial, + _sample_from_trial_multi, + _set_nested, + _suggest_value, + _template_to_leaf_evals, + _trial_state_from_storage, +) +from rapidfireai.automl.callbacks import RunDecision, PipelineDecision + + +# --------------------------------------------------------------------------- +# Search-space extraction +# --------------------------------------------------------------------------- + + +class TestExtractSearchSpace: + def test_flat_dict(self): + template = { + "learning_rate": Range(1e-6, 1e-3), + "batch_size": List([4, 8, 16]), + "epochs": 3, + } + space = _extract_search_space(template) + assert len(space) == 2 + paths = {p for p, _ in space} + assert paths == {"learning_rate", "batch_size"} + + def test_nested_dict(self): + template = { + "training_args": { + "lr": Range(1e-5, 1e-3), + "warmup": List([0, 100, 500]), + }, + "model_name": "bert-base", + } + space = _extract_search_space(template) + assert len(space) == 2 + paths = {p for p, _ in space} + assert paths == {"training_args.lr", "training_args.warmup"} + + def test_object_with_user_params(self): + class FakeConfig: + def __init__(self, **kwargs): + self._user_params = kwargs + + config = FakeConfig(lr=Range(1e-5, 1e-3), dropout=0.1, hidden=List([128, 256])) + space = _extract_search_space(config) + assert len(space) == 2 + paths = {p for p, _ in space} + assert paths == {"lr", "hidden"} + + def test_empty_template(self): + assert _extract_search_space({"a": 1, "b": "hello"}) == [] + + def test_dataclass_wraps_nested_user_params(self): + """RFModelConfig is a dataclass; Range/List under peft_config._user_params must be found.""" + + class FakePeft: + def __init__(self): + self._user_params = {"lora_alpha": List([16, 32]), "r": 8} + + @dataclass + class FakeModelConfig: + model_name: str + peft_config: object + + template = FakeModelConfig(model_name="gpt2", peft_config=FakePeft()) + space = _extract_search_space(template) + assert len(space) == 1 + path, param = space[0] + assert path == "peft_config.lora_alpha" + assert isinstance(param, List) + + def test_range_log_and_step(self): + r = Range(1e-6, 1e-3, log=True) + assert r.log is True + assert r.step is None + r2 = Range(8, 64, step=8) + assert r2.step == 8 + assert r2.log is False + + +def test_resolve_scalar_prefers_primary_key(): + assert _resolve_scalar_for_objective({"eval_loss": 1.0, "train_loss": 9.0}, "eval_loss") == 1.0 + + +class TestResolveMetricHistory: + def test_mlflow_style_history(self): + metrics = {"eval_loss": [(0, 0.9), (10, 0.7), (20, 0.5)]} + assert _resolve_metric_history(metrics, "eval_loss") == [(0, 0.9), (10, 0.7), (20, 0.5)] + + def test_plain_scalar(self): + assert _resolve_metric_history({"eval_loss": 0.42}, "eval_loss") == [(0, 0.42)] + + def test_alias_fallback(self): + metrics = {"train_loss": [(5, 1.0), (15, 0.8)]} + assert _resolve_metric_history(metrics, "eval_loss") == [(5, 1.0), (15, 0.8)] + + def test_no_match(self): + assert _resolve_metric_history({"other": 1.0}, "eval_loss") == [] + + def test_unsorted_input_gets_sorted(self): + metrics = {"eval_loss": [(20, 0.5), (0, 0.9), (10, 0.7)]} + assert _resolve_metric_history(metrics, "eval_loss") == [(0, 0.9), (10, 0.7), (20, 0.5)] + + def test_bare_number_list(self): + metrics = {"eval_loss": [0.9, 0.7, 0.5]} + result = _resolve_metric_history(metrics, "eval_loss") + assert result == [(0, 0.9), (1, 0.7), (2, 0.5)] + + +# --------------------------------------------------------------------------- +# Sampling from trial +# --------------------------------------------------------------------------- + + +class TestSuggestAndSample: + def test_suggest_float_range(self): + study = optuna.create_study() + trial = study.ask() + val = _suggest_value(trial, "lr", Range(0.001, 0.1)) + assert 0.001 <= val <= 0.1 + + def test_suggest_int_range(self): + study = optuna.create_study() + trial = study.ask() + val = _suggest_value(trial, "bs", Range(4, 32)) + assert 4 <= val <= 32 + assert isinstance(val, int) + + def test_suggest_categorical(self): + study = optuna.create_study() + trial = study.ask() + val = _suggest_value(trial, "opt", List(["adam", "sgd", "adamw"])) + assert val in ["adam", "sgd", "adamw"] + + def test_sample_from_trial_flat(self): + template = { + "lr": Range(0.0, 1.0), + "name": "test", + "bs": List([8, 16]), + } + space = _extract_search_space(template) + study = optuna.create_study() + trial = study.ask() + result = _sample_from_trial(trial, space, template) + + assert isinstance(result["lr"], float) + assert result["bs"] in [8, 16] + assert result["name"] == "test" + # Original template not mutated + assert isinstance(template["lr"], Range) + + def test_sample_from_trial_nested(self): + template = { + "outer": { + "inner": Range(0, 10), + "fixed": "hello", + } + } + space = _extract_search_space(template) + study = optuna.create_study() + trial = study.ask() + result = _sample_from_trial(trial, space, template) + assert isinstance(result["outer"]["inner"], int) + assert result["outer"]["fixed"] == "hello" + + +class TestSetNested: + def test_flat_dict(self): + d = {"a": 1, "b": 2} + _set_nested(d, "a", 99) + assert d["a"] == 99 + + def test_nested_dict(self): + d = {"outer": {"inner": 1}} + _set_nested(d, "outer.inner", 42) + assert d["outer"]["inner"] == 42 + + +# --------------------------------------------------------------------------- +# OptunaChunkCallback +# --------------------------------------------------------------------------- + + +def _fit_template_for_chunk_callback_tests() -> types.SimpleNamespace: + """Minimal RFModelConfig-like object for tests that call ``_template_to_leaf_fit``.""" + return types.SimpleNamespace( + model_name="m", + tokenizer=None, + tokenizer_kwargs=None, + model_type="causal_lm", + peft_config=None, + training_args=None, + model_kwargs=None, + ref_model_kwargs=None, + reward_funcs=None, + ref_model_name=None, + ref_model_type=None, + num_gpus=None, + formatting_func=None, + compute_metrics=None, + generation_config=None, + lr=Range(0.0, 1.0), + ) + + +class TestOptunaChunkCallback: + def _make_callback(self, direction="minimize", pruner=None): + study = optuna.create_study( + direction=direction, + pruner=pruner or optuna.pruners.NopPruner(), + ) + space = [("lr", Range(0.0, 1.0))] + template = _fit_template_for_chunk_callback_tests() + cb = OptunaChunkCallback( + study=study, + search_spaces=[space], + config_templates=[template], + trainer_type="SFT", + budget=5, + objective_metric="eval_loss", + ) + return cb, study + + def test_continue_when_no_prune(self): + cb, study = self._make_callback() + trial = study.ask() + cb._set_initial_trials({1: trial}, spawned=1) + + decision = cb.on_chunk_complete(1, 0, {"eval_loss": 0.5}) + assert decision.action == "continue" + assert decision.replacement_config is None + + def test_continue_when_metric_missing(self): + cb, study = self._make_callback() + trial = study.ask() + cb._set_initial_trials({1: trial}, spawned=1) + + decision = cb.on_chunk_complete(1, 0, {"other_metric": 0.5}) + assert decision.action == "continue" + + def test_continue_when_run_unknown(self): + cb, _ = self._make_callback() + decision = cb.on_chunk_complete(999, 0, {"eval_loss": 0.5}) + assert decision.action == "continue" + + def test_resolve_metric_flat(self): + cb, _ = self._make_callback() + assert cb._resolve_metric({"eval_loss": 0.5}) == 0.5 + + def test_resolve_metric_mlflow_history(self): + cb, _ = self._make_callback() + assert cb._resolve_metric({"eval_loss": [(0, 0.8), (1, 0.5)]}) == 0.5 + + def test_resolve_metric_falls_back_when_eval_missing(self): + """Tiny SFT jobs may log train_loss but never eval_loss.""" + cb, _ = self._make_callback() + assert cb._resolve_metric({"train_loss": 2.5}) == 2.5 + assert cb._resolve_metric({"train_loss": [(0, 3.0), (4, 2.1)]}) == 2.1 + + def test_finalize_tells_study(self): + cb, study = self._make_callback() + trial = study.ask() + cb._set_initial_trials({1: trial}, spawned=1) + + cb.finalize({1: {"eval_loss": 0.3}}) + assert _trial_state_from_storage(study, trial) == optuna.trial.TrialState.COMPLETE + + def test_finalize_fails_missing_metric(self): + cb, study = self._make_callback() + trial = study.ask() + cb._set_initial_trials({1: trial}, spawned=1) + + cb.finalize({1: {}}) + assert _trial_state_from_storage(study, trial) == optuna.trial.TrialState.FAIL + + def test_replacement_within_budget(self): + cb, study = self._make_callback() + cb._spawned = 3 + cb._budget = 5 + replacement = cb._maybe_suggest_replacement() + assert replacement is not None + assert isinstance(replacement, dict) + assert cb._spawned == 4 + + def test_no_replacement_over_budget(self): + cb, study = self._make_callback() + cb._spawned = 5 + cb._budget = 5 + replacement = cb._maybe_suggest_replacement() + assert replacement is None + + @staticmethod + def _get_intermediate_values(study, trial): + """Retrieve intermediate_values from the frozen trial in storage.""" + for ft in study.get_trials(deepcopy=False): + if ft.number == trial.number: + return ft.intermediate_values + return {} + + def test_reports_all_training_steps(self): + """on_chunk_complete should report every training step, not just one per chunk.""" + cb, study = self._make_callback() + trial = study.ask() + cb._set_initial_trials({1: trial}, spawned=1) + + metrics = {"eval_loss": [(0, 0.9), (5, 0.8), (10, 0.7)]} + decision = cb.on_chunk_complete(1, 0, metrics) + assert decision.action == "continue" + + reported = self._get_intermediate_values(study, trial) + assert reported == {0: 0.9, 5: 0.8, 10: 0.7} + assert cb._last_reported_step[1] == 10 + + def test_cumulative_across_chunks(self): + """Second chunk should only report new steps, not re-report old ones.""" + cb, study = self._make_callback() + trial = study.ask() + cb._set_initial_trials({1: trial}, spawned=1) + + cb.on_chunk_complete(1, 0, {"eval_loss": [(0, 0.9), (5, 0.8)]}) + assert cb._last_reported_step[1] == 5 + + cb.on_chunk_complete(1, 1, {"eval_loss": [(0, 0.9), (5, 0.8), (10, 0.6), (15, 0.5)]}) + assert cb._last_reported_step[1] == 15 + + reported = self._get_intermediate_values(study, trial) + assert reported == {0: 0.9, 5: 0.8, 10: 0.6, 15: 0.5} + + def test_flat_scalar_reports_at_step_zero(self): + """A flat scalar metric gets reported at step 0.""" + cb, study = self._make_callback() + trial = study.ask() + cb._set_initial_trials({1: trial}, spawned=1) + + cb.on_chunk_complete(1, 0, {"eval_loss": 0.5}) + reported = self._get_intermediate_values(study, trial) + assert reported == {0: 0.5} + + def test_remap_pending_trial(self): + cb, study = self._make_callback() + trial = study.ask() + cb._trials["_optuna_pending_abc12345"] = trial + cb._remap_pending_trial(42) + assert 42 in cb._trials + assert "_optuna_pending_abc12345" not in cb._trials + + +# --------------------------------------------------------------------------- +# OptunaChunkCallback — epoch granularity +# --------------------------------------------------------------------------- + + +class TestOptunaChunkCallbackEpochGranularity: + """Tests for granularity='epoch': decisions only fire at epoch boundaries.""" + + NUM_CHUNKS = 4 + + def _make_callback(self, direction="minimize", pruner=None): + study = optuna.create_study( + direction=direction, + pruner=pruner or optuna.pruners.NopPruner(), + ) + space = [("lr", Range(0.0, 1.0))] + template = _fit_template_for_chunk_callback_tests() + cb = OptunaChunkCallback( + study=study, + search_spaces=[space], + config_templates=[template], + trainer_type="SFT", + budget=5, + objective_metric="eval_loss", + granularity="epoch", + num_chunks=self.NUM_CHUNKS, + ) + return cb, study + + def test_defers_decision_until_epoch_boundary(self): + """Chunks 0-2 should always continue; chunk 3 (4th) is the epoch boundary.""" + cb, study = self._make_callback() + trial = study.ask() + cb._set_initial_trials({1: trial}, spawned=1) + + for chunk_id in range(self.NUM_CHUNKS - 1): + decision = cb.on_chunk_complete(1, chunk_id, {"eval_loss": 0.9 - chunk_id * 0.1}) + assert decision.action == "continue", f"expected continue at chunk {chunk_id}" + + decision = cb.on_chunk_complete(1, self.NUM_CHUNKS - 1, {"eval_loss": 0.5}) + assert decision.action == "continue" + + def test_prune_fires_at_epoch_boundary(self): + pruner = optuna.pruners.ThresholdPruner(upper=0.1) + cb, study = self._make_callback(pruner=pruner) + trial = study.ask() + cb._set_initial_trials({1: trial}, spawned=1) + + for chunk_id in range(self.NUM_CHUNKS - 1): + decision = cb.on_chunk_complete(1, chunk_id, {"eval_loss": 5.0}) + assert decision.action == "continue" + + decision = cb.on_chunk_complete(1, self.NUM_CHUNKS - 1, {"eval_loss": 5.0}) + assert decision.action == "prune" + + def test_counter_resets_after_epoch(self): + """After one epoch completes, the next epoch should count from 0 again.""" + cb, study = self._make_callback() + trial = study.ask() + cb._set_initial_trials({1: trial}, spawned=1) + + for chunk_id in range(self.NUM_CHUNKS): + cb.on_chunk_complete(1, chunk_id, {"eval_loss": 0.5}) + assert cb._chunks_since_last_eval[1] == 0 + + for chunk_id in range(self.NUM_CHUNKS - 1): + decision = cb.on_chunk_complete(1, chunk_id, {"eval_loss": 0.4}) + assert decision.action == "continue" + + def test_metrics_still_reported_every_chunk(self): + """Even with epoch granularity, intermediate metric values are reported to Optuna + on every chunk so the pruner has full visibility.""" + cb, study = self._make_callback() + trial = study.ask() + cb._set_initial_trials({1: trial}, spawned=1) + + cb.on_chunk_complete(1, 0, {"eval_loss": [(0, 0.9), (5, 0.8)]}) + cb.on_chunk_complete(1, 1, {"eval_loss": [(0, 0.9), (5, 0.8), (10, 0.7)]}) + + reported = {} + for ft in study.get_trials(deepcopy=False): + if ft.number == trial.number: + reported = ft.intermediate_values + assert reported == {0: 0.9, 5: 0.8, 10: 0.7} + + def test_independent_tracking_per_run(self): + cb, study = self._make_callback() + t1 = study.ask() + t2 = study.ask() + cb._set_initial_trials({1: t1, 2: t2}, spawned=2) + + cb.on_chunk_complete(1, 0, {"eval_loss": 0.5}) + cb.on_chunk_complete(1, 1, {"eval_loss": 0.4}) + cb.on_chunk_complete(2, 0, {"eval_loss": 0.6}) + + assert cb._chunks_since_last_eval[1] == 2 + assert cb._chunks_since_last_eval[2] == 1 + + def test_invalid_granularity_rejected(self): + study = optuna.create_study() + with pytest.raises(Exception, match="granularity"): + OptunaChunkCallback( + study=study, + search_spaces=[[("x", Range(0.0, 1.0))]], + config_templates=[{"x": Range(0.0, 1.0)}], + trainer_type="SFT", + budget=5, + objective_metric="loss", + granularity="step", + num_chunks=4, + ) + + def test_epoch_granularity_requires_num_chunks(self): + study = optuna.create_study() + with pytest.raises(Exception, match="num_chunks"): + OptunaChunkCallback( + study=study, + search_spaces=[[("x", Range(0.0, 1.0))]], + config_templates=[{"x": Range(0.0, 1.0)}], + trainer_type="SFT", + budget=5, + objective_metric="loss", + granularity="epoch", + num_chunks=None, + ) + + +# --------------------------------------------------------------------------- +# OptunaShardCallback +# --------------------------------------------------------------------------- + + +class TestOptunaShardCallback: + def _make_callback(self): + study = optuna.create_study( + direction="maximize", + pruner=optuna.pruners.NopPruner(), + ) + space = [("temperature", Range(0.0, 2.0))] + template = {"pipeline": "fake", "temperature": Range(0.0, 2.0)} + cb = OptunaShardCallback( + study=study, + search_spaces=[space], + config_templates=[template], + budget=5, + objective_metric="accuracy", + ) + return cb, study + + def test_continue_decision(self): + cb, study = self._make_callback() + trial = study.ask() + cb._set_initial_trials({10: trial}, spawned=1) + decision = cb.on_shard_complete(10, 0, {"accuracy": 0.85}) + assert decision.action == "continue" + + def test_resolve_metric_dict_with_value(self): + cb, _ = self._make_callback() + assert cb._resolve_metric({"accuracy": {"value": 0.9, "lower_bound": 0.85}}) == 0.9 + + def test_resolve_metric_plain_float(self): + cb, _ = self._make_callback() + assert cb._resolve_metric({"accuracy": 0.75}) == 0.75 + + def test_finalize(self): + cb, study = self._make_callback() + trial = study.ask() + cb._set_initial_trials({10: trial}, spawned=1) + cb.finalize({10: {"accuracy": 0.92}}) + assert _trial_state_from_storage(study, trial) == optuna.trial.TrialState.COMPLETE + + +# --------------------------------------------------------------------------- +# RFOptuna class +# --------------------------------------------------------------------------- + + +class TestRFOptuna: + def test_invalid_objective_format(self): + with pytest.raises(Exception, match="objective must be"): + RFOptuna( + configs=[{"lr": Range(0.0, 1.0)}], + objective="bad_format", + ) + + def test_invalid_sampler(self): + rfopt = RFOptuna( + configs=[{"lr": Range(0.0, 1.0)}], + objective="minimize:loss", + sampler="nonexistent", + ) + with pytest.raises(Exception, match="Unknown sampler"): + rfopt.get_runs(seed=42) + + def test_invalid_pruner(self): + rfopt = RFOptuna( + configs=[{"lr": Range(0.0, 1.0)}], + objective="minimize:loss", + pruner="nonexistent", + ) + with pytest.raises(Exception, match="Unknown pruner"): + rfopt.get_runs(seed=42) + + def test_get_runs_evals_mode(self): + rfopt = RFOptuna( + configs=[{"pipeline": "fake", "temperature": Range(0.0, 2.0)}], + trainer_type=None, + n_initial=5, + budget=10, + objective="maximize:accuracy", + sampler="random", + pruner=None, + seed=42, + ) + runs = rfopt.get_runs(seed=42) + assert len(runs) == 5 + for run in runs: + assert "pipeline" in run + assert isinstance(run["temperature"], float) + assert 0.0 <= run["temperature"] <= 2.0 + + def test_get_runs_no_search_space_raises(self): + rfopt = RFOptuna( + configs=[{"fixed_param": 42}], + objective="minimize:loss", + ) + with pytest.raises(Exception, match="No Range or List"): + rfopt.get_runs(seed=42) + + def test_get_callback_returns_shard_for_evals(self): + rfopt = RFOptuna( + configs=[{"pipeline": "fake", "temp": Range(0.0, 2.0)}], + trainer_type=None, + n_initial=3, + budget=6, + objective="maximize:acc", + sampler="random", + pruner=None, + ) + rfopt.get_runs(seed=42) + cb = rfopt.get_callback() + assert isinstance(cb, OptunaShardCallback) + + def test_get_callback_returns_none_before_get_runs(self): + rfopt = RFOptuna( + configs=[{"pipeline": "fake", "temp": Range(0.0, 2.0)}], + objective="maximize:acc", + ) + assert rfopt.get_callback() is None + + def test_bind_initial_trials(self): + rfopt = RFOptuna( + configs=[{"pipeline": "fake", "temp": Range(0.0, 2.0)}], + trainer_type=None, + n_initial=3, + budget=6, + objective="maximize:acc", + sampler="random", + pruner=None, + ) + rfopt.get_runs(seed=42) + cb = rfopt.get_callback() + + rfopt.bind_initial_trials([100, 200, 300]) + assert 100 in cb._trials + assert 200 in cb._trials + assert 300 in cb._trials + + def test_budget_clamps_to_n_initial(self): + rfopt = RFOptuna( + configs=[{"x": Range(0.0, 1.0)}], + n_initial=10, + budget=5, + objective="minimize:loss", + ) + assert rfopt.budget == 10 + + def test_deterministic_with_seed(self): + def make_runs(seed): + rfopt = RFOptuna( + configs=[{"x": Range(0.0, 10.0), "y": List([1, 2, 3])}], + n_initial=5, + budget=5, + objective="minimize:loss", + sampler="tpe", + pruner=None, + seed=seed, + ) + return rfopt.get_runs(seed=seed) + + runs_a = make_runs(42) + runs_b = make_runs(42) + for a, b in zip(runs_a, runs_b, strict=True): + assert a["x"] == b["x"] + assert a["y"] == b["y"] + + def test_base_class_get_callback_returns_none(self): + from rapidfireai.automl import RFGridSearch + gs = RFGridSearch( + configs=[{"pipeline": "fake"}], + trainer_type=None, + ) + assert gs.get_callback() is None + + def test_invalid_granularity(self): + with pytest.raises(Exception, match="granularity"): + RFOptuna( + configs=[{"lr": Range(0.0, 1.0)}], + objective="minimize:loss", + granularity="step", + ) + + def test_granularity_epoch_stored_on_rfoptuna(self): + rfopt = RFOptuna( + configs=[{"pipeline": "fake", "temp": Range(0.0, 2.0)}], + trainer_type=None, + n_initial=2, + budget=4, + objective="minimize:eval_loss", + sampler="random", + pruner=None, + seed=42, + granularity="epoch", + ) + assert rfopt._granularity == "epoch" + + def test_granularity_defaults_to_chunk_on_rfoptuna(self): + rfopt = RFOptuna( + configs=[{"pipeline": "fake", "temp": Range(0.0, 2.0)}], + trainer_type=None, + n_initial=2, + budget=4, + objective="minimize:eval_loss", + sampler="random", + pruner=None, + seed=42, + ) + assert rfopt._granularity == "chunk" + + +# --------------------------------------------------------------------------- +# Multi-template support +# --------------------------------------------------------------------------- + + +class TestMultiTemplate: + """Verify RFOptuna correctly handles multiple config templates.""" + + def test_sample_from_trial_multi_single_template(self): + """Single template: behaves identically to _sample_from_trial.""" + template = {"lr": Range(0.0, 1.0), "fixed": "hello"} + space = _extract_search_space(template) + study = optuna.create_study() + trial = study.ask() + result = _sample_from_trial_multi(trial, [template], [space]) + assert isinstance(result["lr"], float) + assert result["fixed"] == "hello" + # No _config_template_idx categorical when single template + assert "_config_template_idx" not in trial.params + + def test_sample_from_trial_multi_two_templates(self): + """Two templates: Optuna picks one via categorical, samples its space.""" + t0 = {"lr": Range(0.0, 0.1), "model": "small"} + t1 = {"dropout": Range(0.0, 0.5), "model": "large"} + spaces = [_extract_search_space(t0), _extract_search_space(t1)] + + study = optuna.create_study() + trial = study.ask() + result = _sample_from_trial_multi(trial, [t0, t1], spaces) + + assert "_config_template_idx" in trial.params + tidx = trial.params["_config_template_idx"] + assert tidx in (0, 1) + + if tidx == 0: + assert isinstance(result["lr"], float) + assert result["model"] == "small" + else: + assert isinstance(result["dropout"], float) + assert result["model"] == "large" + + def test_get_runs_evals_multi_template(self): + t0 = {"pipeline": "pipe_a", "temperature": Range(0.0, 1.0)} + t1 = {"pipeline": "pipe_b", "top_k": Range(1, 50)} + + rfopt = RFOptuna( + configs=[t0, t1], + trainer_type=None, + n_initial=6, + budget=10, + objective="maximize:accuracy", + sampler="random", + pruner=None, + seed=42, + ) + runs = rfopt.get_runs(seed=42) + assert len(runs) == 6 + for run in runs: + assert "pipeline" in run + + def test_get_runs_evals_list_wrapper(self): + """List([t1, t2]) syntax works the same as [t1, t2].""" + t0 = {"pipeline": "a", "x": Range(0.0, 1.0)} + t1 = {"pipeline": "b", "y": Range(0.0, 1.0)} + + rfopt = RFOptuna( + configs=List([t0, t1]), + trainer_type=None, + n_initial=4, + budget=8, + objective="maximize:score", + sampler="random", + pruner=None, + seed=7, + ) + runs = rfopt.get_runs(seed=7) + assert len(runs) == 4 + + def test_callback_replacement_multi_template(self): + """Replacement configs can come from any template.""" + t0 = {"pipeline": "a", "temperature": Range(0.0, 2.0)} + t1 = {"pipeline": "b", "top_k": Range(1, 50)} + spaces = [_extract_search_space(t0), _extract_search_space(t1)] + + study = optuna.create_study( + direction="maximize", + pruner=optuna.pruners.NopPruner(), + ) + cb = OptunaShardCallback( + study=study, + search_spaces=spaces, + config_templates=[t0, t1], + budget=5, + objective_metric="accuracy", + ) + cb._spawned = 2 + replacement = cb._maybe_suggest_replacement() + assert replacement is not None + assert isinstance(replacement, dict) + assert cb._spawned == 3 diff --git a/tutorial_notebooks/fine-tuning/rf-tutorial-optuna-sft-chatqa-tiny.ipynb b/tutorial_notebooks/fine-tuning/rf-tutorial-optuna-sft-chatqa-tiny.ipynb new file mode 100644 index 00000000..d8040f3b --- /dev/null +++ b/tutorial_notebooks/fine-tuning/rf-tutorial-optuna-sft-chatqa-tiny.ipynb @@ -0,0 +1,1029 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "\n", + "\n", + "
\n", + "Join Discord if you need help + ⭐ Star us on GitHub ⭐\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## RapidFire AI + Optuna: Adaptive SFT Hyperparameter Search\n", + "\n", + "This tutorial shows how to use **RFOptuna** for Bayesian hyperparameter optimization\n", + "integrated into RapidFire's chunk-based training loop.\n", + "\n", + "**Key difference from RFGridSearch / RFRandomSearch:**\n", + "- Grid/Random search decide all configs upfront\n", + "- **RFOptuna** uses Optuna's TPE sampler to suggest initial configs, then **prunes underperforming runs after each chunk** and replaces them with smarter suggestions\n", + "\n", + "This notebook uses TinyLlama-1.1B-Chat with chat-formatted SFT, optimizing for **ROUGE-L** as the single objective metric." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Enable Metric Loggers" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"RF_MLFLOW_ENABLED\"] = \"true\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Imports\n", + "\n", + "Note: `RFOptuna` requires the `optuna` package. Install with:\n", + "```bash\n", + "pip install optuna\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from rapidfireai import Experiment\n", + "from rapidfireai.automl import (\n", + " List,\n", + " Range,\n", + " RFOptuna,\n", + " RFModelConfig,\n", + " RFLoraConfig,\n", + " RFSFTConfig,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset = load_dataset(\"bitext/Bitext-customer-support-llm-chatbot-training-dataset\")\n", + "\n", + "train_dataset = dataset[\"train\"].select(range(128))\n", + "eval_dataset = dataset[\"train\"].select(range(128, 152))\n", + "train_dataset = train_dataset.shuffle(seed=42)\n", + "eval_dataset = eval_dataset.shuffle(seed=42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define Data Processing Function" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def sample_formatting_function(row):\n", + " \"\"\"Function to preprocess each example from dataset\"\"\"\n", + " SYSTEM_PROMPT = \"You are a helpful and friendly customer support assistant. Please answer the user's query to the best of your ability.\"\n", + " return {\n", + " \"prompt\": [\n", + " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", + " {\"role\": \"user\", \"content\": row[\"instruction\"]},\n", + " ],\n", + " \"completion\": [\n", + " {\"role\": \"assistant\", \"content\": row[\"response\"]}\n", + " ]\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define Custom Eval Metrics Function" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def sample_compute_metrics(eval_preds): \n", + " \"\"\"Optional function to compute eval metrics based on predictions and labels\"\"\"\n", + " predictions, labels = eval_preds\n", + "\n", + " import evaluate\n", + " rouge = evaluate.load(\"rouge\")\n", + " bleu = evaluate.load(\"bleu\")\n", + "\n", + " rouge_output = rouge.compute(predictions=predictions, references=labels, use_stemmer=True)\n", + " rouge_l = rouge_output[\"rougeL\"]\n", + " bleu_output = bleu.compute(predictions=predictions, references=labels)\n", + " bleu_score = bleu_output[\"bleu\"]\n", + "\n", + " return {\n", + " \"rougeL\": round(rouge_l, 4),\n", + " \"bleu\": round(bleu_score, 4),\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize Experiment" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The previously running experiment exp-optuna-chatqa-tiny1_15 was forcibly ended. Created a new experiment 'exp-optuna-chatqa-tiny1_16' with Experiment ID: 18 and Metric Experiment ID: 29 at /home/ubuntu/rapidfireai/rapidfire_experiments/exp-optuna-chatqa-tiny1_16\n" + ] + } + ], + "source": [ + "experiment = Experiment(experiment_name=\"exp-optuna-chatqa-tiny1\", mode=\"fit\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define Config Template with Search Space\n", + "\n", + "Instead of listing every combination (GridSearch) or sampling blindly (RandomSearch),\n", + "we define a **search space** using `Range(...)` and `List([...])` inside a single\n", + "`RFModelConfig`. Optuna's TPE sampler will intelligently explore this space.\n", + "\n", + "**What Optuna controls:**\n", + "- `lora_alpha`: Sampled from a categorical list\n", + "- `target_modules`: Sampled from a categorical list of module sets\n", + "\n", + "**What stays fixed:**\n", + "- Model (`TinyLlama-1.1B-Chat`), learning rate, LoRA rank, batch size, precision, etc." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "config_template = RFModelConfig(\n", + " model_name=\"TinyLlama/TinyLlama-1.1B-Chat-v1.0\",\n", + " peft_config=RFLoraConfig(\n", + " r=List([8, 32]),\n", + " lora_alpha=List([16, 32, 64, 128]),\n", + " lora_dropout=0.1,\n", + " target_modules=List([\n", + " [\"q_proj\", \"v_proj\"],\n", + " [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n", + " ]),\n", + " bias=\"none\",\n", + " ),\n", + " training_args=RFSFTConfig(\n", + " learning_rate=List([1e-3, 1e-4]),\n", + " lr_scheduler_type=\"linear\",\n", + " per_device_train_batch_size=4,\n", + " per_device_eval_batch_size=4,\n", + " max_steps=128,\n", + " gradient_accumulation_steps=1,\n", + " logging_steps=2,\n", + " eval_strategy=\"steps\",\n", + " eval_steps=4,\n", + " bf16=True,\n", + " ),\n", + " model_type=\"causal_lm\",\n", + " model_kwargs={\"device_map\": \"auto\", \"torch_dtype\": \"auto\", \"use_cache\": False},\n", + " formatting_func=sample_formatting_function,\n", + " compute_metrics=sample_compute_metrics,\n", + " generation_config={\n", + " \"max_new_tokens\": 256,\n", + " \"temperature\": 0.8,\n", + " \"top_p\": 0.9,\n", + " \"top_k\": 30,\n", + " \"repetition_penalty\": 1.05,\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create RFOptuna Config Group\n", + "\n", + "Key parameters:\n", + "- **`n_initial=4`**: Start with 4 configs (sampled by TPE)\n", + "- **`budget=6`**: Allow up to 6 total configs (including replacements for pruned runs)\n", + "- **`objective=\"maximize:rougeL\"`**: Maximize ROUGE-L (from `sample_compute_metrics`)\n", + "- **`sampler=\"tpe\"`**: Tree-structured Parzen Estimator (Bayesian)\n", + "- **`pruner=\"median\"`**: Prune runs whose intermediate ROUGE-L falls below the median of peers" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "config_group = RFOptuna(\n", + " configs=[config_template],\n", + " trainer_type=\"SFT\",\n", + " n_initial=4,\n", + " budget=6,\n", + " objective=\"maximize:rougeL\",\n", + " sampler=\"tpe\",\n", + " pruner=\"median\",\n", + " seed=42,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define Model Creation Function" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "def sample_create_model(model_config):\n", + " \"\"\"Function to create model object for any given config; must return tuple of (model, tokenizer)\"\"\"\n", + " from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForMaskedLM\n", + "\n", + " model_name = model_config[\"model_name\"]\n", + " model_type = model_config[\"model_type\"]\n", + " model_kwargs = model_config[\"model_kwargs\"]\n", + "\n", + " if model_type == \"causal_lm\":\n", + " model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)\n", + " elif model_type == \"seq2seq_lm\":\n", + " model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **model_kwargs)\n", + " elif model_type == \"masked_lm\":\n", + " model = AutoModelForMaskedLM.from_pretrained(model_name, **model_kwargs)\n", + " else:\n", + " model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)\n", + "\n", + " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + "\n", + " return (model, tokenizer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run Optuna-Powered Training\n", + "\n", + "Behind the scenes:\n", + "1. `RFOptuna.get_runs()` asks Optuna's TPE sampler for 4 initial configs\n", + "2. RapidFire trains all 4 concurrently using chunk-based scheduling\n", + "3. After each chunk completes, the Optuna callback:\n", + " - Resolves `rougeL` for each run\n", + " - Uses median pruning: a run is pruned if its ROUGE-L falls below the median of peers\n", + " - If pruned, suggests a replacement config (up to budget of 6)\n", + "4. RapidFire stops pruned runs and starts replacements automatically" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO 04-28 21:53:45 [__init__.py:216] Automatically detected platform cuda.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m[I 2026-04-28 21:53:46,154]\u001b[0m A new study created in memory with name: no-name-0eaba987-1210-472e-b9b5-a5232c9f729a\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Started 2 worker processes successfully\n", + "Created workers\n", + "INFO 04-28 21:53:55 [__init__.py:216] Automatically detected platform cuda.\n", + "INFO 04-28 21:53:55 [__init__.py:216] Automatically detected platform cuda.\n" + ] + } + ], + "source": [ + "experiment.run_fit(\n", + " config_group,\n", + " sample_create_model,\n", + " train_dataset,\n", + " eval_dataset,\n", + " num_chunks=4,\n", + " seed=42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Inspect Optuna Study Results\n", + "\n", + "After training, the Optuna study object is accessible on the `RFOptuna` instance.\n", + "`study.best_trial` returns the trial with the highest ROUGE-L score." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trials: 6 total, 1 completed, 5 pruned\n", + "Best trial: #2 with rougeL = 0.5657\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
trialstatebestrlora_alphatarget_moduleslearning_raterougeL
00PRUNED3216['q_proj', 'k_proj', 'v_proj', 'o_proj']0.00010.2545
11PRUNED3216['q_proj', 'k_proj', 'v_proj', 'o_proj']0.00100.4768
22COMPLETE8128['q_proj', 'k_proj', 'v_proj', 'o_proj']0.00100.5657
33PRUNED864['q_proj', 'v_proj']0.00100.4838
44PRUNED3232['q_proj', 'k_proj', 'v_proj', 'o_proj']0.00100.5288
55PRUNED816['q_proj', 'k_proj', 'v_proj', 'o_proj']0.00010.2569
\n", + "
" + ], + "text/plain": [ + " trial state best r lora_alpha \\\n", + "0 0 PRUNED 32 16 \n", + "1 1 PRUNED 32 16 \n", + "2 2 COMPLETE ✓ 8 128 \n", + "3 3 PRUNED 8 64 \n", + "4 4 PRUNED 32 32 \n", + "5 5 PRUNED 8 16 \n", + "\n", + " target_modules learning_rate rougeL \n", + "0 ['q_proj', 'k_proj', 'v_proj', 'o_proj'] 0.0001 0.2545 \n", + "1 ['q_proj', 'k_proj', 'v_proj', 'o_proj'] 0.0010 0.4768 \n", + "2 ['q_proj', 'k_proj', 'v_proj', 'o_proj'] 0.0010 0.5657 \n", + "3 ['q_proj', 'v_proj'] 0.0010 0.4838 \n", + "4 ['q_proj', 'k_proj', 'v_proj', 'o_proj'] 0.0010 0.5288 \n", + "5 ['q_proj', 'k_proj', 'v_proj', 'o_proj'] 0.0001 0.2569 " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "from optuna.trial import TrialState\n", + "\n", + "study = config_group._study\n", + "\n", + "completed = [t for t in study.trials if t.state == TrialState.COMPLETE]\n", + "pruned = [t for t in study.trials if t.state == TrialState.PRUNED]\n", + "\n", + "print(f\"Trials: {len(study.trials)} total, {len(completed)} completed, {len(pruned)} pruned\")\n", + "\n", + "try:\n", + " best = study.best_trial\n", + " print(f\"Best trial: #{best.number} with rougeL = {best.value:.4f}\")\n", + "except ValueError as exc:\n", + " print(f\"No best trial available yet ({exc}).\")\n", + "\n", + "trials_df = pd.DataFrame([\n", + " {\n", + " \"trial\": t.number,\n", + " \"state\": t.state.name,\n", + " \"best\": \"✓\" if t.number == study.best_trial.number else \"\",\n", + " **{k.split(\".\")[-1]: v for k, v in t.params.items()},\n", + " \"rougeL\": f\"{t.value:.4f}\" if t.value is not None else \"—\",\n", + " }\n", + " for t in study.trials\n", + "])\n", + "trials_df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Get Results via RapidFire" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
run_idsteplossgrad_normlearning_ratenum_tokensmean_token_accuracychunk numbernum_epochs_completedeval_loss...eval_steps_per_secondeval_num_tokenseval_mean_token_accuracyrougeLbleutrain_runtimetrain_samples_per_secondtrain_steps_per_secondtotal_flostrain_loss
0121.46010.3213780.0000992777.00.6477570.00.0NaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
1141.42390.3435280.0000985156.00.6468820.00.01.362228...9.3825156.00.6604360.24350.0801NaNNaNNaNNaNNaN
2161.37630.3396700.0000967839.00.6600240.00.0NaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
3181.33580.3440420.00009510420.00.6725560.00.01.270336...9.25610420.00.6781660.25450.0864132.35580.2420.0607.358093e+131.399039
4221.39590.2663550.0009922777.00.6573890.00.0NaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
..................................................................
875160.62420.4702280.00088310770.00.8238531.00.00.511950...9.54510770.00.8479430.52880.4742144.37330.2220.0557.445749e+130.588614
88621.45850.6645690.0000992777.00.6502700.00.0NaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
89641.41880.7135180.0000985156.00.6468820.00.01.359361...9.3995156.00.6600330.24320.0777NaNNaNNaNNaNNaN
90661.37450.6890090.0000967839.00.6617810.00.0NaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
91681.33570.6988270.00009510420.00.6743660.00.01.270993...9.42710420.00.6767000.25690.0883134.75700.2370.0597.310439e+131.396876
\n", + "

92 rows × 22 columns

\n", + "
" + ], + "text/plain": [ + " run_id step loss grad_norm learning_rate num_tokens \\\n", + "0 1 2 1.4601 0.321378 0.000099 2777.0 \n", + "1 1 4 1.4239 0.343528 0.000098 5156.0 \n", + "2 1 6 1.3763 0.339670 0.000096 7839.0 \n", + "3 1 8 1.3358 0.344042 0.000095 10420.0 \n", + "4 2 2 1.3959 0.266355 0.000992 2777.0 \n", + ".. ... ... ... ... ... ... \n", + "87 5 16 0.6242 0.470228 0.000883 10770.0 \n", + "88 6 2 1.4585 0.664569 0.000099 2777.0 \n", + "89 6 4 1.4188 0.713518 0.000098 5156.0 \n", + "90 6 6 1.3745 0.689009 0.000096 7839.0 \n", + "91 6 8 1.3357 0.698827 0.000095 10420.0 \n", + "\n", + " mean_token_accuracy chunk number num_epochs_completed eval_loss ... \\\n", + "0 0.647757 0.0 0.0 NaN ... \n", + "1 0.646882 0.0 0.0 1.362228 ... \n", + "2 0.660024 0.0 0.0 NaN ... \n", + "3 0.672556 0.0 0.0 1.270336 ... \n", + "4 0.657389 0.0 0.0 NaN ... \n", + ".. ... ... ... ... ... \n", + "87 0.823853 1.0 0.0 0.511950 ... \n", + "88 0.650270 0.0 0.0 NaN ... \n", + "89 0.646882 0.0 0.0 1.359361 ... \n", + "90 0.661781 0.0 0.0 NaN ... \n", + "91 0.674366 0.0 0.0 1.270993 ... \n", + "\n", + " eval_steps_per_second eval_num_tokens eval_mean_token_accuracy rougeL \\\n", + "0 NaN NaN NaN NaN \n", + "1 9.382 5156.0 0.660436 0.2435 \n", + "2 NaN NaN NaN NaN \n", + "3 9.256 10420.0 0.678166 0.2545 \n", + "4 NaN NaN NaN NaN \n", + ".. ... ... ... ... \n", + "87 9.545 10770.0 0.847943 0.5288 \n", + "88 NaN NaN NaN NaN \n", + "89 9.399 5156.0 0.660033 0.2432 \n", + "90 NaN NaN NaN NaN \n", + "91 9.427 10420.0 0.676700 0.2569 \n", + "\n", + " bleu train_runtime train_samples_per_second train_steps_per_second \\\n", + "0 NaN NaN NaN NaN \n", + "1 0.0801 NaN NaN NaN \n", + "2 NaN NaN NaN NaN \n", + "3 0.0864 132.3558 0.242 0.060 \n", + "4 NaN NaN NaN NaN \n", + ".. ... ... ... ... \n", + "87 0.4742 144.3733 0.222 0.055 \n", + "88 NaN NaN NaN NaN \n", + "89 0.0777 NaN NaN NaN \n", + "90 NaN NaN NaN NaN \n", + "91 0.0883 134.7570 0.237 0.059 \n", + "\n", + " total_flos train_loss \n", + "0 NaN NaN \n", + "1 NaN NaN \n", + "2 NaN NaN \n", + "3 7.358093e+13 1.399039 \n", + "4 NaN NaN \n", + ".. ... ... \n", + "87 7.445749e+13 0.588614 \n", + "88 NaN NaN \n", + "89 NaN NaN \n", + "90 NaN NaN \n", + "91 7.310439e+13 1.396876 \n", + "\n", + "[92 rows x 22 columns]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results = experiment.get_results()\n", + "results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### End Experiment" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Experiment exp-optuna-chatqa-tiny1_16 ended\n", + "Workers stopped\n" + ] + } + ], + "source": [ + "experiment.end()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "### Comparison: RFOptuna vs RFGridSearch vs RFRandomSearch\n", + "\n", + "| Feature | RFGridSearch | RFRandomSearch | RFOptuna |\n", + "|---|---|---|---|\n", + "| Config selection | All combos upfront | Random sample upfront | Bayesian (TPE) — learns from results |\n", + "| Pruning | Manual via IC Ops | Manual via IC Ops | Automatic (Median / Hyperband pruner) |\n", + "| Replacement | Manual clone-modify | Manual clone-modify | Automatic — new suggestions within budget |\n", + "| Search space | `List([...])` | `List([...])`, `Range(...)` | `List([...])`, `Range(...)` |\n", + "| Best for | Small discrete spaces | Large spaces, no adaptation | Large spaces, adaptive exploration |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "\n", + "\n", + "
\n", + "Thanks for trying RapidFire AI! ⭐ Star us on GitHub ⭐\n", + "
" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tutorial_notebooks/rag-contexteng/rf-tutorial-optuna-rag-fiqa.ipynb b/tutorial_notebooks/rag-contexteng/rf-tutorial-optuna-rag-fiqa.ipynb new file mode 100644 index 00000000..8b2ef332 --- /dev/null +++ b/tutorial_notebooks/rag-contexteng/rf-tutorial-optuna-rag-fiqa.ipynb @@ -0,0 +1,616 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "\n", + "\n", + "
\n", + "Join Discord if you need help + ⭐ Star us on GitHub ⭐\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## RapidFire AI + Optuna: Adaptive RAG Hyperparameter Search on FiQA\n", + "\n", + "This tutorial shows how to use **RFOptuna** for Bayesian hyperparameter optimization\n", + "of a RAG (Retrieval-Augmented Generation) pipeline on the **FiQA** financial Q&A dataset.\n", + "\n", + "**Key difference from RFGridSearch / RFRandomSearch:**\n", + "- Grid/Random search decide all configs upfront\n", + "- **RFOptuna** uses Optuna's TPE sampler to suggest initial configs, then **prunes underperforming pipelines after each shard** and replaces them with smarter suggestions\n", + "\n", + "**What this notebook demonstrates:**\n", + "- Defining a RAG search space with `List(...)` over retrieval parameters (chunk size, reranker top-n)\n", + "- Using `RFOptuna` in **evals mode** to adaptively search the space\n", + "- Optuna-driven pruning of weak retrieval configs mid-evaluation\n", + "- Inspecting the Optuna study to find the best RAG configuration" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Enable Metric Loggers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"RF_MLFLOW_ENABLED\"] = \"true\"\n", + "os.environ[\"RF_TENSORBOARD_ENABLED\"] = \"true\"\n", + "os.environ[\"RF_TRACKIO_ENABLED\"] = \"true\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Imports\n", + "\n", + "Note: `RFOptuna` requires the `optuna` package. Install with:\n", + "```bash\n", + "pip install optuna\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from rapidfireai import Experiment\n", + "from rapidfireai.automl import (\n", + " List,\n", + " Range,\n", + " RFOptuna,\n", + " RFLangChainRagSpec,\n", + " RFvLLMModelConfig,\n", + " RFPromptManager,\n", + ")\n", + "import math\n", + "import pandas as pd\n", + "from pathlib import Path\n", + "from typing import List as listtype, Dict, Any" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load FiQA Dataset and Relevance Labels" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset_dir = Path(\"datasets\")\n", + "\n", + "fiqa_dataset = load_dataset(\n", + " \"json\", data_files=str(dataset_dir / \"fiqa\" / \"queries.jsonl\"), split=\"train\"\n", + ").select(range(256))\n", + "fiqa_dataset = fiqa_dataset.rename_columns({\"text\": \"query\", \"_id\": \"query_id\"})\n", + "\n", + "qrels = pd.read_csv(str(dataset_dir / \"fiqa\" / \"qrels.tsv\"), sep=\"\\t\")\n", + "qrels = qrels.rename(\n", + " columns={\"query-id\": \"query_id\", \"corpus-id\": \"corpus_id\", \"score\": \"relevance\"}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create Experiment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "experiment = Experiment(experiment_name=\"exp-optuna-rag-fiqa\", mode=\"evals\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define RAG Pipeline with Search Space\n", + "\n", + "Instead of listing every combination (GridSearch) or sampling blindly (RandomSearch),\n", + "we define a **search space** using `List(...)` and `Range(...)` inside the RAG spec.\n", + "Optuna's TPE sampler will intelligently explore this space.\n", + "\n", + "**What Optuna controls:**\n", + "- `text_splitter`: Chunk size (128 vs 256 tokens) — `List(...)` (discrete objects)\n", + "- `search_cfg.k`: Initial retrieval candidates (5–20, step 5) — `Range(...)` (stepped int)\n", + "- `reranker_cfg.top_n`: Documents after reranking (2–10, step 2) — `Range(...)` (stepped int)\n", + "\n", + "**What stays fixed:**\n", + "- Embedding model (`all-MiniLM-L6-v2`), FAISS vector store, similarity search type" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.document_loaders import DirectoryLoader, JSONLoader\n", + "from langchain_text_splitters import RecursiveCharacterTextSplitter\n", + "from langchain_huggingface import HuggingFaceEmbeddings\n", + "from langchain_classic.retrievers.document_compressors import CrossEncoderReranker\n", + "from langchain_community.cross_encoders import HuggingFaceCrossEncoder\n", + "\n", + "batch_size = 128\n", + "\n", + "rag_spec = RFLangChainRagSpec(\n", + " document_loader=DirectoryLoader(\n", + " path=str(dataset_dir / \"fiqa\"),\n", + " glob=\"corpus.jsonl\",\n", + " loader_cls=JSONLoader,\n", + " loader_kwargs={\n", + " \"jq_schema\": \".\",\n", + " \"content_key\": \"text\",\n", + " \"metadata_func\": lambda record, metadata: {\n", + " \"corpus_id\": int(record.get(\"_id\"))\n", + " },\n", + " \"json_lines\": True,\n", + " \"text_content\": False,\n", + " },\n", + " sample_seed=42,\n", + " ),\n", + " text_splitter=List([\n", + " RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n", + " encoding_name=\"gpt2\", chunk_size=256, chunk_overlap=32\n", + " ),\n", + " RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n", + " encoding_name=\"gpt2\", chunk_size=128, chunk_overlap=32\n", + " ),\n", + " ]),\n", + " embedding_cfg={\n", + " \"class\": HuggingFaceEmbeddings,\n", + " \"model_name\": \"sentence-transformers/all-MiniLM-L6-v2\",\n", + " \"model_kwargs\": {\"device\": \"cuda:0\"},\n", + " \"encode_kwargs\": {\"normalize_embeddings\": True, \"batch_size\": batch_size},\n", + " },\n", + " vector_store_cfg={\n", + " \"type\": \"faiss\",\n", + " },\n", + " search_cfg={\n", + " \"type\": \"similarity\",\n", + " \"k\": Range(5, 20, step=5),\n", + " },\n", + " reranker_cfg={\n", + " \"class\": CrossEncoderReranker,\n", + " \"model_name\": \"cross-encoder/ms-marco-MiniLM-L6-v2\",\n", + " \"model_kwargs\": {\"device\": \"cuda:0\"},\n", + " \"top_n\": List([2,5])\n", + " },\n", + " enable_gpu_search=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define Data Processing Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def sample_preprocess_fn(\n", + " batch: Dict[str, listtype], rag: RFLangChainRagSpec, prompt_manager: RFPromptManager\n", + ") -> Dict[str, listtype]:\n", + " \"\"\"Prepare the final inputs given to the generator model.\"\"\"\n", + " INSTRUCTIONS = (\n", + " \"Utilize your financial knowledge, give your answer or opinion \"\n", + " \"to the input question or subject matter.\"\n", + " )\n", + "\n", + " all_context = rag.get_context(batch_queries=batch[\"query\"], serialize=False)\n", + "\n", + " retrieved_documents = [\n", + " [doc.metadata[\"corpus_id\"] for doc in docs] for docs in all_context\n", + " ]\n", + "\n", + " serialized_context = rag.serialize_documents(all_context)\n", + " batch[\"query_id\"] = [int(query_id) for query_id in batch[\"query_id\"]]\n", + "\n", + " return {\n", + " \"prompts\": [\n", + " [\n", + " {\"role\": \"system\", \"content\": INSTRUCTIONS},\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": (\n", + " f\"Here is some relevant context:\\n{context}. \"\n", + " f\"\\nNow answer the following question using the \"\n", + " f\"context provided earlier:\\n{question}\"\n", + " ),\n", + " },\n", + " ]\n", + " for question, context in zip(batch[\"query\"], serialized_context)\n", + " ],\n", + " \"retrieved_documents\": retrieved_documents,\n", + " **batch,\n", + " }\n", + "\n", + "\n", + "def sample_postprocess_fn(batch: Dict[str, listtype]) -> Dict[str, listtype]:\n", + " \"\"\"Attach ground truth documents for metric computation.\"\"\"\n", + " batch[\"ground_truth_documents\"] = [\n", + " qrels[qrels[\"query_id\"] == query_id][\"corpus_id\"].tolist()\n", + " for query_id in batch[\"query_id\"]\n", + " ]\n", + " return batch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define Custom RAG Eval Metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_ndcg_at_k(retrieved_docs: set, expected_docs: set, k=5):\n", + " \"\"\"Compute NDCG@k for a single query.\"\"\"\n", + " relevance = [1 if doc in expected_docs else 0 for doc in list(retrieved_docs)[:k]]\n", + " dcg = sum(rel / math.log2(i + 2) for i, rel in enumerate(relevance))\n", + " ideal_length = min(k, len(expected_docs))\n", + " ideal_relevance = [3] * ideal_length + [0] * (k - ideal_length)\n", + " idcg = sum(rel / math.log2(i + 2) for i, rel in enumerate(ideal_relevance))\n", + " return dcg / idcg if idcg > 0 else 0.0\n", + "\n", + "\n", + "def compute_rr(retrieved_docs: set, expected_docs: set):\n", + " \"\"\"Compute Reciprocal Rank (RR) for a single query.\"\"\"\n", + " for i, retrieved_doc in enumerate(retrieved_docs):\n", + " if retrieved_doc in expected_docs:\n", + " return 1 / (i + 1)\n", + " return 0\n", + "\n", + "\n", + "def sample_compute_metrics_fn(batch: Dict[str, listtype]) -> Dict[str, Dict[str, Any]]:\n", + " \"\"\"Compute retrieval metrics per batch.\"\"\"\n", + " precisions, recalls, f1_scores, ndcgs, rrs = [], [], [], [], []\n", + " total_queries = len(batch[\"query\"])\n", + "\n", + " for pred, gt in zip(batch[\"retrieved_documents\"], batch[\"ground_truth_documents\"]):\n", + " expected_set = set(gt)\n", + " retrieved_set = set(pred)\n", + "\n", + " true_positives = len(expected_set.intersection(retrieved_set))\n", + " precision = true_positives / len(retrieved_set) if len(retrieved_set) > 0 else 0\n", + " recall = true_positives / len(expected_set) if len(expected_set) > 0 else 0\n", + " f1 = (\n", + " 2 * precision * recall / (precision + recall)\n", + " if (precision + recall) > 0\n", + " else 0\n", + " )\n", + "\n", + " precisions.append(precision)\n", + " recalls.append(recall)\n", + " f1_scores.append(f1)\n", + " ndcgs.append(compute_ndcg_at_k(retrieved_set, expected_set, k=5))\n", + " rrs.append(compute_rr(retrieved_set, expected_set))\n", + "\n", + " return {\n", + " \"Total\": {\"value\": total_queries},\n", + " \"Precision\": {\"value\": sum(precisions) / total_queries},\n", + " \"Recall\": {\"value\": sum(recalls) / total_queries},\n", + " \"F1 Score\": {\"value\": sum(f1_scores) / total_queries},\n", + " \"NDCG@5\": {\"value\": sum(ndcgs) / total_queries},\n", + " \"MRR\": {\"value\": sum(rrs) / total_queries},\n", + " }\n", + "\n", + "\n", + "def sample_accumulate_metrics_fn(\n", + " aggregated_metrics: Dict[str, listtype],\n", + ") -> Dict[str, Dict[str, Any]]:\n", + " \"\"\"Accumulate metrics across all batches.\"\"\"\n", + " num_queries_per_batch = [m.get(\"value\", 0) for m in aggregated_metrics.get(\"Total\", [])]\n", + " total_queries = sum(num_queries_per_batch)\n", + " algebraic_metrics = [\"Precision\", \"Recall\", \"F1 Score\", \"NDCG@5\", \"MRR\"]\n", + "\n", + " return {\n", + " \"Total\": {\"value\": total_queries},\n", + " **{\n", + " metric: {\n", + " \"value\": sum(\n", + " m[\"value\"] * queries\n", + " for m, queries in zip(\n", + " aggregated_metrics[metric], num_queries_per_batch\n", + " )\n", + " )\n", + " / total_queries,\n", + " \"is_algebraic\": True,\n", + " \"value_range\": (0, 1),\n", + " }\n", + " for metric in algebraic_metrics\n", + " },\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define vLLM Generator Config with RAG\n", + "\n", + "We use a single generator model (Qwen2.5-0.5B-Instruct) and let Optuna\n", + "search over the **retrieval parameters** defined in the RAG spec above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vllm_config = RFvLLMModelConfig(\n", + " model_config={\n", + " \"model\": \"Qwen/Qwen2.5-0.5B-Instruct\",\n", + " \"dtype\": \"half\",\n", + " \"gpu_memory_utilization\": 0.7,\n", + " \"tensor_parallel_size\": 1,\n", + " \"distributed_executor_backend\": \"mp\",\n", + " \"enable_chunked_prefill\": False,\n", + " \"enable_prefix_caching\": True,\n", + " \"max_model_len\": 4096,\n", + " \"disable_log_stats\": True,\n", + " },\n", + " sampling_params={\n", + " \"temperature\": 0.8,\n", + " \"top_p\": 0.95,\n", + " \"max_tokens\": 512,\n", + " },\n", + " rag=rag_spec,\n", + " prompt_manager=None,\n", + ")\n", + "\n", + "config_template = {\n", + " \"vllm_config\": vllm_config,\n", + " \"batch_size\": batch_size,\n", + " \"preprocess_fn\": sample_preprocess_fn,\n", + " \"postprocess_fn\": sample_postprocess_fn,\n", + " \"compute_metrics_fn\": sample_compute_metrics_fn,\n", + " \"accumulate_metrics_fn\": sample_accumulate_metrics_fn,\n", + " \"online_strategy_kwargs\": {\n", + " \"strategy_name\": \"normal\",\n", + " \"confidence_level\": 0.95,\n", + " \"use_fpc\": True,\n", + " },\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create RFOptuna Config Group\n", + "\n", + "Key parameters:\n", + "- **`n_initial=4`**: Start with 4 pipeline configs (sampled by TPE)\n", + "- **`budget=6`**: Allow up to 6 total configs (including replacements for pruned pipelines)\n", + "- **`objective=\"maximize:NDCG@5\"`**: Optuna maximizes NDCG@5 to decide pruning\n", + "- **`sampler=\"tpe\"`**: Tree-structured Parzen Estimator (Bayesian)\n", + "- **`pruner=\"median\"`**: Prune pipelines performing worse than the median at each shard\n", + "\n", + "Note: `trainer_type` is `None` for evals mode (RAG). The search space comes from\n", + "`List(...)` and `Range(...)` values embedded in the `rag_spec` and `vllm_config`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config_group = RFOptuna(\n", + " configs=[config_template],\n", + " trainer_type=None,\n", + " n_initial=4,\n", + " budget=6,\n", + " objective=\"maximize:NDCG@5\",\n", + " sampler=\"tpe\",\n", + " pruner=\"median\",\n", + " seed=42,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run Optuna-Powered RAG Evaluation\n", + "\n", + "Behind the scenes:\n", + "1. `RFOptuna.get_runs()` asks Optuna's TPE sampler for 4 initial RAG configs\n", + "2. RapidFire evaluates all 4 concurrently using shard-based scheduling\n", + "3. After each shard completes, the Optuna callback:\n", + " - Reports the pipeline's `NDCG@5` to Optuna\n", + " - Checks if Optuna's median pruner wants to prune the pipeline\n", + " - If pruned, suggests a replacement config (up to budget of 6)\n", + "4. RapidFire stops pruned pipelines and starts replacements automatically" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results = experiment.run_evals(\n", + " config_group=config_group,\n", + " dataset=fiqa_dataset,\n", + " num_shards=4,\n", + " seed=42,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Inspect Optuna Study Results\n", + "\n", + "After evaluation, the Optuna study object is accessible on the `RFOptuna` instance.\n", + "You can inspect which trials were completed, pruned, or failed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from optuna.trial import TrialState\n", + "\n", + "study = config_group._study\n", + "\n", + "completed = [t for t in study.trials if t.state == TrialState.COMPLETE]\n", + "pruned = [t for t in study.trials if t.state == TrialState.PRUNED]\n", + "\n", + "print(f\"Trials: {len(study.trials)} total, {len(completed)} completed, {len(pruned)} pruned\")\n", + "try:\n", + " bt = study.best_trial\n", + " print(f\"Best trial: #{bt.number} — NDCG@5 = {bt.value:.4f}\")\n", + "except (RuntimeError, ValueError) as exc:\n", + " print(f\"No best trial available yet ({exc}).\")\n", + "\n", + "trials_df = pd.DataFrame([\n", + " {\n", + " \"trial\": t.number,\n", + " \"state\": t.state.name,\n", + " **{k.split(\".\")[-1]: v for k, v in t.params.items()},\n", + " \"NDCG@5\": f\"{t.value:.4f}\" if t.value is not None else \"—\",\n", + " }\n", + " for t in study.trials\n", + "])\n", + "trials_df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### View Results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results_df = pd.DataFrame([\n", + " {\n", + " k: v[\"value\"] if isinstance(v, dict) and \"value\" in v else v\n", + " for k, v in {**metrics_dict, \"run_id\": run_id}.items()\n", + " }\n", + " for run_id, (_, metrics_dict) in results.items()\n", + "])\n", + "\n", + "results_df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### End Experiment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "experiment.end()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "### Comparison: RFOptuna vs RFGridSearch vs RFRandomSearch for RAG\n", + "\n", + "| Feature | RFGridSearch | RFRandomSearch | RFOptuna |\n", + "|---|---|---|---|\n", + "| Config selection | All combos upfront | Random sample upfront | Bayesian (TPE) — learns from results |\n", + "| Pruning | Manual via IC Ops | Manual via IC Ops | Automatic (Median / Hyperband pruner) |\n", + "| Replacement | Manual clone-modify | Manual clone-modify | Automatic — new suggestions within budget |\n", + "| Search space | `List([...])` | `List([...])`, `Range(...)` | `List([...])`, `Range(...)` |\n", + "| Best for | Small discrete spaces | Large spaces, no adaptation | Large spaces, adaptive exploration |\n", + "| RAG use case | Few chunk/reranker combos | Many retrieval variants | Optimizing retrieval quality adaptively |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "\n", + "\n", + "
\n", + "Thanks for trying RapidFire AI! ⭐ Star us on GitHub ⭐\n", + "
" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}