Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 14 additions & 20 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import json_repair
import litellm

from dspy.adapters.types import History
from dspy.adapters.types import History, Type
from dspy.adapters.types.base_type import split_message_content_for_custom_types
from dspy.adapters.types.tool import Tool, ToolCalls
from dspy.experimental import Citations
from dspy.signatures.signature import Signature
from dspy.utils.callback import BaseCallback, with_callbacks

Expand All @@ -18,9 +17,10 @@


class Adapter:
def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = False):
def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = False, native_response_types: list[type[Type]] | None = None):
self.callbacks = callbacks or []
self.use_native_function_calling = use_native_function_calling
self.native_response_types = native_response_types or []

def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
Expand Down Expand Up @@ -64,9 +64,10 @@ def _call_preprocess(

return signature_for_native_function_calling

citation_output_field_name = self._get_citation_output_field_name(signature)
if citation_output_field_name:
signature = signature.delete(citation_output_field_name)
# Handle custom types that use native response
for name, field in signature.output_fields.items():
if isinstance(field.annotation, type) and issubclass(field.annotation, Type) and field.annotation in self.native_response_types:
signature = signature.delete(name)

return signature

Expand All @@ -75,23 +76,21 @@ def _call_postprocess(
processed_signature: type[Signature],
original_signature: type[Signature],
outputs: list[dict[str, Any]],
lm: "LM",
) -> list[dict[str, Any]]:
values = []

tool_call_output_field_name = self._get_tool_call_output_field_name(original_signature)
citation_output_field_name = self._get_citation_output_field_name(original_signature)

for output in outputs:
output_logprobs = None
tool_calls = None
citations = None
text = output

if isinstance(output, dict):
text = output["text"]
output_logprobs = output.get("logprobs")
tool_calls = output.get("tool_calls")
citations = output.get("citations")

if text:
value = self.parse(processed_signature, text)
Expand All @@ -114,9 +113,10 @@ def _call_postprocess(
]
value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls)

if citations and citation_output_field_name:
citations_obj = Citations.from_dict_list(citations)
value[citation_output_field_name] = citations_obj
# Parse custom types that does not rely on the adapter parsing
for name, field in original_signature.output_fields.items():
if isinstance(field.annotation, type) and issubclass(field.annotation, Type) and field.annotation in self.native_response_types:
value[name] = field.annotation.parse_lm_response(output)

if output_logprobs:
value["logprobs"] = output_logprobs
Expand All @@ -137,7 +137,7 @@ def __call__(
inputs = self.format(processed_signature, demos, inputs)

outputs = lm(messages=inputs, **lm_kwargs)
return self._call_postprocess(processed_signature, signature, outputs)
return self._call_postprocess(processed_signature, signature, outputs, lm)

async def acall(
self,
Expand All @@ -151,7 +151,7 @@ async def acall(
inputs = self.format(processed_signature, demos, inputs)

outputs = await lm.acall(messages=inputs, **lm_kwargs)
return self._call_postprocess(processed_signature, signature, outputs)
return self._call_postprocess(processed_signature, signature, outputs, lm)

def format(
self,
Expand Down Expand Up @@ -402,12 +402,6 @@ def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool:
return name
return None

def _get_citation_output_field_name(self, signature: type[Signature]) -> str | None:
"""Find the Citations output field in the signature."""
for name, field in signature.output_fields.items():
if field.annotation == Citations:
return name
return None

def format_conversation_history(
self,
Expand Down
33 changes: 32 additions & 1 deletion dspy/adapters/types/base_type.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import re
from typing import Any, get_args, get_origin
from typing import Any, Optional, get_args, get_origin

import json_repair
import pydantic
from litellm import ModelResponseStream

CUSTOM_TYPE_START_IDENTIFIER = "<<CUSTOM-TYPE-START-IDENTIFIER>>"
CUSTOM_TYPE_END_IDENTIFIER = "<<CUSTOM-TYPE-END-IDENTIFIER>>"
Expand Down Expand Up @@ -67,6 +68,36 @@ def serialize_model(self):
return f"{CUSTOM_TYPE_START_IDENTIFIER}{formatted}{CUSTOM_TYPE_END_IDENTIFIER}"
return formatted

@classmethod
def is_streamable(cls) -> bool:
"""Whether the custom type is streamable."""
return False

@classmethod
def parse_stream_chunk(cls, chunk: ModelResponseStream) -> Optional["Type"]:
"""
Parse a stream chunk into the custom type.

Args:
chunk: A stream chunk.

Returns:
A custom type object or None if the chunk is not for this custom type.
"""
return None


@classmethod
def parse_lm_response(cls, response: str | dict[str, Any]) -> Optional["Type"]:
"""Parse a LM response into the custom type.

Args:
response: A LM response.

Returns:
A custom type object.
"""
return None

def split_message_content_for_custom_types(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Split user message content into a list of content blocks.
Expand Down
50 changes: 49 additions & 1 deletion dspy/adapters/types/citation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional

import pydantic

Expand Down Expand Up @@ -166,3 +166,51 @@ def __len__(self):
def __getitem__(self, index):
"""Allow indexing into citations."""
return self.citations[index]

@classmethod
def is_streamable(cls) -> bool:
"""Whether the Citations type is streamable."""
return True

@classmethod
def parse_stream_chunk(cls, chunk) -> Optional["Citations"]:
"""
Parse a stream chunk into Citations.

Args:
chunk: A stream chunk from the LM.

Returns:
A Citations object if the chunk contains citation data, None otherwise.
"""
try:
# Check if the chunk has citation data in provider_specific_fields
if hasattr(chunk, "choices") and chunk.choices:
delta = chunk.choices[0].delta
if hasattr(delta, "provider_specific_fields") and delta.provider_specific_fields:
citation_data = delta.provider_specific_fields.get("citation")
if citation_data:
return cls.from_dict_list([citation_data])
except Exception:
pass
return None


@classmethod
def parse_lm_response(cls, response: str | dict[str, Any]) -> Optional["Citations"]:
"""Parse a LM response into Citations.

Args:
response: A LM response that may contain citation data.

Returns:
A Citations object if citation data is found, None otherwise.
"""
if isinstance(response, dict):
# Check if the response contains citations in the expected format
if "citations" in response:
citations_data = response["citations"]
if isinstance(citations_data, list):
return cls.from_dict_list(citations_data)

return None
49 changes: 29 additions & 20 deletions dspy/streaming/streaming_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters.json_adapter import JSONAdapter
from dspy.adapters.types.citation import Citations
from dspy.adapters.types import Type
from dspy.adapters.xml_adapter import XMLAdapter
from dspy.dsp.utils.settings import settings
from dspy.streaming.messages import StreamResponse
Expand Down Expand Up @@ -102,18 +102,15 @@ def receive(self, chunk: ModelResponseStream):
except Exception:
return

# Handle anthropic citations. see https://docs.litellm.ai/docs/providers/anthropic#beta-citations-api
try:
if self._signature_field_is_citation_type():
if chunk_citation := chunk.choices[0].delta.provider_specific_fields.get("citation", None):
return StreamResponse(
self.predict_name,
self.signature_field_name,
Citations.from_dict_list([chunk_citation]),
is_last_chunk=False,
)
except Exception:
pass
# Handle custom streamable types
if self.output_type and issubclass(self.output_type, Type) and self.output_type.is_streamable():
if parsed_chunk := self.output_type.parse_stream_chunk(chunk):
return StreamResponse(
self.predict_name,
self.signature_field_name,
parsed_chunk,
is_last_chunk=self.stream_end,
)

if chunk_message and start_identifier in chunk_message:
# If the cache is hit, the chunk_message could be the full response. When it happens we can
Expand Down Expand Up @@ -217,10 +214,13 @@ def flush(self) -> str:
f"{', '.join([a.__name__ for a in ADAPTER_SUPPORT_STREAMING])}"
)

def _signature_field_is_citation_type(self) -> bool:
"""Check if the signature field is a citations field."""
from dspy.predict import Predict
return isinstance(self.predict, Predict) and getattr(self.predict.signature.output_fields.get(self.signature_field_name, None), "annotation", None) == Citations
@property
def output_type(self) -> type | None:
try:
return self.predict.signature.output_fields[self.signature_field_name].annotation
except Exception:
return None



def find_predictor_for_stream_listeners(program: "Module", stream_listeners: list[StreamListener]):
Expand Down Expand Up @@ -249,10 +249,10 @@ def find_predictor_for_stream_listeners(program: "Module", stream_listeners: lis
"predictor to use for streaming. Please specify the predictor to listen to."
)

if field_info.annotation not in [str, Citations]:
if not _is_streamable(field_info.annotation):
raise ValueError(
f"Stream listener can only be applied to string or Citations output field, but your field {field_name} is of "
f"type {field_info.annotation}."
f"Stream listener can only be applied to string or subclass of `dspy.Type` that has `is_streamable() == True`, "
f"but your field {field_name} is of type {field_info.annotation}."
)

field_name_to_named_predictor[field_name] = (name, predictor)
Expand All @@ -271,3 +271,12 @@ def find_predictor_for_stream_listeners(program: "Module", stream_listeners: lis
listener.predict_name, listener.predict = field_name_to_named_predictor[listener.signature_field_name]
predict_id_to_listener[id(listener.predict)].append(listener)
return predict_id_to_listener

def _is_streamable(field_type: type | None) -> bool:
if field_type is None:
return False
if field_type is str:
return True
if issubclass(field_type, Type):
return field_type.is_streamable()
return False
3 changes: 2 additions & 1 deletion tests/adapters/test_citation.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ class CitationSignature(Signature):
result = adapter._call_postprocess(
CitationSignature.delete("citations"),
CitationSignature,
outputs
outputs,
dspy.LM(model="claude-3-5-sonnet-20241022")
)

assert len(result) == 1
Expand Down
55 changes: 54 additions & 1 deletion tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices

import dspy
from dspy.adapters.types import Type
from dspy.experimental import Citations, Document
from dspy.streaming import StatusMessage, StatusMessageProvider, streaming_response

Expand Down Expand Up @@ -877,6 +878,58 @@ async def send_to_stream():
assert isinstance(all_chunks[1], dspy.Prediction)


@pytest.mark.anyio
async def test_streaming_allows_custom_streamable_type():
class CustomType(Type):
message: str

@classmethod
def is_streamable(cls) -> bool:
return True

@classmethod
def parse_stream_chunk(cls, chunk):
return CustomType(message=chunk.choices[0].delta.content)

@classmethod
def parse_lm_response(cls, response: dict) -> "CustomType":
return CustomType(message=response.split("\n\n")[0])

class CustomSignature(dspy.Signature):
question: str = dspy.InputField()
answer: CustomType = dspy.OutputField()

program = dspy.streamify(
dspy.Predict(CustomSignature),
stream_listeners=[
dspy.streaming.StreamListener(signature_field_name="answer"),
],
)

async def stream(*args, **kwargs):
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="Hello"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="World"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="\n\n"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="[[ ##"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" completed"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ##"))])
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" ]]"))])


with mock.patch("litellm.acompletion", side_effect=stream):
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.ChatAdapter(native_response_types=[CustomType])):
output = program(question="why did a chicken cross the kitchen?")
all_chunks = []
async for value in output:
if isinstance(value, dspy.streaming.StreamResponse):
all_chunks.append(value)
elif isinstance(value, dspy.Prediction):
assert isinstance(value.answer, CustomType)
assert value.answer.message == "HelloWorld"

assert all(isinstance(chunk.chunk, CustomType) for chunk in all_chunks)


@pytest.mark.anyio
async def test_streaming_with_citations():
class AnswerWithSources(dspy.Signature):
Expand Down Expand Up @@ -936,7 +989,7 @@ async def citation_stream(*args, **kwargs):
# Create test documents
docs = [Document(data="Water boils at 100°C at standard pressure.", title="Physics Facts")]

with dspy.context(lm=dspy.LM("anthropic/claude-3-5-sonnet-20241022", cache=False)):
with dspy.context(lm=dspy.LM("anthropic/claude-3-5-sonnet-20241022", cache=False), adapter=dspy.ChatAdapter(native_response_types=[Citations])):
output = program(documents=docs, question="What temperature does water boil?")
citation_chunks = []
final_prediction = None
Expand Down
Loading