diff --git a/logfire/experimental/_gepa/__init__.py b/logfire/experimental/_gepa/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/logfire/experimental/_gepa/_gepa.py b/logfire/experimental/_gepa/_gepa.py new file mode 100644 index 000000000..7b05c4b19 --- /dev/null +++ b/logfire/experimental/_gepa/_gepa.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from contextlib import ExitStack +from dataclasses import dataclass +from typing import Any, Generic + +from gepa.core.adapter import DataInst, EvaluationBatch, GEPAAdapter, RolloutOutput, Trajectory +from pydantic import ValidationError + +import logfire +from logfire._internal.utils import JsonDict + + +@dataclass +class ReflectionItem(Generic[Trajectory, RolloutOutput]): + candidate: dict[str, str] + component: str + trajectory: Trajectory + score: float + output: RolloutOutput + + +class SimpleReflectionAdapterMixin(GEPAAdapter[DataInst, Trajectory, RolloutOutput], ABC): + def make_reflective_dataset( + self, + candidate: dict[str, str], + eval_batch: EvaluationBatch[Trajectory, RolloutOutput], + components_to_update: list[str], + ): + assert len(components_to_update) == 1 + component = components_to_update[0] + + assert eval_batch.trajectories, 'Trajectories are required to build a reflective dataset.' + + return { + component: [ + self.reflect_item(ReflectionItem(candidate, component, *inst)) + for inst in zip(eval_batch.trajectories, eval_batch.scores, eval_batch.outputs, strict=True) + ] + } + + @abstractmethod + def reflect_item(self, item: ReflectionItem[Trajectory, RolloutOutput]) -> JsonDict: ... + + +@dataclass +class EvaluationInput(Generic[DataInst]): + data: DataInst + """Inputs to the task/function being evaluated.""" + + candidate: dict[str, str] + capture_traces: bool + + +@dataclass +class EvaluationResult(Generic[Trajectory, RolloutOutput]): + output: RolloutOutput + score: float + trajectory: Trajectory | None = None + + +class SimpleEvaluateAdapterMixin(GEPAAdapter[DataInst, Trajectory, RolloutOutput], ABC): + def evaluate( + self, + batch: list[DataInst], + candidate: dict[str, str], + capture_traces: bool = False, + ): + outputs: list[RolloutOutput] = [] + scores: list[float] = [] + trajectories: list[Trajectory] | None = [] if capture_traces else None + + for data in batch: + eval_input = EvaluationInput(data=data, candidate=candidate, capture_traces=capture_traces) + eval_result = self.evaluate_instance(eval_input) + outputs.append(eval_result.output) + scores.append(eval_result.score) + + if trajectories is not None: + assert eval_result.trajectory + trajectories.append(eval_result.trajectory) + + return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajectories) + + @abstractmethod + def evaluate_instance( + self, eval_input: EvaluationInput[DataInst] + ) -> EvaluationResult[Trajectory, RolloutOutput]: ... + + +class SimpleProposeNewTextsAdapterMixin( + GEPAAdapter[DataInst, Trajectory, RolloutOutput], + ABC, +): + def __init__(self, *args: Any, **kwargs: Any): + self.propose_new_texts = self.propose_new_texts_impl + super().__init__(*args, **kwargs) + + @abstractmethod + def propose_new_texts_impl( + self, + candidate: dict[str, str], + reflective_dataset: Mapping[str, Sequence[Mapping[str, Any]]], + components_to_update: list[str], + ) -> dict[str, str]: ... + + +class CombinedSimpleAdapterMixin( + SimpleEvaluateAdapterMixin[DataInst, Trajectory, RolloutOutput], + SimpleReflectionAdapterMixin[DataInst, Trajectory, RolloutOutput], + SimpleProposeNewTextsAdapterMixin[DataInst, Trajectory, RolloutOutput], + ABC, +): + pass + + +@dataclass +class AdapterWrapper(GEPAAdapter[DataInst, Trajectory, RolloutOutput]): + wrapped: GEPAAdapter[DataInst, Trajectory, RolloutOutput] + + def __post_init__(self): + self.propose_new_texts = self.wrapped.propose_new_texts + + def evaluate( + self, + batch: list[DataInst], + candidate: dict[str, str], + capture_traces: bool = False, + ): + return self.wrapped.evaluate(batch, candidate, capture_traces) + + def make_reflective_dataset( + self, + candidate: dict[str, str], + eval_batch: EvaluationBatch[Trajectory, RolloutOutput], + components_to_update: list[str], + ): + return self.wrapped.make_reflective_dataset(candidate, eval_batch, components_to_update) + + +class ManagedVarsEvaluateAdapterWrapper(AdapterWrapper[DataInst, Trajectory, RolloutOutput]): + def evaluate( + self, + batch: list[DataInst], + candidate: dict[str, str], + capture_traces: bool = False, + ): + stack = ExitStack() + variables = logfire.get_variables() + with stack: + for var in variables: + if var.name in candidate: + raw_value = candidate[var.name] + try: + value = var.type_adapter.validate_json(raw_value) + except ValidationError: + if var.value_type is str: + value = raw_value + else: + raise + stack.enter_context(var.override(value)) + return super().evaluate(batch, candidate, capture_traces) diff --git a/logfire/experimental/_gepa/test.py b/logfire/experimental/_gepa/test.py new file mode 100644 index 000000000..7e617d33a --- /dev/null +++ b/logfire/experimental/_gepa/test.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Any + +import logfire +from logfire._internal.utils import JsonDict + +logfire.install_auto_tracing(modules=['gepa'], min_duration=0) + +from logfire.experimental._gepa._gepa import ( # noqa + CombinedSimpleAdapterMixin, + EvaluationInput, + EvaluationResult, + ManagedVarsEvaluateAdapterWrapper, + ReflectionItem, +) + +from gepa.api import optimize # type: ignore # noqa + +logfire.configure() +logfire.instrument_openai() +logfire.instrument_httpx(capture_all=True) + +MyDataInst = str +MyTrajectory = str +MyRolloutOutput = int + +prompt = logfire.var(name='prompt', default='Say 6', type=str) + + +class MyAdapter(CombinedSimpleAdapterMixin[MyDataInst, MyTrajectory, MyRolloutOutput]): + def propose_new_texts_impl( + self, + candidate: dict[str, str], + reflective_dataset: Mapping[str, Sequence[Mapping[str, Any]]], + components_to_update: list[str], + ) -> dict[str, str]: + assert list(candidate.keys()) == list(reflective_dataset.keys()) == components_to_update == ['prompt'] + # In practice this should construct a prompt to an LLM, + # which asks the LLM to suggest a new prompt. + # It would help for the user to provide a description of the problem/task/goal, + # (although this becomes redundant when the agent system instructions are present, + # which presumably describe the task already) + # and maybe descriptions of each variable. + return {'prompt': f'Say {reflective_dataset["prompt"][0]["expected output"]}'} + + def evaluate_instance(self, eval_input: EvaluationInput[MyDataInst]): + # In practice this would run some task function on eval_input.data. + # Also, if eval_input.capture_traces is True, it should create a trajectory containing: + # - the inputs (eval_input.data) + # - traces/spans captured during execution + + # The value of `prompt` is set by the outer ManagedVarsEvaluateAdapterWrapper. + if '42' in prompt.get().value: + output = 42 + score = 1 + else: + output = 0 + score = 0 + return EvaluationResult(output=output, score=score, trajectory=eval_input.data) + + def reflect_item(self, item: ReflectionItem[MyTrajectory, MyRolloutOutput]) -> JsonDict: + # This is where you'd run evaluators to provide feedback on the trajectory/output. + # You might also filter the traces/spans to only include parts where + # the managed variable corresponding to item.component was used/set. + return {'question': item.trajectory, 'response': item.output, 'expected output': 42} + + +adapter = ManagedVarsEvaluateAdapterWrapper(wrapped=MyAdapter()) + +seed_prompt = {var.name: var.get().value for var in [prompt]} + +gepa_result = optimize( # type: ignore + seed_candidate=seed_prompt, + adapter=adapter, + trainset=['What is the (numeric) answer to life, the universe and everything?'], + max_metric_calls=2, +) + +print(gepa_result) # type: ignore +assert gepa_result.best_candidate == {'prompt': 'Say 42'} diff --git a/pyproject.toml b/pyproject.toml index 297bdf319..9506f2277 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -185,6 +185,7 @@ dev = [ "litellm != 1.80.9", "pip >= 0", "surrealdb >= 0", + "gepa >= 0; python_version >= '3.10'" ] docs = [ "black>=23.12.0", diff --git a/uv.lock b/uv.lock index 5ce24aafc..167dd5872 100644 --- a/uv.lock +++ b/uv.lock @@ -614,7 +614,7 @@ name = "cffi" version = "2.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pycparser", marker = "implementation_name != 'PyPy'" }, + { name = "pycparser", marker = "(python_full_version >= '3.10' and implementation_name != 'PyPy') or (implementation_name != 'PyPy' and platform_python_implementation != 'PyPy')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8453301356628e8147c79dbb825bcbc73dc7401f9846/cffi-2.0.0.tar.gz", hash = "sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529", size = 523588, upload-time = "2025-09-08T23:24:04.541Z" } wheels = [ @@ -1375,7 +1375,7 @@ name = "exceptiongroup" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } wheels = [ @@ -1718,6 +1718,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/09/f3/5e5756d273c897bb5d0bfb8079bbfeb65fc6beb8bb1facb76dfda01651e9/genai_prices-0.0.50-py3-none-any.whl", hash = "sha256:ac70a5a0a532cb19591f8a465b24799d887b0241777f612ddac1d7604befa4d0", size = 61331, upload-time = "2026-01-06T15:03:15.486Z" }, ] +[[package]] +name = "gepa" +version = "0.0.24" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bc/2c/8f249827bcdc56212a445f4f671ee2201b864e433d2c88428933bd4b08f3/gepa-0.0.24.tar.gz", hash = "sha256:8035d1a0877661b6a63db457dc935b4878ee76acf7da1e488d7e6209bb32c054", size = 135673, upload-time = "2026-01-05T16:45:30.553Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/b1/33b035ff1aaf22d4e104c5b15ba48fe5050639764457048e967c20d6317a/gepa-0.0.24-py3-none-any.whl", hash = "sha256:6d8b16699e7b24ed01435dea7bbbc89156a88cbb4b877b14d90e7455db2b0032", size = 137539, upload-time = "2026-01-05T16:45:29.244Z" }, +] + [[package]] name = "ghp-import" version = "2.1.0" @@ -3096,6 +3105,7 @@ dev = [ { name = "eval-type-backport" }, { name = "fastapi" }, { name = "flask" }, + { name = "gepa", marker = "python_full_version >= '3.10'" }, { name = "google-genai", version = "1.47.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "google-genai", version = "1.56.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "greenlet", version = "3.2.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, @@ -3253,6 +3263,7 @@ dev = [ { name = "eval-type-backport", specifier = ">=0.2.0" }, { name = "fastapi", specifier = "!=0.124.3,!=0.124.4" }, { name = "flask", specifier = ">=3.0.3" }, + { name = "gepa", marker = "python_full_version >= '3.10'", specifier = ">=0" }, { name = "google-genai", specifier = ">=0" }, { name = "greenlet", specifier = ">=3.1.1" }, { name = "httpx", specifier = ">=0.27.2" },