Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions dspy/clients/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,29 @@


class TrainingJobDatabricks(TrainingJob):
"""Tracks the state of a Databricks finetuning job and its deployment."""

def __init__(self, finetuning_run=None, *args, **kwargs):
"""Initialize a Databricks training job with optional finetuning run reference.

The finetuning_run parameter stores a reference to the Databricks finetuning run object
that can be used to query the job status.
"""
super().__init__(*args, **kwargs)
self.finetuning_run = finetuning_run
self.launch_started = False
self.launch_completed = False
self.endpoint_name = None

def status(self):
"""Retrieve the current status of the finetuning run from Databricks.

Queries the Databricks foundation model API to get the current status of the finetuning run.
Returns None if no finetuning run is associated with this job.

Raises:
ImportError: If the databricks_genai package is not installed.
"""
if not self.finetuning_run:
return None
try:
Expand All @@ -39,11 +54,18 @@ def status(self):


class DatabricksProvider(Provider):
"""Provider implementation for Databricks model finetuning and deployment."""

finetunable = True
TrainingJob = TrainingJobDatabricks

@staticmethod
def is_provider_model(model: str) -> bool:
"""Check if the model string indicates a Databricks model.

Always returns False as Databricks is not automatically inferred from model strings.
Databricks is not a proprietary model provider, so explicit provider specification is required.
"""
# We don't automatically infer Databricks models because Databricks is not a proprietary model provider.
return False

Expand All @@ -55,6 +77,16 @@ def deploy_finetuned_model(
databricks_token: str | None = None,
deploy_timeout: int = 900,
):
"""Deploy a finetuned model to Databricks serving endpoint with provisioned throughput.

Creates or updates a serving endpoint for the model, configures provisioned throughput
based on the model's optimization info, and waits for the endpoint to become ready.
The endpoint is tested with a sample request to verify it's operational.

Raises:
ValueError: If the model is not eligible for provisioned throughput or if the
serving endpoint creation/update fails or doesn't become ready within the timeout.
"""
workspace_client = _get_workspace_client()
model_version = next(workspace_client.model_versions.list(model)).version

Expand Down Expand Up @@ -172,6 +204,20 @@ def finetune(
train_data_format: TrainDataFormat | str | None = "chat",
train_kwargs: dict[str, Any] | None = None,
) -> str:
"""Execute finetuning on Databricks and optionally deploy the resulting model.

Uploads training data to Unity Catalog, starts a finetuning run, waits for completion,
and deploys the model to a serving endpoint unless skip_deploy is set. The method polls
the finetuning run status until it completes or fails.

Returns the deployed model name in the format "databricks/<endpoint_name>" or None if
deployment is skipped.

Raises:
ValueError: If train_data_format is invalid, required kwargs are missing, or if the
finetuning run fails.
ImportError: If the databricks_genai package is not installed.
"""
if isinstance(train_data_format, str):
if train_data_format == "chat":
train_data_format = TrainDataFormat.CHAT
Expand Down Expand Up @@ -244,6 +290,15 @@ def finetune(

@staticmethod
def upload_data(train_data: list[dict[str, Any]], databricks_unity_catalog_path: str, data_format: TrainDataFormat):
"""Save training data locally, validate it, and upload to Databricks Unity Catalog.

Creates a local JSONL file with the training data, validates the Unity Catalog path format,
ensures the target volume exists, creates the directory if needed, and uploads the file.
Returns the full path to the uploaded file in Unity Catalog.

Raises:
ValueError: If the upload fails or if the Unity Catalog path or volume is invalid.
"""
logger.info("Uploading finetuning data to Databricks Unity Catalog...")
file_path = _save_data_to_local_file(train_data, data_format)

Expand All @@ -261,6 +316,14 @@ def upload_data(train_data: list[dict[str, Any]], databricks_unity_catalog_path:


def _get_workspace_client() -> "WorkspaceClient":
"""Create and return a Databricks WorkspaceClient instance.

Initializes a WorkspaceClient using the default authentication configuration.
The client is used to interact with Databricks workspace APIs.

Raises:
ImportError: If the databricks-sdk package is not installed.
"""
try:
from databricks.sdk import WorkspaceClient
except ImportError:
Expand All @@ -271,6 +334,14 @@ def _get_workspace_client() -> "WorkspaceClient":


def _create_directory_in_databricks_unity_catalog(w: "WorkspaceClient", databricks_unity_catalog_path: str):
"""Validate the Unity Catalog path format, verify the volume exists, and create the directory if needed.

Checks that the path follows the required format '/Volumes/<catalog>/<schema>/<volume>/...',
verifies the volume exists in Unity Catalog, and creates the directory if it doesn't already exist.

Raises:
ValueError: If the path format is invalid or if the specified volume does not exist.
"""
pattern = r"^/Volumes/(?P<catalog>[^/]+)/(?P<schema>[^/]+)/(?P<volume>[^/]+)(/[^/]+)+$"
match = re.match(pattern, databricks_unity_catalog_path)
if not match:
Expand Down Expand Up @@ -303,6 +374,12 @@ def _create_directory_in_databricks_unity_catalog(w: "WorkspaceClient", databric


def _save_data_to_local_file(train_data: list[dict[str, Any]], data_format: TrainDataFormat):
"""Write training data to a local JSONL file after validating each item based on the data format.

Creates a uniquely named JSONL file in the finetune directory, validates each training data item
according to the specified format (chat or completion), and writes the data line by line.
Returns the absolute path to the created file.
"""
import uuid

file_name = f"finetuning_{uuid.uuid4()}.jsonl"
Expand All @@ -322,6 +399,14 @@ def _save_data_to_local_file(train_data: list[dict[str, Any]], data_format: Trai


def _validate_chat_data(data: dict[str, Any]):
"""Verify that chat data contains a 'messages' list with properly formatted message dictionaries.

Checks that the data has a 'messages' key containing a list, and that each message in the list
has both 'role' and 'content' keys.

Raises:
ValueError: If the data structure doesn't match the expected chat format.
"""
if "messages" not in data:
raise ValueError(
"Each finetuning data must be a dict with a 'messages' key when `task=CHAT_COMPLETION`, but "
Expand All @@ -344,6 +429,14 @@ def _validate_chat_data(data: dict[str, Any]):


def _validate_completion_data(data: dict[str, Any]):
"""Verify that completion data contains both 'prompt' and either 'response' or 'completion' keys.

Checks that the data has a 'prompt' key and at least one of 'response' or 'completion' keys,
which are required for instruction finetuning format.

Raises:
ValueError: If the data structure doesn't match the expected completion format.
"""
if "prompt" not in data:
raise ValueError(
"Each finetuning data must be a dict with a 'prompt' key when `task=INSTRUCTION_FINETUNE`, but "
Expand Down
31 changes: 31 additions & 0 deletions dspy/predict/avatar/avatar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


def get_number_with_suffix(number: int) -> str:
"""Converts a number to its ordinal string representation (1st, 2nd, 3rd, etc.)."""
if number == 1:
return "1st"
elif number == 2:
Expand All @@ -20,13 +21,21 @@ def get_number_with_suffix(number: int) -> str:


class Avatar(dspy.Module):
"""An agent-like DSPy module that uses tools iteratively to accomplish a task defined by a signature."""

def __init__(
self,
signature,
tools,
max_iters=3,
verbose=False,
):
"""Initializes the Avatar agent with a task signature and available tools.

Sets up the actor signature by appending input fields from the provided signature,
adds a special "Finish" tool to the tool list, and creates a TypedPredictor for the actor.
The actor signature is built dynamically by prepending input fields in reverse order.
"""
self.signature = ensure_signature(signature)
self.input_fields = self.signature.input_fields
self.output_fields = self.signature.output_fields
Expand Down Expand Up @@ -54,6 +63,14 @@ def __init__(
self.actor_clone = deepcopy(self.actor)

def _get_field(self, field_info: FieldInfo):
"""Reconstructs a DSPy field (InputField or OutputField) from a FieldInfo object.

Extracts the field type, prefix, description, and optional format from the json_schema_extra
metadata and creates the appropriate DSPy field instance.

Raises:
ValueError: If the field type in json_schema_extra is not 'input' or 'output'.
"""
if field_info.json_schema_extra["__dspy_field_type"] == "input":
return dspy.InputField(
prefix=field_info.json_schema_extra["prefix"],
Expand All @@ -70,6 +87,12 @@ def _get_field(self, field_info: FieldInfo):
raise ValueError(f"Unknown field type: {field_info.json_schema_extra['__dspy_field_type']}")

def _update_signature(self, idx: int, omit_action: bool = False):
"""Dynamically extends the actor signature with action and result fields for the current iteration.

Converts the previous action field from output to input, adds a result field for that action,
and either appends the next action field (if omit_action is False) or appends the final output
fields (if omit_action is True, indicating task completion).
"""
self.actor.signature = self.actor.signature.with_updated_fields(
f"action_{idx}", Action, __dspy_field_type="input"
)
Expand Down Expand Up @@ -104,11 +127,19 @@ def _update_signature(self, idx: int, omit_action: bool = False):
)

def _call_tool(self, tool_name: str, tool_input_query: str) -> str:
"""Executes a tool by name with the provided input query and returns its output."""
for tool in self.tools:
if tool.name == tool_name:
return tool.tool.run(tool_input_query)

def forward(self, **kwargs):
"""Executes the agent's main loop, iteratively selecting and running tools until task completion.

The agent repeatedly calls the actor to select an action, executes the corresponding tool,
and updates the signature with the action and result. This continues until the "Finish" tool
is selected or max_iters is reached. The signature is dynamically extended with each iteration's
action and result fields, building up the context for subsequent decisions.
"""
if self.verbose:
print("Starting the task...")

Expand Down
91 changes: 79 additions & 12 deletions dspy/teleprompt/simba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
logger = logging.getLogger(__name__)

def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings: dict | None = None):
"""Prepares a list of language models for resampling by assigning unique rollout IDs.

Creates n models with sequential rollout IDs. If teacher_settings is provided, the first
model uses the teacher's language model configuration. Remaining models are copies of the
base model with temperature set to 1.0.

Returns:
A list of language models configured for resampling with unique rollout IDs.
"""
lm = program.get_lm() or dspy.settings.lm

start_rollout_id = lm.kwargs.get("rollout_id", 0)
Expand All @@ -32,7 +41,26 @@ def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings
return models

def wrap_program(program: dspy.Module, metric: Callable):
"""Wraps a program to capture its execution trace and evaluate it with a metric.

Returns a function that executes the program on an example, captures the trace,
evaluates the prediction using the metric, and returns a dictionary containing
the prediction, trace, score, example, and any additional metadata from the metric.
The metric can return a numeric score or a dspy.Prediction with a score field.

Returns:
A function that takes an example and returns a dictionary with prediction results,
trace, score, and metadata.
"""
def wrapped_program(example):
"""Executes the program on an example and captures its trace.

Runs the program with the given example, captures the execution trace, evaluates
the result using the metric, and packages everything into a result dictionary.

Returns:
A dictionary containing prediction, trace, score, example, and output_metadata.
"""
with dspy.context(trace=[]):
prediction, trace, score = None, None, 0.0
try:
Expand Down Expand Up @@ -71,7 +99,25 @@ def wrapped_program(example):
return wrapped_program

def append_a_demo(demo_input_field_maxlen):
"""Returns a function that appends demonstrations from a successful trajectory to predictors.

The returned function extracts demonstrations from the best trajectory in a bucket and
appends them to the corresponding predictors. Input fields longer than demo_input_field_maxlen
are truncated. Skips appending if the best score is at or below the 10th percentile.

Returns:
A function that processes a bucket and appends demonstrations to predictors.
"""
def append_a_demo_(bucket, system, **kwargs):
"""Extracts and appends demonstrations from the best trajectory to predictors.

Processes the highest-scoring trajectory in the bucket, creates demonstrations from
each step, and appends them to the corresponding predictors. Truncates long input
fields and skips if the score is too low.

Returns:
True if demonstrations were appended, False if skipped due to low score.
"""
predictor2name, name2predictor = kwargs["predictor2name"], kwargs["name2predictor"]
batch_10p_score = kwargs["batch_10p_score"]

Expand Down Expand Up @@ -104,6 +150,16 @@ def append_a_demo_(bucket, system, **kwargs):


def append_a_rule(bucket, system, **kwargs):
"""Generates and appends advice to predictor instructions by comparing good and bad trajectories.

Uses a language model to analyze the difference between a high-scoring and low-scoring
trajectory, generating module-specific advice. The advice is appended to each predictor's
instructions. Skips rule generation if the good score is too low or the bad score is too high
relative to batch percentiles.

Returns:
True if advice was generated and appended, False if skipped due to score thresholds.
"""
predictor2name = kwargs["predictor2name"]
batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"]
prompt_model = kwargs["prompt_model"] or dspy.settings.lm
Expand Down Expand Up @@ -168,18 +224,11 @@ def append_a_rule(bucket, system, **kwargs):
return True

class OfferFeedback(dspy.Signature):
"""
You will be given two trajectories of an LLM-driven program's execution. Your goal is to help the program's modules
build up experience on how to maximize the reward value assigned to the program's outputs if it were to receive
similar inputs in the future.

The module won't see its own history. It will rely on your advice balancing being concrete and being generalizable.

In your advice:
- Avoid boilerplate. Offer advice that would change the module's behavior for the better in the future.
- Ensure that advice offered to a module M is specific to that M's specific sub-task, not the overall program.
- Rely on contrasting the behavior of the worse trajectory against the better trajectory in making recommendations.
- Ensure each unique module name appears exactly once as a key in the advice dictionary.
"""Signature for generating module-specific advice by comparing successful and unsuccessful trajectories.

Analyzes two program execution trajectories with different reward values to generate
concrete, actionable advice for each module. The advice helps modules improve their
behavior by learning from the contrast between better and worse trajectories.
"""

program_code: str = InputField(desc="The code of the program that we are analyzing")
Expand Down Expand Up @@ -208,6 +257,15 @@ class OfferFeedback(dspy.Signature):
)

def inspect_modules(program):
"""Formats module information into a human-readable string representation.

Extracts and formats each predictor's name, input fields, output fields, and instructions
into a structured text format with separators. The output is suitable for inclusion in
prompts or logs.

Returns:
A formatted string containing module definitions with their fields and instructions.
"""
separator = "-" * 80
output = [separator]

Expand All @@ -228,6 +286,15 @@ def inspect_modules(program):


def recursive_mask(o):
"""Recursively masks non-serializable objects with placeholder strings.

Traverses the object structure and replaces any non-JSON-serializable values with
a placeholder string indicating the type. Handles dictionaries, lists, and tuples
recursively while preserving already-serializable values.

Returns:
The object with non-serializable values replaced by placeholder strings.
"""
# If the object is already serializable, return it.
try:
orjson.dumps(o)
Expand Down