1919from cairo_coder .core .types import (
2020 Document ,
2121 DocumentSource ,
22+ FormattedSource ,
2223 Message ,
2324 ProcessedQuery ,
2425 StreamEvent ,
2526 StreamEventType ,
27+ combine_usage ,
2628 title_from_url ,
2729)
2830from cairo_coder .dspy .document_retriever import DocumentRetrieverProgram
@@ -82,33 +84,57 @@ def __init__(self, config: RagPipelineConfig):
8284 self ._current_processed_query : ProcessedQuery | None = None
8385 self ._current_documents : list [Document ] = []
8486
87+ # Token usage accumulator
88+ self ._accumulated_usage : dict [str , dict [str , int ]] = {}
89+
8590 @property
8691 def last_retrieved_documents (self ) -> list [Document ]:
8792 """Documents retrieved during the most recent pipeline execution."""
8893 return self ._current_documents
8994
95+ def _accumulate_usage (self , prediction : dspy .Prediction ) -> None :
96+ """
97+ Accumulate token usage from a prediction.
98+
99+ Args:
100+ prediction: DSPy prediction object with usage information
101+ """
102+ usage = prediction .get_lm_usage ()
103+ self ._accumulated_usage = combine_usage (self ._accumulated_usage , usage )
104+
105+ def _reset_usage (self ) -> None :
106+ """Reset accumulated usage for a new request."""
107+ self ._accumulated_usage = {}
108+
90109 async def _aprocess_query_and_retrieve_docs (
91110 self ,
92111 query : str ,
93112 chat_history_str : str ,
94113 sources : list [DocumentSource ] | None = None ,
95114 ) -> tuple [ProcessedQuery , list [Document ]]:
96115 """Process query and retrieve documents - shared async logic."""
97- processed_query = await self .query_processor .aforward (
116+ qp_prediction = await self .query_processor .aforward (
98117 query = query , chat_history = chat_history_str
99118 )
119+ self ._accumulate_usage (qp_prediction )
120+ processed_query = qp_prediction .processed_query
100121 self ._current_processed_query = processed_query
101122
102123 # Use provided sources or fall back to processed query sources
103124 retrieval_sources = sources or processed_query .resources
104- documents = await self .document_retriever .aforward (
125+ dr_prediction = await self .document_retriever .aforward (
105126 processed_query = processed_query , sources = retrieval_sources
106127 )
128+ self ._accumulate_usage (dr_prediction )
129+ documents = dr_prediction .documents
107130
108131 # Optional Grok web/X augmentation: activate when STARKNET_BLOG is among sources.
109132 try :
110133 if DocumentSource .STARKNET_BLOG in retrieval_sources :
111- grok_docs = await self .grok_search .aforward (processed_query , chat_history_str )
134+ grok_pred = await self .grok_search .aforward (processed_query , chat_history_str )
135+ self ._accumulate_usage (grok_pred )
136+ grok_docs = grok_pred .documents
137+
112138 self ._grok_citations = list (self .grok_search .last_citations )
113139 if grok_docs :
114140 documents .extend (grok_docs )
@@ -126,7 +152,9 @@ async def _aprocess_query_and_retrieve_docs(
126152 lm = dspy .LM ("gemini/gemini-flash-lite-latest" , max_tokens = 10000 , temperature = 0.5 ),
127153 adapter = XMLAdapter (),
128154 ):
129- documents = await self .retrieval_judge .aforward (query = query , documents = documents )
155+ judge_pred = await self .retrieval_judge .aforward (query = query , documents = documents )
156+ self ._accumulate_usage (judge_pred )
157+ documents = judge_pred .documents
130158 except Exception as e :
131159 logger .warning (
132160 "Retrieval judge failed (async), using all documents" ,
@@ -158,6 +186,9 @@ async def aforward(
158186 mcp_mode : bool = False ,
159187 sources : list [DocumentSource ] | None = None ,
160188 ) -> dspy .Prediction :
189+ # Reset usage for this request
190+ self ._reset_usage ()
191+
161192 chat_history_str = self ._format_chat_history (chat_history or [])
162193 processed_query , documents = await self ._aprocess_query_and_retrieve_docs (
163194 query , chat_history_str , sources
@@ -167,13 +198,21 @@ async def aforward(
167198 )
168199
169200 if mcp_mode :
170- return await self .mcp_generation_program .aforward (documents )
201+ result = await self .mcp_generation_program .aforward (documents )
202+ self ._accumulate_usage (result )
203+ result .set_lm_usage (self ._accumulated_usage )
204+ return result
171205
172206 context = self ._prepare_context (documents )
173207
174- return await self .generation_program .aforward (
208+ result = await self .generation_program .aforward (
175209 query = query , context = context , chat_history = chat_history_str
176210 )
211+ if result :
212+ self ._accumulate_usage (result )
213+ # Update the result's usage to include accumulated usage from previous steps
214+ result .set_lm_usage (self ._accumulated_usage )
215+ return result
177216
178217
179218 async def aforward_streaming (
@@ -251,6 +290,7 @@ async def aforward_streaming(
251290 logger .warning (f"Unknown signature field name: { chunk .signature_field_name } " )
252291 elif isinstance (chunk , dspy .Prediction ):
253292 # Final complete answer
293+ self ._accumulate_usage (chunk )
254294 final_text = getattr (chunk , "answer" , None ) or chunk_accumulator
255295 yield StreamEvent (type = StreamEventType .FINAL_RESPONSE , data = final_text )
256296 rt .end (outputs = {"output" : final_text })
@@ -268,28 +308,12 @@ async def aforward_streaming(
268308
269309 def get_lm_usage (self ) -> dict [str , dict [str , int ]]:
270310 """
271- Get the total number of tokens used by the LLM.
272- """
273- generation_usage = self .generation_program .get_lm_usage ()
274- query_usage = self .query_processor .get_lm_usage ()
275- judge_usage = self .retrieval_judge .get_lm_usage ()
276-
277- # Additive merge strategy
278- merged_usage = {}
279-
280- # Helper function to merge usage dictionaries
281- def merge_usage_dict (target : dict , source : dict ) -> None :
282- for model_name , metrics in source .items ():
283- if model_name not in target :
284- target [model_name ] = {}
285- for metric_name , value in metrics .items ():
286- target [model_name ][metric_name ] = target [model_name ].get (metric_name , 0 ) + value
311+ Get accumulated token usage from all predictions in the pipeline.
287312
288- merge_usage_dict (merged_usage , generation_usage )
289- merge_usage_dict (merged_usage , query_usage )
290- merge_usage_dict (merged_usage , judge_usage )
291-
292- return merged_usage
313+ Returns:
314+ Dictionary mapping model names to usage metrics
315+ """
316+ return self ._accumulated_usage
293317
294318 def _format_chat_history (self , chat_history : list [Message ]) -> str :
295319 """
@@ -311,7 +335,7 @@ def _format_chat_history(self, chat_history: list[Message]) -> str:
311335
312336 return "\n " .join (formatted_messages )
313337
314- def _format_sources (self , documents : list [Document ]) -> list [dict [ str , Any ] ]:
338+ def _format_sources (self , documents : list [Document ]) -> list [FormattedSource ]:
315339 """
316340 Format documents for the frontend-friendly sources event.
317341
@@ -322,9 +346,9 @@ def _format_sources(self, documents: list[Document]) -> list[dict[str, Any]]:
322346 documents: List of retrieved documents
323347
324348 Returns:
325- List of dicts: [{"title": str, "url": str}, ...]
349+ List of formatted sources with metadata
326350 """
327- sources : list [dict [ str , str ] ] = []
351+ sources : list [FormattedSource ] = []
328352 seen_urls : set [str ] = set ()
329353
330354
0 commit comments