Skip to content

Commit 698e37c

Browse files
authored
fix: token usage calculation (#88)
1 parent 5bd7bbd commit 698e37c

22 files changed

+621
-360
lines changed

python/AGENTS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ This repo uses a unified, deterministic testing infrastructure to keep tests fas
1212
- Unit client uses `mock_agent_factory` and `mock_vector_db`.
1313
- Integration client injects a real `RagPipeline` wired to `mock_query_processor` + `mock_vector_db` (via the same `mock_agent_factory`).
1414
- Replace ad‑hoc stubs with shared fixtures: `sample_processed_query`, `mock_query_processor`, `sample_documents`, and `mock_returned_documents` (built from `sample_documents`).
15+
- Respect declared types. When a signature says the argument is type `T`, never guard it with `is None` or `hasattr` checks for `T`'s own surface area—just call the method and let the type system show bugs. (Example: if something is typed `dspy.Prediction`, call `get_lm_usage()` directly and set usage via `set_lm_usage`. Don't assume these attributes are not present.)
1516

1617
## DSPy/LLM Behavior
1718

python/src/cairo_coder/agents/registry.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
agent system with a simple, in-memory registry of available agents.
66
"""
77

8-
from dataclasses import dataclass
8+
from collections.abc import Callable
9+
from dataclasses import dataclass, field
910
from enum import Enum
11+
from typing import Any
1012

1113
from cairo_coder.core.config import VectorStoreConfig
1214
from cairo_coder.core.rag_pipeline import RagPipeline, RagPipelineFactory
@@ -33,7 +35,8 @@ class AgentSpec:
3335
name: str
3436
description: str
3537
sources: list[DocumentSource]
36-
generation_program_type: AgentId
38+
pipeline_builder: Callable[..., RagPipeline]
39+
builder_kwargs: dict[str, Any] = field(default_factory=dict)
3740
max_source_count: int = 5
3841
similarity_threshold: float = 0.4
3942

@@ -48,31 +51,15 @@ def build(self, vector_db: SourceFilteredPgVectorRM, vector_store_config: Vector
4851
Returns:
4952
Configured RagPipeline instance
5053
"""
51-
match self.generation_program_type:
52-
case AgentId.STARKNET:
53-
return RagPipelineFactory.create_pipeline(
54-
name=self.name,
55-
vector_store_config=vector_store_config,
56-
sources=self.sources,
57-
query_processor=create_query_processor(),
58-
generation_program=create_generation_program(AgentId.STARKNET),
59-
mcp_generation_program=create_mcp_generation_program(),
60-
max_source_count=self.max_source_count,
61-
similarity_threshold=self.similarity_threshold,
62-
vector_db=vector_db,
63-
)
64-
case AgentId.CAIRO_CODER:
65-
return RagPipelineFactory.create_pipeline(
66-
name=self.name,
67-
vector_store_config=vector_store_config,
68-
sources=self.sources,
69-
query_processor=create_query_processor(),
70-
generation_program=create_generation_program(AgentId.CAIRO_CODER),
71-
mcp_generation_program=create_mcp_generation_program(),
72-
max_source_count=self.max_source_count,
73-
similarity_threshold=self.similarity_threshold,
74-
vector_db=vector_db,
75-
)
54+
return self.pipeline_builder(
55+
name=self.name,
56+
vector_store_config=vector_store_config,
57+
vector_db=vector_db,
58+
sources=self.sources,
59+
max_source_count=self.max_source_count,
60+
similarity_threshold=self.similarity_threshold,
61+
**self.builder_kwargs,
62+
)
7663

7764

7865
# The global registry of available agents
@@ -81,15 +68,25 @@ def build(self, vector_db: SourceFilteredPgVectorRM, vector_store_config: Vector
8168
name="Cairo Coder",
8269
description="General Cairo programming assistant",
8370
sources=list(DocumentSource), # All sources
84-
generation_program_type=AgentId.CAIRO_CODER,
71+
pipeline_builder=RagPipelineFactory.create_pipeline,
72+
builder_kwargs={
73+
"query_processor": create_query_processor(),
74+
"generation_program": create_generation_program(AgentId.CAIRO_CODER),
75+
"mcp_generation_program": create_mcp_generation_program(),
76+
},
8577
max_source_count=5,
8678
similarity_threshold=0.4,
8779
),
8880
AgentId.STARKNET: AgentSpec(
8981
name="Starknet Agent",
9082
description="Assistant for the Starknet ecosystem (contracts, tools, docs).",
9183
sources=list(DocumentSource),
92-
generation_program_type=AgentId.STARKNET,
84+
pipeline_builder=RagPipelineFactory.create_pipeline,
85+
builder_kwargs={
86+
"query_processor": create_query_processor(),
87+
"generation_program": create_generation_program(AgentId.STARKNET),
88+
"mcp_generation_program": create_mcp_generation_program(),
89+
},
9390
max_source_count=5,
9491
similarity_threshold=0.4,
9592
),

python/src/cairo_coder/core/rag_pipeline.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
from cairo_coder.core.types import (
2020
Document,
2121
DocumentSource,
22+
FormattedSource,
2223
Message,
2324
ProcessedQuery,
2425
StreamEvent,
2526
StreamEventType,
27+
combine_usage,
2628
title_from_url,
2729
)
2830
from cairo_coder.dspy.document_retriever import DocumentRetrieverProgram
@@ -82,33 +84,57 @@ def __init__(self, config: RagPipelineConfig):
8284
self._current_processed_query: ProcessedQuery | None = None
8385
self._current_documents: list[Document] = []
8486

87+
# Token usage accumulator
88+
self._accumulated_usage: dict[str, dict[str, int]] = {}
89+
8590
@property
8691
def last_retrieved_documents(self) -> list[Document]:
8792
"""Documents retrieved during the most recent pipeline execution."""
8893
return self._current_documents
8994

95+
def _accumulate_usage(self, prediction: dspy.Prediction) -> None:
96+
"""
97+
Accumulate token usage from a prediction.
98+
99+
Args:
100+
prediction: DSPy prediction object with usage information
101+
"""
102+
usage = prediction.get_lm_usage()
103+
self._accumulated_usage = combine_usage(self._accumulated_usage, usage)
104+
105+
def _reset_usage(self) -> None:
106+
"""Reset accumulated usage for a new request."""
107+
self._accumulated_usage = {}
108+
90109
async def _aprocess_query_and_retrieve_docs(
91110
self,
92111
query: str,
93112
chat_history_str: str,
94113
sources: list[DocumentSource] | None = None,
95114
) -> tuple[ProcessedQuery, list[Document]]:
96115
"""Process query and retrieve documents - shared async logic."""
97-
processed_query = await self.query_processor.aforward(
116+
qp_prediction = await self.query_processor.aforward(
98117
query=query, chat_history=chat_history_str
99118
)
119+
self._accumulate_usage(qp_prediction)
120+
processed_query = qp_prediction.processed_query
100121
self._current_processed_query = processed_query
101122

102123
# Use provided sources or fall back to processed query sources
103124
retrieval_sources = sources or processed_query.resources
104-
documents = await self.document_retriever.aforward(
125+
dr_prediction = await self.document_retriever.aforward(
105126
processed_query=processed_query, sources=retrieval_sources
106127
)
128+
self._accumulate_usage(dr_prediction)
129+
documents = dr_prediction.documents
107130

108131
# Optional Grok web/X augmentation: activate when STARKNET_BLOG is among sources.
109132
try:
110133
if DocumentSource.STARKNET_BLOG in retrieval_sources:
111-
grok_docs = await self.grok_search.aforward(processed_query, chat_history_str)
134+
grok_pred = await self.grok_search.aforward(processed_query, chat_history_str)
135+
self._accumulate_usage(grok_pred)
136+
grok_docs = grok_pred.documents
137+
112138
self._grok_citations = list(self.grok_search.last_citations)
113139
if grok_docs:
114140
documents.extend(grok_docs)
@@ -126,7 +152,9 @@ async def _aprocess_query_and_retrieve_docs(
126152
lm=dspy.LM("gemini/gemini-flash-lite-latest", max_tokens=10000, temperature=0.5),
127153
adapter=XMLAdapter(),
128154
):
129-
documents = await self.retrieval_judge.aforward(query=query, documents=documents)
155+
judge_pred = await self.retrieval_judge.aforward(query=query, documents=documents)
156+
self._accumulate_usage(judge_pred)
157+
documents = judge_pred.documents
130158
except Exception as e:
131159
logger.warning(
132160
"Retrieval judge failed (async), using all documents",
@@ -158,6 +186,9 @@ async def aforward(
158186
mcp_mode: bool = False,
159187
sources: list[DocumentSource] | None = None,
160188
) -> dspy.Prediction:
189+
# Reset usage for this request
190+
self._reset_usage()
191+
161192
chat_history_str = self._format_chat_history(chat_history or [])
162193
processed_query, documents = await self._aprocess_query_and_retrieve_docs(
163194
query, chat_history_str, sources
@@ -167,13 +198,21 @@ async def aforward(
167198
)
168199

169200
if mcp_mode:
170-
return await self.mcp_generation_program.aforward(documents)
201+
result = await self.mcp_generation_program.aforward(documents)
202+
self._accumulate_usage(result)
203+
result.set_lm_usage(self._accumulated_usage)
204+
return result
171205

172206
context = self._prepare_context(documents)
173207

174-
return await self.generation_program.aforward(
208+
result = await self.generation_program.aforward(
175209
query=query, context=context, chat_history=chat_history_str
176210
)
211+
if result:
212+
self._accumulate_usage(result)
213+
# Update the result's usage to include accumulated usage from previous steps
214+
result.set_lm_usage(self._accumulated_usage)
215+
return result
177216

178217

179218
async def aforward_streaming(
@@ -251,6 +290,7 @@ async def aforward_streaming(
251290
logger.warning(f"Unknown signature field name: {chunk.signature_field_name}")
252291
elif isinstance(chunk, dspy.Prediction):
253292
# Final complete answer
293+
self._accumulate_usage(chunk)
254294
final_text = getattr(chunk, "answer", None) or chunk_accumulator
255295
yield StreamEvent(type=StreamEventType.FINAL_RESPONSE, data=final_text)
256296
rt.end(outputs={"output": final_text})
@@ -268,28 +308,12 @@ async def aforward_streaming(
268308

269309
def get_lm_usage(self) -> dict[str, dict[str, int]]:
270310
"""
271-
Get the total number of tokens used by the LLM.
272-
"""
273-
generation_usage = self.generation_program.get_lm_usage()
274-
query_usage = self.query_processor.get_lm_usage()
275-
judge_usage = self.retrieval_judge.get_lm_usage()
276-
277-
# Additive merge strategy
278-
merged_usage = {}
279-
280-
# Helper function to merge usage dictionaries
281-
def merge_usage_dict(target: dict, source: dict) -> None:
282-
for model_name, metrics in source.items():
283-
if model_name not in target:
284-
target[model_name] = {}
285-
for metric_name, value in metrics.items():
286-
target[model_name][metric_name] = target[model_name].get(metric_name, 0) + value
311+
Get accumulated token usage from all predictions in the pipeline.
287312
288-
merge_usage_dict(merged_usage, generation_usage)
289-
merge_usage_dict(merged_usage, query_usage)
290-
merge_usage_dict(merged_usage, judge_usage)
291-
292-
return merged_usage
313+
Returns:
314+
Dictionary mapping model names to usage metrics
315+
"""
316+
return self._accumulated_usage
293317

294318
def _format_chat_history(self, chat_history: list[Message]) -> str:
295319
"""
@@ -311,7 +335,7 @@ def _format_chat_history(self, chat_history: list[Message]) -> str:
311335

312336
return "\n".join(formatted_messages)
313337

314-
def _format_sources(self, documents: list[Document]) -> list[dict[str, Any]]:
338+
def _format_sources(self, documents: list[Document]) -> list[FormattedSource]:
315339
"""
316340
Format documents for the frontend-friendly sources event.
317341
@@ -322,9 +346,9 @@ def _format_sources(self, documents: list[Document]) -> list[dict[str, Any]]:
322346
documents: List of retrieved documents
323347
324348
Returns:
325-
List of dicts: [{"title": str, "url": str}, ...]
349+
List of formatted sources with metadata
326350
"""
327-
sources: list[dict[str, str]] = []
351+
sources: list[FormattedSource] = []
328352
seen_urls: set[str] = set()
329353

330354

python/src/cairo_coder/core/types.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,29 @@ class ProcessedQuery:
7474
is_test_related: bool = False
7575
resources: list[DocumentSource] = field(default_factory=list)
7676

77+
LMUsageEntry = dict[str, Any]
78+
LMUsage = dict[str, LMUsageEntry]
79+
80+
81+
class RetrievedSourceData(TypedDict):
82+
"""Structure for retrieved source data stored in database."""
83+
84+
page_content: str
85+
metadata: DocumentMetadata
86+
87+
88+
class FormattedSourceMetadata(TypedDict):
89+
"""Metadata structure for formatted sources sent to frontend."""
90+
91+
title: str
92+
url: str
93+
source_type: str
94+
95+
96+
class FormattedSource(TypedDict):
97+
"""Structure for formatted sources sent to frontend."""
98+
99+
metadata: FormattedSourceMetadata
77100

78101
# Helper to extract domain title
79102
def title_from_url(url: str) -> str:
@@ -174,6 +197,33 @@ def to_dict(self) -> dict[str, Any]:
174197
"details": self.details,
175198
"timestamp": self.timestamp.isoformat(),
176199
}
200+
201+
202+
def combine_usage(usage1: LMUsage, usage2: LMUsage) -> LMUsage:
203+
"""Combine two LM usage dictionaries, tolerating missing inputs."""
204+
result: LMUsage = {model: (metrics or {}).copy() for model, metrics in usage1.items()}
205+
206+
for model, metrics in usage2.items():
207+
if model not in result:
208+
result[model] = metrics.copy()
209+
else:
210+
# Merge metrics
211+
for key, value in metrics.items():
212+
if isinstance(value, int | float):
213+
result[model][key] = result[model].get(key, 0) + value
214+
elif isinstance(value, dict):
215+
if key not in result[model] or result[model][key] is None:
216+
result[model][key] = value.copy()
217+
else:
218+
# Recursive merge for nested dicts
219+
for detail_key, detail_value in value.items():
220+
if isinstance(detail_value, int | float):
221+
result[model][key][detail_key] = (
222+
result[model][key].get(detail_key, 0) + detail_value
223+
)
224+
return result
225+
226+
177227
class AgentResponse(BaseModel):
178228
"""Response from agent processing."""
179229

python/src/cairo_coder/db/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from pydantic import BaseModel, Field
1212

13+
from cairo_coder.core.types import RetrievedSourceData
14+
1315

1416
class UserInteraction(BaseModel):
1517
"""Represents a record in the user_interactions table."""
@@ -21,5 +23,5 @@ class UserInteraction(BaseModel):
2123
chat_history: Optional[list[dict[str, Any]]] = None
2224
query: str
2325
generated_answer: Optional[str] = None
24-
retrieved_sources: Optional[list[dict[str, Any]]] = None
26+
retrieved_sources: Optional[list[RetrievedSourceData]] = None
2527
llm_usage: Optional[dict[str, Any]] = None

0 commit comments

Comments
 (0)