Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions mellea/stdlib/genslot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pydantic import BaseModel, Field, create_model

from mellea.stdlib.base import Component, TemplateRepresentation
from mellea.stdlib.session import get_session
from mellea.stdlib.session import MelleaSession, get_session

P = ParamSpec("P")
R = TypeVar("R")
Expand Down Expand Up @@ -154,7 +154,7 @@ def __init__(self, func: Callable[P, R]):

def __call__(
self,
m=None,
m: MelleaSession | None = None,
model_options: dict | None = None,
*args: P.args,
**kwargs: P.kwargs,
Expand All @@ -180,13 +180,11 @@ def __call__(

response_model = create_response_format(self._function._func)

response = m.genslot(
slot_copy, model_options=model_options, format=response_model
)
response = m.act(slot_copy, format=response_model, model_options=model_options)

function_response: FunctionResponse[R] = response_model.model_validate_json(
response.value
) # type: ignore
response.value # type: ignore
)

return function_response.result

Expand Down
9 changes: 5 additions & 4 deletions mellea/stdlib/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
self.success = success
self.sample_generations = sample_generations
self.sample_validations = sample_validations
self.sample_actions = sample_actions


class SamplingStrategy(abc.ABC):
Expand Down Expand Up @@ -153,7 +154,7 @@ def select_from_failure(
sampled_actions: list[Component],
sampled_results: list[ModelOutputThunk],
sampled_val: list[list[tuple[Requirement, ValidationResult]]],
):
) -> int:
"""This function returns the index of the result that should be selected as `.value` iff the loop budget is exhausted and no success.

Args:
Expand Down Expand Up @@ -356,17 +357,17 @@ def select_from_failure(

@staticmethod
def repair(
context: Context,
ctx: Context,
past_actions: list[Component],
past_results: list[ModelOutputThunk],
past_val: list[list[tuple[Requirement, ValidationResult]]],
) -> Component:
assert isinstance(context, LinearContext), (
assert isinstance(ctx, LinearContext), (
" Need linear context to run agentic sampling."
)

# add failed execution to chat history
context.insert_turn(ContextTurn(past_actions[-1], past_results[-1]))
ctx.insert_turn(ContextTurn(past_actions[-1], past_results[-1]))

last_failed_reqs: list[Requirement] = [s[0] for s in past_val[-1] if not s[1]]
last_failed_reqs_str = "* " + "\n* ".join(
Expand Down
Loading