From ee655d4ebb5bc273b74b760fe0bd045322a9366a Mon Sep 17 00:00:00 2001 From: Krista Opsahl-Ong Date: Thu, 4 Sep 2025 11:13:18 -0400 Subject: [PATCH 1/2] simba updates + handling optional fields --- dspy/adapters/utils.py | 17 +++++++- dspy/teleprompt/simba.py | 13 ++++-- dspy/teleprompt/simba_utils.py | 80 ++++++++++++++++++++++++---------- 3 files changed, 82 insertions(+), 28 deletions(-) diff --git a/dspy/adapters/utils.py b/dspy/adapters/utils.py index 2ce6de201d..71d8024b8a 100644 --- a/dspy/adapters/utils.py +++ b/dspy/adapters/utils.py @@ -1,9 +1,10 @@ + import ast import enum import inspect import json from collections.abc import Mapping -from typing import Any, Literal, Union, get_args, get_origin +from typing import Any, List, Literal, get_args, get_origin, Union import json_repair import pydantic @@ -13,6 +14,7 @@ from dspy.adapters.types.base_type import Type from dspy.signatures.utils import get_dspy_field_type +NoneType = type(None) def serialize_for_json(value: Any) -> Any: """ @@ -132,8 +134,19 @@ def find_enum_member(enum, identifier): raise ValueError(f"{identifier} is not a valid name or value for the enum {enum.__name__}") +def _strip_optional(ann): + """If ann is Union[..., NoneType] return the non‑None part, else ann.""" + if get_origin(ann) is Union and NoneType in get_args(ann): + # keep the first non‑None member (there will be only one in Optional[T]) + return next(a for a in get_args(ann) if a is not NoneType) + return ann def parse_value(value, annotation): + annotation = _strip_optional(annotation) + + if value is None: + return None + if annotation is str: return str(value) @@ -277,4 +290,4 @@ def _quoted_string_for_literal_type_annotation(s: str) -> str: return f"'{escaped}'" else: # Neither => enclose in single quotes - return f"'{s}'" + return f"'{s}'" \ No newline at end of file diff --git a/dspy/teleprompt/simba.py b/dspy/teleprompt/simba.py index 230673d8f6..bbed5f4184 100644 --- a/dspy/teleprompt/simba.py +++ b/dspy/teleprompt/simba.py @@ -1,12 +1,12 @@ import logging import random -from typing import Any, Callable import numpy as np import dspy -from dspy.teleprompt.simba_utils import append_a_demo, append_a_rule, prepare_models_for_resampling, wrap_program from dspy.teleprompt.teleprompt import Teleprompter +from dspy.teleprompt.simba_utils import prepare_models_for_resampling, wrap_program, append_a_demo, append_a_rule +from typing import Optional, Any, Dict, Callable logger = logging.getLogger(__name__) @@ -31,6 +31,8 @@ def __init__( num_candidates: int = 6, max_steps: int = 8, max_demos: int = 4, + prompt_model: Optional[Any] = None, + teacher_settings: Optional[Dict] = None, demo_input_field_maxlen: int = 100_000, num_threads: int | None = None, temperature_for_sampling: float = 0.2, @@ -62,6 +64,8 @@ def __init__( self.num_candidates = num_candidates self.max_steps = max_steps self.max_demos = max_demos + self.prompt_model = prompt_model if prompt_model else dspy.settings.lm + self.teacher_settings = teacher_settings self.demo_input_field_maxlen = demo_input_field_maxlen self.num_threads = num_threads @@ -175,7 +179,7 @@ def register_new_program(prog: dspy.Module, score_list: list[float]) -> None: # We'll generate (program, model) pairs for the trajectory sampling. # Prepare distinct LMs (with different temperatures, etc.) from the baseline=programs[0]. - models = prepare_models_for_resampling(programs[0], self.num_candidates) + models = prepare_models_for_resampling(programs[0], self.num_candidates, self.teacher_settings) top_programs = top_k_plus_baseline(self.num_candidates) exec_pairs = [] @@ -278,6 +282,7 @@ def register_new_program(prog: dspy.Module, score_list: list[float]) -> None: name2predictor=name2predictor, batch_10p_score=batch_10th_percentile_score, batch_90p_score=batch_90th_percentile_score, + prompt_model=self.prompt_model, ) except Exception as e: logger.error(f"Strategy failed with error: {e}") @@ -363,4 +368,4 @@ def register_new_program(prog: dspy.Module, score_list: list[float]) -> None: best_program.candidate_programs = candidate_data best_program.trial_logs = trial_logs - return best_program + return best_program \ No newline at end of file diff --git a/dspy/teleprompt/simba_utils.py b/dspy/teleprompt/simba_utils.py index a88e501e7b..7ad26dda73 100644 --- a/dspy/teleprompt/simba_utils.py +++ b/dspy/teleprompt/simba_utils.py @@ -1,24 +1,36 @@ import inspect import logging import textwrap -from typing import Callable - +import re import orjson import dspy from dspy.adapters.utils import get_field_description_string from dspy.signatures import InputField, OutputField +from typing import Callable, Optional, Dict logger = logging.getLogger(__name__) - -def prepare_models_for_resampling(program: dspy.Module, n: int): +def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings: Optional[Dict] = None): lm = program.get_lm() or dspy.settings.lm - start = lm.kwargs.get("rollout_id", 0) - rollout_ids = [start + i for i in range(n)] - return [lm.copy(rollout_id=r, temperature=1.0) for r in rollout_ids] + + start_rollout_id = lm.kwargs.get("rollout_id", 0) + rollout_ids = [start_rollout_id + i for i in range(n)] + start_rollout_idx, models = 0, [] + # If we have a teacher model, use this as the first model + if teacher_settings: + teacher_lm = teacher_settings.get("lm") or lm + teacher_lm.kwargs["rollout_id"] = rollout_ids[start_rollout_idx] + models.append(teacher_lm) + start_rollout_idx += 1 + + # The rest of the models are just copies of the base model + models.extend([lm.copy(rollout_id=r, temperature=1.0) for r in rollout_ids[start_rollout_idx:]]) + + return models + def wrap_program(program: dspy.Module, metric: Callable): def wrapped_program(example): with dspy.context(trace=[]): @@ -26,33 +38,51 @@ def wrapped_program(example): try: prediction = program(**example.inputs()) except Exception as e: - print(e) + logger.info(e) trace = dspy.settings.trace.copy() + output = None + score = 0.0 + output_metadata = {} + try: - score = metric(example, prediction) + output = metric(example, prediction) + if isinstance(output, (int, float)): + score = output + elif isinstance(output, dspy.Prediction): + if not hasattr(output, 'score'): + raise ValueError("dspy.Prediction must contain a 'score' attribute") + score = output.score + # Just extract fields from _store, excluding 'score' + output_metadata = { + k: v for k, v in output._store.items() if k != "score" + } except Exception as e: - print(e) + logger.info(e) - # Include the `example` in the output for subsequent usage in buckets/strategies. return { "prediction": prediction, "trace": trace, "score": score, - "example": example + "example": example, + "output_metadata": output_metadata } return wrapped_program - - def append_a_demo(demo_input_field_maxlen): def append_a_demo_(bucket, system, **kwargs): predictor2name, name2predictor = kwargs["predictor2name"], kwargs["name2predictor"] + batch_10p_score = kwargs["batch_10p_score"] - trace = bucket[0]["trace"] + good = bucket[0] + trace = good["trace"] name2demo = {} + if good["score"] <= batch_10p_score: + logger.info(f"Skipping appending a demo as good score {good['score']} is at or below the 10th percentile.") + return False + for step in trace: predictor, _inputs, _outputs = step @@ -63,7 +93,6 @@ def append_a_demo_(bucket, system, **kwargs): demo = dspy.Example(augmented=True, **_inputs, **_outputs) name = predictor2name[id(predictor)] name2demo[name] = demo # keep the last demo for each predictor - for name, demo in name2demo.items(): predictor = name2predictor[name] predictor.demos.append(demo) @@ -77,14 +106,15 @@ def append_a_demo_(bucket, system, **kwargs): def append_a_rule(bucket, system, **kwargs): predictor2name = kwargs["predictor2name"] batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"] + prompt_model = kwargs["prompt_model"] or dspy.settings.lm module_names = [name for name, _ in system.named_predictors()] good, bad = bucket[0], bucket[-1] example = good["example"] - if good["score"] < batch_10p_score or bad["score"] > batch_90p_score: - logger.info(f"Skipping rule generation as good score {good['score']} is below the 10th percentile " - f"*or* bad score {bad['score']} is above the 90th percentile.") + if good["score"] <= batch_10p_score or bad["score"] >= batch_90p_score: + logger.info(f"Skipping rule generation as good score {good['score']} is at or below the 10th percentile " + f"*or* bad score {bad['score']} is at or above the 90th percentile.") return False if good["score"] <= bad["score"]: @@ -117,12 +147,17 @@ def append_a_rule(bucket, system, **kwargs): "worse_program_outputs": dict(bad["prediction"] or {}), "worse_reward_value": bad["score"], "better_reward_value": good["score"], + "worse_reward_info": bad["output_metadata"], + "better_reward_info": good["output_metadata"], "module_names": module_names, } kwargs = {k: v if isinstance(v, str) else orjson.dumps(recursive_mask(v), option=orjson.OPT_INDENT_2).decode() for k, v in kwargs.items()} - advice = dspy.Predict(OfferFeedback)(**kwargs).module_advice + + with dspy.settings.context(trace=[], lm=prompt_model): + advice_program = dspy.Predict(OfferFeedback) + advice = advice_program(**kwargs).module_advice for name, predictor in system.named_predictors(): if name in advice: @@ -156,11 +191,13 @@ class OfferFeedback(dspy.Signature): ) worse_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") worse_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") + worse_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") better_program_trajectory: str = InputField( desc="The trajectory of the program's execution, showing each module's I/O" ) better_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") better_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") + better_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") module_names: list[str] = InputField(desc="The names of the modules in the program, for which we seek advice") discussion: str = OutputField(desc="Discussing blame of where each module went wrong, if it did") module_advice: dict[str, str] = OutputField( @@ -170,7 +207,6 @@ class OfferFeedback(dspy.Signature): "like the successful trajectory rather than the lower-scoring trajectory." ) - def inspect_modules(program): separator = "-" * 80 output = [separator] @@ -210,4 +246,4 @@ def recursive_mask(o): return tuple(recursive_mask(v) for v in o) # Otherwise, replace it with a placeholder string (or use repr(o)). else: - return f"" + return f"" \ No newline at end of file From 09476766dfa6a1d8729514d06b90c0e596111d92 Mon Sep 17 00:00:00 2001 From: Krista Opsahl-Ong Date: Thu, 4 Sep 2025 11:25:30 -0400 Subject: [PATCH 2/2] pre-commit check --- dspy/adapters/utils.py | 6 +++--- dspy/teleprompt/simba.py | 10 +++++----- dspy/teleprompt/simba_utils.py | 12 ++++++------ 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/dspy/adapters/utils.py b/dspy/adapters/utils.py index 71d8024b8a..bf2c51b04f 100644 --- a/dspy/adapters/utils.py +++ b/dspy/adapters/utils.py @@ -4,7 +4,7 @@ import inspect import json from collections.abc import Mapping -from typing import Any, List, Literal, get_args, get_origin, Union +from typing import Any, Literal, Union, get_args, get_origin import json_repair import pydantic @@ -146,7 +146,7 @@ def parse_value(value, annotation): if value is None: return None - + if annotation is str: return str(value) @@ -290,4 +290,4 @@ def _quoted_string_for_literal_type_annotation(s: str) -> str: return f"'{escaped}'" else: # Neither => enclose in single quotes - return f"'{s}'" \ No newline at end of file + return f"'{s}'" diff --git a/dspy/teleprompt/simba.py b/dspy/teleprompt/simba.py index bbed5f4184..c1f109856e 100644 --- a/dspy/teleprompt/simba.py +++ b/dspy/teleprompt/simba.py @@ -1,12 +1,12 @@ import logging import random +from typing import Any, Callable import numpy as np import dspy +from dspy.teleprompt.simba_utils import append_a_demo, append_a_rule, prepare_models_for_resampling, wrap_program from dspy.teleprompt.teleprompt import Teleprompter -from dspy.teleprompt.simba_utils import prepare_models_for_resampling, wrap_program, append_a_demo, append_a_rule -from typing import Optional, Any, Dict, Callable logger = logging.getLogger(__name__) @@ -31,8 +31,8 @@ def __init__( num_candidates: int = 6, max_steps: int = 8, max_demos: int = 4, - prompt_model: Optional[Any] = None, - teacher_settings: Optional[Dict] = None, + prompt_model: Any | None = None, + teacher_settings: dict | None = None, demo_input_field_maxlen: int = 100_000, num_threads: int | None = None, temperature_for_sampling: float = 0.2, @@ -368,4 +368,4 @@ def register_new_program(prog: dspy.Module, score_list: list[float]) -> None: best_program.candidate_programs = candidate_data best_program.trial_logs = trial_logs - return best_program \ No newline at end of file + return best_program diff --git a/dspy/teleprompt/simba_utils.py b/dspy/teleprompt/simba_utils.py index 7ad26dda73..33fd5671c5 100644 --- a/dspy/teleprompt/simba_utils.py +++ b/dspy/teleprompt/simba_utils.py @@ -1,17 +1,17 @@ import inspect import logging import textwrap -import re +from typing import Callable + import orjson import dspy from dspy.adapters.utils import get_field_description_string from dspy.signatures import InputField, OutputField -from typing import Callable, Optional, Dict logger = logging.getLogger(__name__) -def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings: Optional[Dict] = None): +def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings: dict | None = None): lm = program.get_lm() or dspy.settings.lm start_rollout_id = lm.kwargs.get("rollout_id", 0) @@ -28,7 +28,7 @@ def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings # The rest of the models are just copies of the base model models.extend([lm.copy(rollout_id=r, temperature=1.0) for r in rollout_ids[start_rollout_idx:]]) - + return models def wrap_program(program: dspy.Module, metric: Callable): @@ -50,7 +50,7 @@ def wrapped_program(example): if isinstance(output, (int, float)): score = output elif isinstance(output, dspy.Prediction): - if not hasattr(output, 'score'): + if not hasattr(output, "score"): raise ValueError("dspy.Prediction must contain a 'score' attribute") score = output.score # Just extract fields from _store, excluding 'score' @@ -246,4 +246,4 @@ def recursive_mask(o): return tuple(recursive_mask(v) for v in o) # Otherwise, replace it with a placeholder string (or use repr(o)). else: - return f"" \ No newline at end of file + return f""