Skip to content

Commit fd8e765

Browse files
committed
Refine code
1 parent 407aa2e commit fd8e765

3 files changed

Lines changed: 8 additions & 55 deletions

File tree

global_utils/robustness.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def _normalize_model_name(
2020
return None
2121
try:
2222
return name_manager.get_universal_name(model_name)
23-
except Exception:
23+
except ValueError:
2424
return model_name
2525

2626

llm_evaluation/run.py

Lines changed: 6 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,6 @@ def process_router_predictions(
376376
save_interval: int = 50,
377377
num_workers: int = 4,
378378
force: bool = False,
379-
robustness_predictions_path: Optional[str] = None,
380379
) -> None:
381380
"""
382381
Process router predictions by evaluating generated results with incremental saving.
@@ -392,7 +391,6 @@ def process_router_predictions(
392391
save_interval: Number of entries to process before saving (default: 50)
393392
num_workers: Number of worker threads for parallel processing (default: 4)
394393
force: If True, re-evaluate all entries even if already evaluated (default: False)
395-
robustness_predictions_path: Optional path to the robustness predictions file
396394
"""
397395
logger.info(f"Starting LLM evaluation for router: {router_name} (split: {split})")
398396
logger.info(f"Using {num_workers} worker threads for parallel processing")
@@ -562,27 +560,8 @@ def evaluate_task(seq_idx: int, prediction: Dict[str, Any]) -> bool:
562560
)
563561
logger.info("=" * 60)
564562

565-
# Load robustness predictions if requested
566-
robustness_predictions = None
567-
if robustness_predictions_path:
568-
try:
569-
robustness_predictions = load_predictions_from_path(
570-
robustness_predictions_path
571-
)
572-
logger.info(
573-
f"Loaded robustness predictions from {robustness_predictions_path}"
574-
)
575-
except FileNotFoundError:
576-
logger.warning(
577-
f"Robustness predictions not found at {robustness_predictions_path}"
578-
)
579-
except Exception as e:
580-
logger.warning(
581-
f"Could not load robustness predictions from {robustness_predictions_path}: {e}"
582-
)
583-
584563
# Compute and display router-level metrics
585-
compute_router_metrics(predictions, router_name, robustness_predictions)
564+
compute_router_metrics(predictions, router_name)
586565

587566

588567
def _prepare_optimality_data(
@@ -926,16 +905,14 @@ def run_robustness_only(router_name: str, robustness_path: Optional[str]) -> Non
926905

927906
try:
928907
robustness_predictions = load_predictions_from_path(target_path)
929-
except FileNotFoundError as error:
908+
except FileNotFoundError:
930909
raise FileNotFoundError(
931910
"Robustness predictions not found at "
932911
f"{target_path}. Generate them with "
933912
"router_inference/generate_prediction_file.py <router> robustness."
934-
) from error
935-
except Exception as exc:
936-
raise RuntimeError(
937-
f"Unable to load robustness predictions from {target_path}: {exc}"
938-
) from exc
913+
)
914+
except json.JSONDecodeError:
915+
raise RuntimeError(f"Unable to load robustness predictions from {target_path}")
939916

940917
score = compute_robustness_score(predictions, robustness_predictions)
941918
if score is None:
@@ -952,11 +929,7 @@ def run_robustness_only(router_name: str, robustness_path: Optional[str]) -> Non
952929
logger.info("Robustness metrics saved to %s", metrics_path)
953930

954931

955-
def compute_router_metrics(
956-
predictions: List[Dict[str, Any]],
957-
router_name: str,
958-
robustness_predictions: Optional[List[Dict[str, Any]]] = None,
959-
) -> None:
932+
def compute_router_metrics(predictions: List[Dict[str, Any]], router_name: str) -> None:
960933
"""
961934
Compute router-level metrics (accuracy, cost, RouterArena score, etc.) and display them.
962935
@@ -1102,23 +1075,6 @@ def compute_router_metrics(
11021075
"num_sub10_queries": optimality_scores["num_sub10_queries"],
11031076
}
11041077

1105-
robustness_score = None
1106-
if robustness_predictions:
1107-
logger.info("\n" + "-" * 80)
1108-
logger.info("Computing Robustness Score (model selection flip ratio)...")
1109-
logger.info("-" * 80)
1110-
robustness_score = compute_robustness_score(predictions, robustness_predictions)
1111-
if robustness_score is not None:
1112-
logger.info(
1113-
f"Robustness flip ratio: {robustness_score:.4f} "
1114-
f"({robustness_score * 100:.2f}% differing selections)"
1115-
)
1116-
metrics_dict["robustness_score"] = robustness_score
1117-
else:
1118-
logger.warning(
1119-
"Robustness score could not be computed because no overlapping entries were found."
1120-
)
1121-
11221078
# Save to metrics.json
11231079
metrics_path = "./metrics.json"
11241080
with open(metrics_path, "w") as f:
@@ -1216,9 +1172,6 @@ def main():
12161172
save_interval,
12171173
args.num_workers,
12181174
args.force,
1219-
default_robustness_path
1220-
if os.path.exists(default_robustness_path)
1221-
else None,
12221175
)
12231176
except KeyboardInterrupt:
12241177
logger.info("\nInterrupted by user. Saving partial results...")

scripts/process_datasets/prep_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"global index": row.get("Global Index"),
3232
}
3333
)
34-
robustness_json_path = os.path.join(save_dir, "router_robustness.json")
34+
robustness_json_path = os.path.join(save_dir, "router_robustne-ss.json")
3535
with open(robustness_json_path, "w", encoding="utf-8") as f:
3636
json.dump(robustness_records, f, ensure_ascii=False, indent=2)
3737
print(f"[prep] Wrote {len(robustness_records)} items to {robustness_json_path}")

0 commit comments

Comments
 (0)