Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
164 changes: 164 additions & 0 deletions logfire/experimental/_gepa/_gepa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Any, Generic

from gepa.core.adapter import DataInst, EvaluationBatch, GEPAAdapter, RolloutOutput, Trajectory
from pydantic import ValidationError

import logfire
from logfire._internal.utils import JsonDict


@dataclass
class ReflectionItem(Generic[Trajectory, RolloutOutput]):
candidate: dict[str, str]
component: str
trajectory: Trajectory
score: float
output: RolloutOutput


class SimpleReflectionAdapterMixin(GEPAAdapter[DataInst, Trajectory, RolloutOutput], ABC):
def make_reflective_dataset(
self,
candidate: dict[str, str],
eval_batch: EvaluationBatch[Trajectory, RolloutOutput],
components_to_update: list[str],
):
assert len(components_to_update) == 1
component = components_to_update[0]

assert eval_batch.trajectories, 'Trajectories are required to build a reflective dataset.'

return {
component: [
self.reflect_item(ReflectionItem(candidate, component, *inst))
for inst in zip(eval_batch.trajectories, eval_batch.scores, eval_batch.outputs, strict=True)
]
}

@abstractmethod
def reflect_item(self, item: ReflectionItem[Trajectory, RolloutOutput]) -> JsonDict: ...


@dataclass
class EvaluationInput(Generic[DataInst]):
data: DataInst
"""Inputs to the task/function being evaluated."""

candidate: dict[str, str]
capture_traces: bool


@dataclass
class EvaluationResult(Generic[Trajectory, RolloutOutput]):
output: RolloutOutput
score: float
trajectory: Trajectory | None = None


class SimpleEvaluateAdapterMixin(GEPAAdapter[DataInst, Trajectory, RolloutOutput], ABC):
def evaluate(
self,
batch: list[DataInst],
candidate: dict[str, str],
capture_traces: bool = False,
):
outputs: list[RolloutOutput] = []
scores: list[float] = []
trajectories: list[Trajectory] | None = [] if capture_traces else None

for data in batch:
eval_input = EvaluationInput(data=data, candidate=candidate, capture_traces=capture_traces)
eval_result = self.evaluate_instance(eval_input)
outputs.append(eval_result.output)
scores.append(eval_result.score)

if trajectories is not None:
assert eval_result.trajectory
trajectories.append(eval_result.trajectory)

return EvaluationBatch(outputs=outputs, scores=scores, trajectories=trajectories)

@abstractmethod
def evaluate_instance(
self, eval_input: EvaluationInput[DataInst]
) -> EvaluationResult[Trajectory, RolloutOutput]: ...


class SimpleProposeNewTextsAdapterMixin(
GEPAAdapter[DataInst, Trajectory, RolloutOutput],
ABC,
):
def __init__(self, *args: Any, **kwargs: Any):
self.propose_new_texts = self.propose_new_texts_impl
super().__init__(*args, **kwargs)

@abstractmethod
def propose_new_texts_impl(
self,
candidate: dict[str, str],
reflective_dataset: Mapping[str, Sequence[Mapping[str, Any]]],
components_to_update: list[str],
) -> dict[str, str]: ...


class CombinedSimpleAdapterMixin(
SimpleEvaluateAdapterMixin[DataInst, Trajectory, RolloutOutput],
SimpleReflectionAdapterMixin[DataInst, Trajectory, RolloutOutput],
SimpleProposeNewTextsAdapterMixin[DataInst, Trajectory, RolloutOutput],
ABC,
):
pass


@dataclass
class AdapterWrapper(GEPAAdapter[DataInst, Trajectory, RolloutOutput]):
wrapped: GEPAAdapter[DataInst, Trajectory, RolloutOutput]

def __post_init__(self):
self.propose_new_texts = self.wrapped.propose_new_texts

def evaluate(
self,
batch: list[DataInst],
candidate: dict[str, str],
capture_traces: bool = False,
):
return self.wrapped.evaluate(batch, candidate, capture_traces)

def make_reflective_dataset(
self,
candidate: dict[str, str],
eval_batch: EvaluationBatch[Trajectory, RolloutOutput],
components_to_update: list[str],
):
return self.wrapped.make_reflective_dataset(candidate, eval_batch, components_to_update)


class ManagedVarsEvaluateAdapterWrapper(AdapterWrapper[DataInst, Trajectory, RolloutOutput]):
def evaluate(
self,
batch: list[DataInst],
candidate: dict[str, str],
capture_traces: bool = False,
):
stack = ExitStack()
variables = logfire.get_variables()
with stack:
for var in variables:
if var.name in candidate:
raw_value = candidate[var.name]
try:
value = var.type_adapter.validate_json(raw_value)
except ValidationError:
if var.value_type is str:
value = raw_value
else:
raise
stack.enter_context(var.override(value))
return super().evaluate(batch, candidate, capture_traces)
82 changes: 82 additions & 0 deletions logfire/experimental/_gepa/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations

from collections.abc import Mapping, Sequence
from typing import Any

import logfire
from logfire._internal.utils import JsonDict

logfire.install_auto_tracing(modules=['gepa'], min_duration=0)

from logfire.experimental._gepa._gepa import ( # noqa
CombinedSimpleAdapterMixin,
EvaluationInput,
EvaluationResult,
ManagedVarsEvaluateAdapterWrapper,
ReflectionItem,
)

from gepa.api import optimize # type: ignore # noqa

logfire.configure()
logfire.instrument_openai()
logfire.instrument_httpx(capture_all=True)

MyDataInst = str
MyTrajectory = str
MyRolloutOutput = int

prompt = logfire.var(name='prompt', default='Say 6', type=str)


class MyAdapter(CombinedSimpleAdapterMixin[MyDataInst, MyTrajectory, MyRolloutOutput]):
def propose_new_texts_impl(
self,
candidate: dict[str, str],
reflective_dataset: Mapping[str, Sequence[Mapping[str, Any]]],
components_to_update: list[str],
) -> dict[str, str]:
assert list(candidate.keys()) == list(reflective_dataset.keys()) == components_to_update == ['prompt']
# In practice this should construct a prompt to an LLM,
# which asks the LLM to suggest a new prompt.
# It would help for the user to provide a description of the problem/task/goal,
# (although this becomes redundant when the agent system instructions are present,
# which presumably describe the task already)
# and maybe descriptions of each variable.
return {'prompt': f'Say {reflective_dataset["prompt"][0]["expected output"]}'}

def evaluate_instance(self, eval_input: EvaluationInput[MyDataInst]):
# In practice this would run some task function on eval_input.data.
# Also, if eval_input.capture_traces is True, it should create a trajectory containing:
# - the inputs (eval_input.data)
# - traces/spans captured during execution

# The value of `prompt` is set by the outer ManagedVarsEvaluateAdapterWrapper.
if '42' in prompt.get().value:
output = 42
score = 1
else:
output = 0
score = 0
return EvaluationResult(output=output, score=score, trajectory=eval_input.data)

def reflect_item(self, item: ReflectionItem[MyTrajectory, MyRolloutOutput]) -> JsonDict:
# This is where you'd run evaluators to provide feedback on the trajectory/output.
# You might also filter the traces/spans to only include parts where
# the managed variable corresponding to item.component was used/set.
return {'question': item.trajectory, 'response': item.output, 'expected output': 42}


adapter = ManagedVarsEvaluateAdapterWrapper(wrapped=MyAdapter())

seed_prompt = {var.name: var.get().value for var in [prompt]}

gepa_result = optimize( # type: ignore
seed_candidate=seed_prompt,
adapter=adapter,
trainset=['What is the (numeric) answer to life, the universe and everything?'],
max_metric_calls=2,
)

print(gepa_result) # type: ignore
assert gepa_result.best_candidate == {'prompt': 'Say 42'}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ dev = [
"litellm != 1.80.9",
"pip >= 0",
"surrealdb >= 0",
"gepa >= 0; python_version >= '3.10'"
]
docs = [
"black>=23.12.0",
Expand Down
15 changes: 13 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading