Skip to content

Commit b9ed01d

Browse files
committed
Refactor
1 parent 1b82155 commit b9ed01d

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

backend/app/crud/langfusexp.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,31 @@
1010

1111
warnings.filterwarnings("ignore")
1212

13+
1314
class LangfuseExperiment:
1415
def __init__(self, db: Session, org_id: str, project_id: str = None):
1516
self.db = db
1617
self.org_id = org_id
1718
self.project_id = project_id
1819
self.embedding_model = SentenceTransformer("all-mpnet-base-v2")
1920
self.langfuse = self._initialize_langfuse()
20-
21+
2122
def _initialize_langfuse(self) -> Langfuse:
2223
"""Initialize Langfuse client with credentials from database."""
2324
credentials = get_provider_credential(
2425
session=self.db,
2526
org_id=self.org_id,
2627
provider="langfuse",
27-
project_id=self.project_id
28+
project_id=self.project_id,
2829
)
29-
30+
3031
if not credentials:
3132
raise ValueError("Langfuse credentials not found in database")
32-
33+
3334
return Langfuse(
3435
public_key=credentials["public_key"],
3536
secret_key=credentials["secret_key"],
36-
host=credentials["host"]
37+
host=credentials["host"],
3738
)
3839

3940
def fetch_traces(self, pages_to_fetch: int = 2) -> List[Dict[str, Any]]:
@@ -49,55 +50,58 @@ def prepare_traces_dataframe(self, traces: List[Dict[str, Any]]) -> pd.DataFrame
4950
traces_list = [[trace.id, trace.input] for trace in traces]
5051
df = pd.DataFrame(traces_list, columns=["trace_id", "message"])
5152
df.dropna(inplace=True)
52-
53+
5354
# Filter out system messages
5455
df = df[~df["message"].isin(["setup_thread", "validate_thread", "RunID"])]
5556
df = df.reset_index(drop=True)
5657
return df
5758

58-
def generate_embeddings(self, df: pd.DataFrame, batch_size: int = 512) -> pd.DataFrame:
59+
def generate_embeddings(
60+
self, df: pd.DataFrame, batch_size: int = 512
61+
) -> pd.DataFrame:
5962
"""Generate embeddings for messages in batches."""
6063
messages = df["message"].tolist()
6164
embeddings = []
62-
65+
6366
for i in tqdm(range(0, len(messages), batch_size), desc="Encoding batches"):
64-
batch = messages[i:i + batch_size]
67+
batch = messages[i : i + batch_size]
6568
batch_embeddings = self.embedding_model.encode(batch)
6669
embeddings.extend(batch_embeddings)
67-
70+
6871
df["embeddings"] = embeddings
6972
return df
7073

7174
def cluster_traces(self, df: pd.DataFrame) -> pd.DataFrame:
7275
"""Cluster traces using HDBSCAN."""
7376
import hdbscan
77+
7478
clusterer = hdbscan.HDBSCAN(min_cluster_size=4)
7579
df["cluster"] = clusterer.fit_predict(df["embeddings"].to_list())
7680
return df
7781

7882
def generate_cluster_labels(self, df: pd.DataFrame) -> pd.DataFrame:
7983
"""Generate labels for clusters using OpenAI."""
8084
import openai
81-
85+
8286
for cluster in df["cluster"].unique():
8387
if cluster == -1:
8488
continue
85-
89+
8690
messages_in_cluster = df[df["cluster"] == cluster]["message"]
87-
91+
8892
# Sample if too many messages
8993
if len(messages_in_cluster) > 50:
9094
messages_in_cluster = messages_in_cluster.sample(50)
91-
95+
9296
label = self._generate_label(messages_in_cluster)
9397
df.loc[df["cluster"] == cluster, "cluster_label"] = label
94-
98+
9599
return df
96100

97101
def _generate_label(self, message_group: pd.Series) -> str:
98102
"""Generate a label for a group of messages using OpenAI."""
99103
import openai
100-
104+
101105
prompt = f"""
102106
# Task
103107
Your goal is to assign an intent label that most accurately fits the given group of utterances.
@@ -114,7 +118,7 @@ def _generate_label(self, message_group: pd.Series) -> str:
114118
Utterances: {message_group}
115119
Label:
116120
"""
117-
121+
118122
response = openai.chat.completions.create(
119123
model="gpt-4",
120124
messages=[{"role": "user", "content": prompt}],
@@ -126,33 +130,31 @@ def update_langfuse_traces(self, df: pd.DataFrame) -> None:
126130
"""Update traces in Langfuse with cluster labels."""
127131
for _, row in df.iterrows():
128132
if row["cluster"] != -1:
129-
self.langfuse.trace(
130-
id=row["trace_id"],
131-
tags=[row["cluster_label"]]
132-
)
133+
self.langfuse.trace(id=row["trace_id"], tags=[row["cluster_label"]])
133134

134135
def run_experiment(self, pages_to_fetch: int = 2) -> pd.DataFrame:
135136
"""Run the complete experiment pipeline."""
136137
# Fetch traces
137138
traces = self.fetch_traces(pages_to_fetch)
138-
139+
139140
# Prepare DataFrame
140141
df = self.prepare_traces_dataframe(traces)
141-
142+
142143
# Generate embeddings
143144
df = self.generate_embeddings(df)
144-
145+
145146
# Cluster traces
146147
df = self.cluster_traces(df)
147-
148+
148149
# Generate labels
149150
df = self.generate_cluster_labels(df)
150-
151+
151152
# Update Langfuse
152153
self.update_langfuse_traces(df)
153-
154+
154155
return df
155156

157+
156158
# Example usage:
157159
# experiment = LangfuseExperiment(db=session, org_id="org_123", project_id="proj_456")
158160
# results_df = experiment.run_experiment()

0 commit comments

Comments
 (0)