Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/langfuse/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai>=2.15.1", "langfuse>=2.9.0, <3.0.0"]
dependencies = ["haystack-ai>=2.15.1", "langfuse>=3.0.0, <4.0.0"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/langfuse#readme"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,21 @@
import os
from abc import ABC, abstractmethod
from collections import Counter
from contextlib import AbstractContextManager
from contextvars import ContextVar
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, Iterator, List, Optional, Union
from typing import Any, Dict, Iterator, List, Optional

from haystack import default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage
from haystack.tracing import Span, Tracer
from haystack.tracing import tracer as proxy_tracer
from haystack.tracing import utils as tracing_utils
from typing_extensions import TypeAlias

import langfuse
from langfuse.client import StatefulGenerationClient, StatefulSpanClient, StatefulTraceClient

# Type alias for Langfuse stateful clients
LangfuseStatefulClient: TypeAlias = Union[StatefulTraceClient, StatefulSpanClient, StatefulGenerationClient]

from langfuse import LangfuseSpan as LangfuseClientSpan
from langfuse.types import TraceMetadata

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,15 +72,17 @@ class LangfuseSpan(Span):
Internal class representing a bridge between the Haystack span tracing API and Langfuse.
"""

def __init__(self, span: LangfuseStatefulClient) -> None:
def __init__(self, context_manager: AbstractContextManager) -> None:
"""
Initialize a LangfuseSpan instance.

:param span: The span instance managed by Langfuse.
:param context_manager: The context manager from Langfuse created with
`langfuse.get_client().start_as_current_span` or
`langfuse.get_client().start_as_current_observation`.
"""
self._span = span
# locally cache tags
self._span = context_manager.__enter__()
self._data: Dict[str, Any] = {}
self._context_manager = context_manager

def set_tag(self, key: str, value: Any) -> None:
"""
Expand Down Expand Up @@ -125,7 +124,7 @@ def set_content_tag(self, key: str, value: Any) -> None:

self._data[key] = value

def raw_span(self) -> LangfuseStatefulClient:
def raw_span(self) -> LangfuseClientSpan:
"""
Return the underlying span instance.

Expand Down Expand Up @@ -273,30 +272,50 @@ def create_span(self, context: SpanContext) -> LangfuseSpan:
tracing_ctx = tracing_context_var.get({})
if not context.parent_span:
# Create a new trace when there's no parent span
return LangfuseSpan(
self.tracer.trace(
name=context.trace_name,
public=context.public,
id=tracing_ctx.get("trace_id"),
user_id=tracing_ctx.get("user_id"),
session_id=tracing_ctx.get("session_id"),
tags=tracing_ctx.get("tags"),
version=tracing_ctx.get("version"),
)
span_context_manager = self.tracer.start_as_current_span(
name=context.trace_name,
version=tracing_ctx.get("version"),
)

# Create LangfuseSpan which will handle entering the context manager
span = LangfuseSpan(span_context_manager)

# Build trace metadata from context
trace_metadata: TraceMetadata = {
"name": context.trace_name,
"user_id": tracing_ctx.get("user_id"),
"session_id": tracing_ctx.get("session_id"),
"version": tracing_ctx.get("version"),
"metadata": None,
"tags": tracing_ctx.get("tags"),
"public": context.public,
}

# Filter out None values and apply trace attributes
trace_attrs = {k: v for k, v in trace_metadata.items() if v is not None}
if trace_attrs:
span._span.update_trace(**trace_attrs)

return span
elif context.component_type in _ALL_SUPPORTED_GENERATORS:
return LangfuseSpan(context.parent_span.raw_span().generation(name=context.name))
return LangfuseSpan(self.tracer.start_as_current_observation(name=context.name, as_type="generation"))
else:
return LangfuseSpan(context.parent_span.raw_span().span(name=context.name))
return LangfuseSpan(self.tracer.start_as_current_span(name=context.name))

def handle(self, span: LangfuseSpan, component_type: Optional[str]) -> None:
# Apply trace attributes if they were stored during span creation
trace_attrs = span.get_data().get("_trace_attrs")
if trace_attrs:
# We need to get the actual span from the context manager
# For now, we'll skip this as the context manager needs to be entered
pass

# If the span is at the pipeline level, we add input and output keys to the span
at_pipeline_level = span.get_data().get(_PIPELINE_INPUT_KEY) is not None
if at_pipeline_level:
coerced_input = tracing_utils.coerce_tag_value(span.get_data().get(_PIPELINE_INPUT_KEY))
coerced_output = tracing_utils.coerce_tag_value(span.get_data().get(_PIPELINE_OUTPUT_KEY))
span.raw_span().update(input=coerced_input, output=coerced_output)

span.raw_span().update_trace(input=coerced_input, output=coerced_output)
# special case for ToolInvoker (to update the span name to be: `original_component_name - [tool_names]`)
if component_type == "ToolInvoker":
tool_names: List[str] = []
Expand Down Expand Up @@ -415,7 +434,11 @@ def trace(

# End span (may fail if span data is corrupted)
raw_span = span.raw_span()
if isinstance(raw_span, (StatefulSpanClient, StatefulGenerationClient)):
# In v3, we need to properly exit context managers
if span._context_manager is not None:
span._context_manager.__exit__(None, None, None)
elif hasattr(raw_span, "end"):
# Only call end() if it's not a context manager
raw_span.end()
except Exception as cleanup_error:
# Log cleanup errors but don't let them corrupt context
Expand Down Expand Up @@ -456,4 +479,4 @@ def get_trace_id(self) -> str:
Return the trace ID.
:return: The trace ID.
"""
return self._tracer.get_trace_id()
return self._tracer.get_current_observation_id()
Loading
Loading