Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions dspy/utils/dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName, field_header_pattern
from dspy.adapters.chat_adapter import FieldInfoWithName, field_header_pattern
from dspy.clients.lm import LM
from dspy.dsp.utils.utils import dotdict
from dspy.signatures.field import OutputField
Expand Down Expand Up @@ -67,13 +67,19 @@ class DummyLM(LM):

"""

def __init__(self, answers: list[dict[str, str]] | dict[str, dict[str, str]], follow_examples: bool = False):
def __init__(self, answers: list[dict[str, str]] | dict[str, dict[str, str]], follow_examples: bool = False, adapter=None):
super().__init__("dummy", "chat", 0.0, 1000, True)
self.answers = answers
if isinstance(answers, list):
self.answers = iter(answers)
self.follow_examples = follow_examples

# Set adapter, defaulting to ChatAdapter
if adapter is None:
from dspy.adapters.chat_adapter import ChatAdapter
adapter = ChatAdapter()
self.adapter = adapter

def _use_example(self, messages):
# find all field names
fields = defaultdict(int)
Expand All @@ -94,12 +100,20 @@ def _use_example(self, messages):
@with_callbacks
def __call__(self, prompt=None, messages=None, **kwargs):
def format_answer_fields(field_names_and_values: dict[str, Any]):
return ChatAdapter().format_field_with_value(
fields_with_values={
FieldInfoWithName(name=field_name, info=OutputField()): value
for field_name, value in field_names_and_values.items()
}
)
fields_with_values = {
FieldInfoWithName(name=field_name, info=OutputField()): value
for field_name, value in field_names_and_values.items()
}
# The reason why DummyLM needs an adapter is because it needs to know which output format to mimic.
# Normally LMs should not have any knowledge of an adapter, because the output format is defined in the prompt.
adapter = self.adapter

# Try to use role="assistant" if the adapter supports it (like JSONAdapter)
try:
return adapter.format_field_with_value(fields_with_values, role="assistant")
except TypeError:
# Fallback for adapters that don't support role parameter (like ChatAdapter)
return adapter.format_field_with_value(fields_with_values)

# Build the request.
outputs = []
Expand Down
14 changes: 8 additions & 6 deletions tests/adapters/test_json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,20 @@ class TestSignature(dspy.Signature):

def test_json_adapter_sync_call():
signature = dspy.make_signature("question->answer")
adapter = dspy.ChatAdapter()
lm = dspy.utils.DummyLM([{"answer": "Paris"}])
result = adapter(lm, {}, signature, [], {"question": "What is the capital of France?"})
adapter = dspy.JSONAdapter()
lm = dspy.utils.DummyLM([{"answer": "Paris"}], adapter=adapter)
with dspy.context(adapter=adapter):
result = adapter(lm, {}, signature, [], {"question": "What is the capital of France?"})
assert result == [{"answer": "Paris"}]


@pytest.mark.asyncio
async def test_json_adapter_async_call():
signature = dspy.make_signature("question->answer")
adapter = dspy.ChatAdapter()
lm = dspy.utils.DummyLM([{"answer": "Paris"}])
result = await adapter.acall(lm, {}, signature, [], {"question": "What is the capital of France?"})
adapter = dspy.JSONAdapter()
lm = dspy.utils.DummyLM([{"answer": "Paris"}], adapter=adapter)
with dspy.context(adapter=adapter):
result = await adapter.acall(lm, {}, signature, [], {"question": "What is the capital of France?"})
assert result == [{"answer": "Paris"}]


Expand Down