Skip to content

Commit 1b82155

Browse files
committed
Update embedding and clustering logic with batch processing and OpenAI integration for label generation.
1 parent 923f21f commit 1b82155

File tree

1 file changed

+153
-135
lines changed

1 file changed

+153
-135
lines changed

backend/app/crud/langfusexp.py

Lines changed: 153 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -1,140 +1,158 @@
11
import warnings
2-
3-
warnings.filterwarnings("ignore")
4-
5-
from langfuse import Langfuse
6-
7-
from sentence_transformers import SentenceTransformer
8-
9-
embedding_model = SentenceTransformer("all-mpnet-base-v2")
10-
11-
12-
import os
13-
14-
os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-xx"
15-
os.environ["LANGFUSE_SECRET_KEY"] = "sk-xx"
16-
os.environ["LANGFUSE_HOST"] = "https://cloud.langfuse.com"
17-
os.environ["OPENAI_API_KEY"] = "sk-proj-Yxxx"
18-
os.environ["TOKENIZERS_PARALLELISM"] = "true"
19-
20-
langfuse = Langfuse()
21-
22-
23-
PAGES_TO_FETCH = 2
24-
25-
traces = []
26-
for i in range(PAGES_TO_FETCH):
27-
traces_page = langfuse.fetch_traces(page=i + 1)
28-
traces.extend(traces_page.data)
29-
30-
traces_list = []
31-
for trace in traces:
32-
trace_info = [trace.id, trace.input]
33-
traces_list.append(trace_info)
34-
2+
from typing import List, Dict, Any
353
import pandas as pd
36-
37-
cluster_traces_df = pd.DataFrame(traces_list, columns=["trace_id", "message"])
38-
cluster_traces_df.dropna(inplace=True) # drop traces with message = None
39-
40-
# keep only rows whose message is NOT in bad
41-
cluster_traces_df = cluster_traces_df[
42-
~cluster_traces_df["message"].isin(["setup_thread", "validate_thread", "RunID"])
43-
]
44-
45-
# (optional) reset the index if you don’t care about preserving the old one
46-
cluster_traces_df = cluster_traces_df.reset_index(drop=True)
47-
48-
# naive implementation (batch=1)
49-
cluster_traces_df["embeddings"] = cluster_traces_df["message"].map(
50-
embedding_model.encode
51-
)
52-
53-
# use batches to speed up embedding
544
from tqdm import tqdm
5+
from sentence_transformers import SentenceTransformer
6+
from langfuse import Langfuse
7+
from sqlmodel import Session
8+
from app.crud.credentials import get_provider_credential
9+
from app.core import settings
5510

56-
batch_size = 512 # Choose an appropriate batch size based on your model and hardware capabilities
57-
messages = cluster_traces_df["message"].tolist()
58-
embeddings = []
59-
60-
# Use tqdm to wrap your range function for the progress bar
61-
for i in tqdm(range(0, len(messages), batch_size), desc="Encoding batches"):
62-
batch = messages[i : i + batch_size]
63-
batch_embeddings = embedding_model.encode(batch)
64-
embeddings.extend(batch_embeddings)
65-
66-
cluster_traces_df["embeddings"] = embeddings
67-
68-
69-
import hdbscan
70-
71-
clusterer = hdbscan.HDBSCAN(min_cluster_size=4)
72-
cluster_traces_df["cluster"] = clusterer.fit_predict(
73-
cluster_traces_df["embeddings"].to_list()
74-
)
75-
76-
cluster_traces_df["cluster"].value_counts().head(2).to_dict()
77-
78-
79-
import openai
80-
81-
# Note: Depending on the volume of data you are running,
82-
# you may want to limit the number of utterances representing each group (ex. utterances_group[:5])
83-
84-
85-
def generate_label(message_group):
86-
prompt = f"""
87-
# Task
88-
Your goal is to assign an intent label that most accurately fits the given group of utterances.
89-
You will only provide a single label, no explanation. The label should be snake cased.
90-
91-
## Example utterances
92-
so long
93-
bye
94-
95-
## Example labels
96-
goodbye
97-
end_conversation
98-
99-
Utterances: {message_group}
100-
Label:
101-
"""
102-
response = openai.chat.completions.create(
103-
model="gpt-4o-mini",
104-
messages=[{"role": "user", "content": prompt}],
105-
max_tokens=50,
106-
)
107-
return response.choices[0].message.content.strip()
108-
109-
110-
print(cluster_traces_df)
111-
for cluster in cluster_traces_df["cluster"].unique():
112-
if cluster == -1:
113-
continue
114-
messages_in_cluster = cluster_traces_df[cluster_traces_df["cluster"] == cluster][
115-
"message"
116-
]
117-
118-
# sample if too many messages
119-
if len(messages_in_cluster) > 50:
120-
messages_in_cluster = messages_in_cluster.sample(50)
121-
122-
label = generate_label(messages_in_cluster)
123-
cluster_traces_df.loc[
124-
cluster_traces_df["cluster"] == cluster, "cluster_label"
125-
] = label
126-
127-
128-
cluster_traces_df["cluster_label"].value_counts().head(20).to_dict()
129-
130-
# explore the messages sent within a specific cluster
131-
cluster_traces_df[
132-
cluster_traces_df["cluster_label"] == "trace_in_langfuse"
133-
].message.head(20).to_dict()
11+
warnings.filterwarnings("ignore")
13412

135-
# add as labels back to langfuse
136-
for index, row in cluster_traces_df.iterrows():
137-
if row["cluster"] != -1:
138-
trace_id = row["trace_id"]
139-
label = row["cluster_label"]
140-
langfuse.trace(id=trace_id, tags=[label])
13+
class LangfuseExperiment:
14+
def __init__(self, db: Session, org_id: str, project_id: str = None):
15+
self.db = db
16+
self.org_id = org_id
17+
self.project_id = project_id
18+
self.embedding_model = SentenceTransformer("all-mpnet-base-v2")
19+
self.langfuse = self._initialize_langfuse()
20+
21+
def _initialize_langfuse(self) -> Langfuse:
22+
"""Initialize Langfuse client with credentials from database."""
23+
credentials = get_provider_credential(
24+
session=self.db,
25+
org_id=self.org_id,
26+
provider="langfuse",
27+
project_id=self.project_id
28+
)
29+
30+
if not credentials:
31+
raise ValueError("Langfuse credentials not found in database")
32+
33+
return Langfuse(
34+
public_key=credentials["public_key"],
35+
secret_key=credentials["secret_key"],
36+
host=credentials["host"]
37+
)
38+
39+
def fetch_traces(self, pages_to_fetch: int = 2) -> List[Dict[str, Any]]:
40+
"""Fetch traces from Langfuse."""
41+
traces = []
42+
for i in range(pages_to_fetch):
43+
traces_page = self.langfuse.fetch_traces(page=i + 1)
44+
traces.extend(traces_page.data)
45+
return traces
46+
47+
def prepare_traces_dataframe(self, traces: List[Dict[str, Any]]) -> pd.DataFrame:
48+
"""Convert traces to DataFrame and clean data."""
49+
traces_list = [[trace.id, trace.input] for trace in traces]
50+
df = pd.DataFrame(traces_list, columns=["trace_id", "message"])
51+
df.dropna(inplace=True)
52+
53+
# Filter out system messages
54+
df = df[~df["message"].isin(["setup_thread", "validate_thread", "RunID"])]
55+
df = df.reset_index(drop=True)
56+
return df
57+
58+
def generate_embeddings(self, df: pd.DataFrame, batch_size: int = 512) -> pd.DataFrame:
59+
"""Generate embeddings for messages in batches."""
60+
messages = df["message"].tolist()
61+
embeddings = []
62+
63+
for i in tqdm(range(0, len(messages), batch_size), desc="Encoding batches"):
64+
batch = messages[i:i + batch_size]
65+
batch_embeddings = self.embedding_model.encode(batch)
66+
embeddings.extend(batch_embeddings)
67+
68+
df["embeddings"] = embeddings
69+
return df
70+
71+
def cluster_traces(self, df: pd.DataFrame) -> pd.DataFrame:
72+
"""Cluster traces using HDBSCAN."""
73+
import hdbscan
74+
clusterer = hdbscan.HDBSCAN(min_cluster_size=4)
75+
df["cluster"] = clusterer.fit_predict(df["embeddings"].to_list())
76+
return df
77+
78+
def generate_cluster_labels(self, df: pd.DataFrame) -> pd.DataFrame:
79+
"""Generate labels for clusters using OpenAI."""
80+
import openai
81+
82+
for cluster in df["cluster"].unique():
83+
if cluster == -1:
84+
continue
85+
86+
messages_in_cluster = df[df["cluster"] == cluster]["message"]
87+
88+
# Sample if too many messages
89+
if len(messages_in_cluster) > 50:
90+
messages_in_cluster = messages_in_cluster.sample(50)
91+
92+
label = self._generate_label(messages_in_cluster)
93+
df.loc[df["cluster"] == cluster, "cluster_label"] = label
94+
95+
return df
96+
97+
def _generate_label(self, message_group: pd.Series) -> str:
98+
"""Generate a label for a group of messages using OpenAI."""
99+
import openai
100+
101+
prompt = f"""
102+
# Task
103+
Your goal is to assign an intent label that most accurately fits the given group of utterances.
104+
You will only provide a single label, no explanation. The label should be snake cased.
105+
106+
## Example utterances
107+
so long
108+
bye
109+
110+
## Example labels
111+
goodbye
112+
end_conversation
113+
114+
Utterances: {message_group}
115+
Label:
116+
"""
117+
118+
response = openai.chat.completions.create(
119+
model="gpt-4",
120+
messages=[{"role": "user", "content": prompt}],
121+
max_tokens=50,
122+
)
123+
return response.choices[0].message.content.strip()
124+
125+
def update_langfuse_traces(self, df: pd.DataFrame) -> None:
126+
"""Update traces in Langfuse with cluster labels."""
127+
for _, row in df.iterrows():
128+
if row["cluster"] != -1:
129+
self.langfuse.trace(
130+
id=row["trace_id"],
131+
tags=[row["cluster_label"]]
132+
)
133+
134+
def run_experiment(self, pages_to_fetch: int = 2) -> pd.DataFrame:
135+
"""Run the complete experiment pipeline."""
136+
# Fetch traces
137+
traces = self.fetch_traces(pages_to_fetch)
138+
139+
# Prepare DataFrame
140+
df = self.prepare_traces_dataframe(traces)
141+
142+
# Generate embeddings
143+
df = self.generate_embeddings(df)
144+
145+
# Cluster traces
146+
df = self.cluster_traces(df)
147+
148+
# Generate labels
149+
df = self.generate_cluster_labels(df)
150+
151+
# Update Langfuse
152+
self.update_langfuse_traces(df)
153+
154+
return df
155+
156+
# Example usage:
157+
# experiment = LangfuseExperiment(db=session, org_id="org_123", project_id="proj_456")
158+
# results_df = experiment.run_experiment()

0 commit comments

Comments
 (0)