@@ -376,7 +376,6 @@ def process_router_predictions(
376376 save_interval : int = 50 ,
377377 num_workers : int = 4 ,
378378 force : bool = False ,
379- robustness_predictions_path : Optional [str ] = None ,
380379) -> None :
381380 """
382381 Process router predictions by evaluating generated results with incremental saving.
@@ -392,7 +391,6 @@ def process_router_predictions(
392391 save_interval: Number of entries to process before saving (default: 50)
393392 num_workers: Number of worker threads for parallel processing (default: 4)
394393 force: If True, re-evaluate all entries even if already evaluated (default: False)
395- robustness_predictions_path: Optional path to the robustness predictions file
396394 """
397395 logger .info (f"Starting LLM evaluation for router: { router_name } (split: { split } )" )
398396 logger .info (f"Using { num_workers } worker threads for parallel processing" )
@@ -562,27 +560,8 @@ def evaluate_task(seq_idx: int, prediction: Dict[str, Any]) -> bool:
562560 )
563561 logger .info ("=" * 60 )
564562
565- # Load robustness predictions if requested
566- robustness_predictions = None
567- if robustness_predictions_path :
568- try :
569- robustness_predictions = load_predictions_from_path (
570- robustness_predictions_path
571- )
572- logger .info (
573- f"Loaded robustness predictions from { robustness_predictions_path } "
574- )
575- except FileNotFoundError :
576- logger .warning (
577- f"Robustness predictions not found at { robustness_predictions_path } "
578- )
579- except Exception as e :
580- logger .warning (
581- f"Could not load robustness predictions from { robustness_predictions_path } : { e } "
582- )
583-
584563 # Compute and display router-level metrics
585- compute_router_metrics (predictions , router_name , robustness_predictions )
564+ compute_router_metrics (predictions , router_name )
586565
587566
588567def _prepare_optimality_data (
@@ -926,16 +905,14 @@ def run_robustness_only(router_name: str, robustness_path: Optional[str]) -> Non
926905
927906 try :
928907 robustness_predictions = load_predictions_from_path (target_path )
929- except FileNotFoundError as error :
908+ except FileNotFoundError :
930909 raise FileNotFoundError (
931910 "Robustness predictions not found at "
932911 f"{ target_path } . Generate them with "
933912 "router_inference/generate_prediction_file.py <router> robustness."
934- ) from error
935- except Exception as exc :
936- raise RuntimeError (
937- f"Unable to load robustness predictions from { target_path } : { exc } "
938- ) from exc
913+ )
914+ except json .JSONDecodeError :
915+ raise RuntimeError (f"Unable to load robustness predictions from { target_path } " )
939916
940917 score = compute_robustness_score (predictions , robustness_predictions )
941918 if score is None :
@@ -952,11 +929,7 @@ def run_robustness_only(router_name: str, robustness_path: Optional[str]) -> Non
952929 logger .info ("Robustness metrics saved to %s" , metrics_path )
953930
954931
955- def compute_router_metrics (
956- predictions : List [Dict [str , Any ]],
957- router_name : str ,
958- robustness_predictions : Optional [List [Dict [str , Any ]]] = None ,
959- ) -> None :
932+ def compute_router_metrics (predictions : List [Dict [str , Any ]], router_name : str ) -> None :
960933 """
961934 Compute router-level metrics (accuracy, cost, RouterArena score, etc.) and display them.
962935
@@ -1102,23 +1075,6 @@ def compute_router_metrics(
11021075 "num_sub10_queries" : optimality_scores ["num_sub10_queries" ],
11031076 }
11041077
1105- robustness_score = None
1106- if robustness_predictions :
1107- logger .info ("\n " + "-" * 80 )
1108- logger .info ("Computing Robustness Score (model selection flip ratio)..." )
1109- logger .info ("-" * 80 )
1110- robustness_score = compute_robustness_score (predictions , robustness_predictions )
1111- if robustness_score is not None :
1112- logger .info (
1113- f"Robustness flip ratio: { robustness_score :.4f} "
1114- f"({ robustness_score * 100 :.2f} % differing selections)"
1115- )
1116- metrics_dict ["robustness_score" ] = robustness_score
1117- else :
1118- logger .warning (
1119- "Robustness score could not be computed because no overlapping entries were found."
1120- )
1121-
11221078 # Save to metrics.json
11231079 metrics_path = "./metrics.json"
11241080 with open (metrics_path , "w" ) as f :
@@ -1216,9 +1172,6 @@ def main():
12161172 save_interval ,
12171173 args .num_workers ,
12181174 args .force ,
1219- default_robustness_path
1220- if os .path .exists (default_robustness_path )
1221- else None ,
12221175 )
12231176 except KeyboardInterrupt :
12241177 logger .info ("\n Interrupted by user. Saving partial results..." )
0 commit comments