Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ dependencies = [
]

[project.optional-dependencies]
optuna = [
"optuna>=4.0.0",
]
# TODO: update these dependencies to the correct versions
colab_evals = [
# Core ML/AI Framework
Expand Down
23 changes: 23 additions & 0 deletions rapidfireai/automl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,27 @@
from .random_search import RFRandomSearch
from .automl_utils import get_flattened_config_leaf, get_runs

# Optuna integration (conditionally available)
try:
from .optuna_search import RFOptuna
_OPTUNA_AVAILABLE = True
except ImportError as _optuna_import_error:

class RFOptuna: # type: ignore[misc]
"""Stub so imports succeed; instantiation explains how to enable Optuna."""

def __new__(cls, *args, **kwargs): # noqa: ARG004
raise ImportError(
"RFOptuna requires Optuna importable from this Python environment. "
"Install into the **same interpreter as your Jupyter kernel**, then restart the kernel:\n"
" python -m pip install optuna\n"
"Check in a notebook cell: import sys; print(sys.executable)\n"
"Original error: "
+ str(_optuna_import_error)
) from _optuna_import_error

_OPTUNA_AVAILABLE = False

# Import fit mode configs (conditionally available)
try:
from .model_config import (
Expand Down Expand Up @@ -60,6 +81,8 @@
"get_runs",
]

__all__.append("RFOptuna")

# Conditionally add fit mode configs to __all__
if _FIT_CONFIGS_AVAILABLE:
__all__.extend([
Expand Down
73 changes: 57 additions & 16 deletions rapidfireai/automl/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
"""Base classes and configurations for AutoML algorithms."""
"""Base class for AutoML search algorithms.

Classes
-------
AutoMLAlgorithm
Abstract base subclassed by ``RFGridSearch``, ``RFRandomSearch``,
and ``RFOptuna``.
"""

from abc import ABC, abstractmethod
from typing import Any
Expand All @@ -8,7 +15,35 @@


class AutoMLAlgorithm(ABC):
"""Base class for AutoML algorithms."""
"""Abstract base class for AutoML search strategies.

Parameters
----------
configs :
Config templates (``RFModelConfig`` for fit, dicts for evals).
Accepts a list, a ``List([...])`` wrapper, or a single object.
create_model_fn :
Legacy parameter (unused).
trainer_type : str or None
``"SFT"`` / ``"DPO"`` / ``"GRPO"`` for fit mode, ``None`` for evals.
num_runs : int
Number of samples (used by ``RFRandomSearch``).

Attributes
----------
configs : list
mode : str
``"fit"`` or ``"evals"``.
trainer_type : str or None
num_runs : int

Methods
-------
get_runs(seed) -> list[dict]
Return concrete config-leaf dicts.
get_callback(**kwargs) -> ChunkCallback | ShardCallback | None
Return an optional inter-step pruning callback.
"""

VALID_TRAINER_TYPES = {"SFT", "DPO", "GRPO"}

Expand All @@ -19,19 +54,6 @@ def __init__(
trainer_type: str | None = None,
num_runs: int = 1,
):
"""
Initialize AutoML algorithm with configurations and trainer type.

Args:
configs: List of configurations (RFModelConfig for fit mode, dict for evals mode)
create_model_fn: Optional function to create models (legacy parameter)
trainer_type: Trainer type ("SFT", "DPO", "GRPO") for fit mode, None for evals mode
num_runs: Number of runs for random search

Mode detection:
- If trainer_type is provided: fit mode (requires RFModelConfig instances)
- If trainer_type is None: evals mode (requires dict instances)
"""
try:
self.configs = self._normalize_configs(configs)
self.num_runs = num_runs
Expand Down Expand Up @@ -87,8 +109,27 @@ def _validate_configs(self):
f"If you want fit mode, provide a trainer_type."
)

def get_callback(self, **kwargs):
"""Return an optional callback for inter-chunk/shard pruning decisions.

Returns
-------
ChunkCallback or ShardCallback or None
"""
return None

@abstractmethod
def get_runs(self, seed: int) -> list[dict[str, Any]]:
"""Generate hyperparameter combinations for different training configurations."""
"""Return concrete config-leaf dicts for the controller.

Parameters
----------
seed : int
Non-negative random seed.

Returns
-------
list[dict[str, Any]]
"""
if not isinstance(seed, int) or seed < 0:
raise AutoMLException("seed must be a non-negative integer")
118 changes: 118 additions & 0 deletions rapidfireai/automl/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Callback protocols for inter-chunk/shard decision-making during experiments.

Classes
-------
RunDecision
Dataclass returned by ``ChunkCallback.on_chunk_complete`` (fit mode).
PipelineDecision
Dataclass returned by ``ShardCallback.on_shard_complete`` (evals mode).
ChunkCallback
Protocol for fit-mode inter-chunk pruning callbacks.
ShardCallback
Protocol for evals-mode inter-shard pruning callbacks.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Literal, Protocol


@dataclass
class RunDecision:
"""Decision returned by a ``ChunkCallback`` after a fit-mode chunk completes.

Attributes
----------
action : ``"continue"`` or ``"prune"``
replacement_config : dict or None
Config-leaf dict for a replacement run, or ``None``.
"""

action: Literal["continue", "prune"]
replacement_config: dict[str, Any] | None = None


@dataclass
class PipelineDecision:
"""Decision returned by a ``ShardCallback`` after an evals-mode shard completes.

Attributes
----------
action : ``"continue"`` or ``"prune"``
replacement_config : dict or None
Config-leaf dict for a replacement pipeline, or ``None``.
"""

action: Literal["continue", "prune"]
replacement_config: dict[str, Any] | None = None


class ChunkCallback(Protocol):
"""Protocol for callbacks invoked after each chunk in fit mode.

Call order: ``register_runs`` → ``on_chunk_complete`` (repeated) → ``finalize``.
"""

def register_runs(self, run_id_to_config: dict[int, dict[str, Any]]) -> None:
"""Map newly created DB run IDs to their config dicts."""
...

def on_chunk_complete(
self,
run_id: int,
chunk_id: int,
metrics: dict[str, Any],
) -> RunDecision:
"""Evaluate a run after it finishes a chunk.

Parameters
----------
run_id : int
chunk_id : int
metrics : dict[str, Any]

Returns
-------
RunDecision
"""
...

def finalize(self, final_metrics: dict[int, dict[str, Any]]) -> None:
"""Called after the experiment loop ends."""
...


class ShardCallback(Protocol):
"""Protocol for callbacks invoked after each shard in evals mode.

Call order: ``register_pipelines`` → ``on_shard_complete`` (repeated) → ``finalize``.
"""

def register_pipelines(self, pipeline_id_to_config: dict[int, dict[str, Any]]) -> None:
"""Map newly created DB pipeline IDs to their config dicts."""
...

def on_shard_complete(
self,
pipeline_id: int,
shard_id: int,
metrics: dict[str, Any],
) -> PipelineDecision:
"""Evaluate a pipeline after it finishes a shard.

Parameters
----------
pipeline_id : int
shard_id : int
metrics : dict[str, Any]

Returns
-------
PipelineDecision
"""
...

def finalize(self, final_metrics: dict[int, dict[str, Any]]) -> None:
"""Called after the experiment loop ends."""
...
77 changes: 69 additions & 8 deletions rapidfireai/automl/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,46 @@
"""Contains classes for representing hyperparameter data types."""
"""Contains classes for representing hyperparameter data types.

import random
Covers all Optuna distribution types:

- ``Range(start, end, dtype="float")`` → ``FloatDistribution`` / ``suggest_float``
- ``Range(start, end, dtype="float", log=True)`` → log-uniform float
- ``Range(start, end, dtype="float", step=0.1)`` → discrete float
- ``Range(start, end, dtype="int")`` → ``IntDistribution`` / ``suggest_int``
- ``Range(start, end, dtype="int", log=True)`` → log-uniform int
- ``Range(start, end, dtype="int", step=2)`` → stepped int
- ``List([...])`` → ``CategoricalDistribution`` / ``suggest_categorical``
"""

# TODO: need to set seed for random module.
# TODO: List.sample() will not work for nested lists.
# TODO: add support for sampling methods like 'uniform' and 'loguniform'.
import math
import random


class Range:
"""Represents a range of values for a hyperparameter."""
"""Represents a range of values for a hyperparameter.

Supports uniform, log-uniform, and discrete (stepped) sampling for both
int and float dtypes — matching all variants of Optuna's
``IntDistribution`` and ``FloatDistribution``.

Args:
start: Lower bound (inclusive).
end: Upper bound (inclusive).
dtype: ``"int"`` or ``"float"``. Inferred from *start*/*end* types
when not provided.
log: If ``True``, sample in log-space (start and end must be > 0).
Mutually exclusive with *step*.
step: Discretisation step. When set, sampled values are multiples of
*step* starting from *start*. Mutually exclusive with *log*.
"""

def __init__(self, start, end, dtype: str | None = None):
def __init__(
self,
start,
end,
dtype: str | None = None,
log: bool = False,
step: int | float | None = None,
):
if dtype is None:
self.dtype = (
"int" if isinstance(start, int) and isinstance(end, int) else "float"
Expand All @@ -21,13 +51,44 @@ def __init__(self, start, end, dtype: str | None = None):
self.dtype = dtype
if not (isinstance(start, int | float) and isinstance(end, int | float)):
raise ValueError("start and end must be either int or float.")
if log and step is not None:
raise ValueError(
"log=True and step are mutually exclusive "
"(Optuna does not support this combination either)."
)
if log and (start <= 0 or end <= 0):
raise ValueError(
"log=True requires both start and end to be > 0."
)
self.start = start
self.end = end
self.log = log
self.step = step

def sample(self):
"""Sample a value from the range [self.start, self.end]."""
"""Sample a value from the range [self.start, self.end].

Respects *log* (log-uniform) and *step* (discrete) settings so that
``RFRandomSearch`` produces the same family of distributions as
Optuna's ``suggest_int`` / ``suggest_float``.
"""
if self.dtype == "int":
if self.log:
log_low, log_high = math.log(self.start), math.log(self.end)
return int(round(math.exp(random.uniform(log_low, log_high))))
if self.step is not None:
step = int(self.step)
n_steps = (self.end - self.start) // step
return self.start + random.randint(0, n_steps) * step
return random.randint(self.start, self.end)

# dtype == "float"
if self.log:
log_low, log_high = math.log(self.start), math.log(self.end)
return math.exp(random.uniform(log_low, log_high))
if self.step is not None:
n_steps = int((self.end - self.start) / self.step)
return self.start + random.randint(0, n_steps) * self.step
return random.uniform(self.start, self.end)


Expand Down
Loading