Skip to content

Commit 11e9965

Browse files
authored
Change DummyLM to take an adapter at init (#8802)
1 parent 12bfb37 commit 11e9965

File tree

2 files changed

+30
-14
lines changed

2 files changed

+30
-14
lines changed

dspy/utils/dummies.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName, field_header_pattern
7+
from dspy.adapters.chat_adapter import FieldInfoWithName, field_header_pattern
88
from dspy.clients.lm import LM
99
from dspy.dsp.utils.utils import dotdict
1010
from dspy.signatures.field import OutputField
@@ -67,13 +67,19 @@ class DummyLM(LM):
6767
6868
"""
6969

70-
def __init__(self, answers: list[dict[str, str]] | dict[str, dict[str, str]], follow_examples: bool = False):
70+
def __init__(self, answers: list[dict[str, str]] | dict[str, dict[str, str]], follow_examples: bool = False, adapter=None):
7171
super().__init__("dummy", "chat", 0.0, 1000, True)
7272
self.answers = answers
7373
if isinstance(answers, list):
7474
self.answers = iter(answers)
7575
self.follow_examples = follow_examples
7676

77+
# Set adapter, defaulting to ChatAdapter
78+
if adapter is None:
79+
from dspy.adapters.chat_adapter import ChatAdapter
80+
adapter = ChatAdapter()
81+
self.adapter = adapter
82+
7783
def _use_example(self, messages):
7884
# find all field names
7985
fields = defaultdict(int)
@@ -94,12 +100,20 @@ def _use_example(self, messages):
94100
@with_callbacks
95101
def __call__(self, prompt=None, messages=None, **kwargs):
96102
def format_answer_fields(field_names_and_values: dict[str, Any]):
97-
return ChatAdapter().format_field_with_value(
98-
fields_with_values={
99-
FieldInfoWithName(name=field_name, info=OutputField()): value
100-
for field_name, value in field_names_and_values.items()
101-
}
102-
)
103+
fields_with_values = {
104+
FieldInfoWithName(name=field_name, info=OutputField()): value
105+
for field_name, value in field_names_and_values.items()
106+
}
107+
# The reason why DummyLM needs an adapter is because it needs to know which output format to mimic.
108+
# Normally LMs should not have any knowledge of an adapter, because the output format is defined in the prompt.
109+
adapter = self.adapter
110+
111+
# Try to use role="assistant" if the adapter supports it (like JSONAdapter)
112+
try:
113+
return adapter.format_field_with_value(fields_with_values, role="assistant")
114+
except TypeError:
115+
# Fallback for adapters that don't support role parameter (like ChatAdapter)
116+
return adapter.format_field_with_value(fields_with_values)
103117

104118
# Build the request.
105119
outputs = []

tests/adapters/test_json_adapter.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,20 @@ class TestSignature(dspy.Signature):
106106

107107
def test_json_adapter_sync_call():
108108
signature = dspy.make_signature("question->answer")
109-
adapter = dspy.ChatAdapter()
110-
lm = dspy.utils.DummyLM([{"answer": "Paris"}])
111-
result = adapter(lm, {}, signature, [], {"question": "What is the capital of France?"})
109+
adapter = dspy.JSONAdapter()
110+
lm = dspy.utils.DummyLM([{"answer": "Paris"}], adapter=adapter)
111+
with dspy.context(adapter=adapter):
112+
result = adapter(lm, {}, signature, [], {"question": "What is the capital of France?"})
112113
assert result == [{"answer": "Paris"}]
113114

114115

115116
@pytest.mark.asyncio
116117
async def test_json_adapter_async_call():
117118
signature = dspy.make_signature("question->answer")
118-
adapter = dspy.ChatAdapter()
119-
lm = dspy.utils.DummyLM([{"answer": "Paris"}])
120-
result = await adapter.acall(lm, {}, signature, [], {"question": "What is the capital of France?"})
119+
adapter = dspy.JSONAdapter()
120+
lm = dspy.utils.DummyLM([{"answer": "Paris"}], adapter=adapter)
121+
with dspy.context(adapter=adapter):
122+
result = await adapter.acall(lm, {}, signature, [], {"question": "What is the capital of France?"})
121123
assert result == [{"answer": "Paris"}]
122124

123125

0 commit comments

Comments
 (0)