Skip to content

Optuna Integration for RapidFire AI#234

Open
humaira-rf wants to merge 8 commits intomainfrom
feature/optuna-integration
Open

Optuna Integration for RapidFire AI#234
humaira-rf wants to merge 8 commits intomainfrom
feature/optuna-integration

Conversation

@humaira-rf
Copy link
Copy Markdown
Collaborator

Summary

Adds Optuna-powered hyperparameter optimization as a drop-in replacement for RFGridSearch / RFRandomSearch. Optuna's Bayesian samplers (TPE, CMA-ES) and pruners (Median, Hyperband) now work natively with RapidFire's chunk-based/epoch-based fit loop and shard-based evals loop, enabling smarter search and early stopping of underperforming runs.

What's included

Core Optuna engine (rapidfireai/automl/optuna_search.py — new, ~815 lines)

  • RFOptuna class — user-facing AutoMLAlgorithm subclass that creates an Optuna study, samples initial configs via get_runs()
  • Supports TPE, CMA-ES, and Random samplers; Median and Hyperband pruners
  • OptunaChunkCallback — fit-mode callback that reports training metrics to Optuna after each chunk, evaluates trial.should_prune(), and suggests replacement configs within a budget
  • OptunaShardCallback — evals-mode callback that does the same for pipeline shards

Extended Range / List datatypes (rapidfireai/automl/datatypes.py)

  • Range now supports log=True (log-uniform sampling) and step=... (discrete stepped sampling), matching Optuna's full FloatDistribution / IntDistribution variants
  • Backward compatible — existing Range(start, end) calls work unchanged

Packaging (pyproject.toml)

  • Optuna added as optional dependency: pip install rapidfireai[optuna]
  • Conditional import with a helpful stub error message when Optuna is not installed

Tutorial notebooks

  • tutorial_notebooks/fine-tuning/rf-tutorial-optuna-sft-chatqa-tiny.ipynb — SFT fine-tuning with Optuna
  • tutorial_notebooks/rag-contexteng/rf-tutorial-optuna-rag-fiqa.ipynb — RAG evals with Optuna

Usage example

from rapidfireai import Experiment
from rapidfireai.automl import RFOptuna, RFModelConfig, RFSFTConfig, Range, List

config = RFOptuna(
    configs=[RFModelConfig(
        model_name="model",
        training_args=RFSFTConfig(learning_rate=Range(1e-6, 1e-3, log=True)),
    )],
    trainer_type="SFT",
    n_initial=8,
    budget=20,
    objective="minimize:eval_loss",
    sampler="tpe",
    pruner="median",
    granularity="epoch",   # or "chunk" (default)
)

exp = Experiment("my_optuna_experiment")
exp.run_fit(config, create_model_fn, train_data, eval_data, num_chunks=8)

Tests

  • Tested both Optuna notebooks end-to-end with: Multiple search space configurations and pruner choices, Single-objective and multi-objective optimization,
  • Tested SFT-lite notebook and RAG FiQA notebook
  • All scenarios additionally validated with ICOps

@humaira-rf humaira-rf requested a review from arun-rfai April 28, 2026 18:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant