3030# Add parent directory to path for imports
3131sys .path .append (os .path .abspath (os .path .join (os .path .dirname (__file__ ), "../" )))
3232
33+ # Load environment variables from .env file if it exists
34+ try :
35+ from dotenv import load_dotenv
36+
37+ load_dotenv ()
38+ except ImportError :
39+ # dotenv is optional
40+ pass
41+
3342from universal_model_names import ModelNameManager
3443
3544# Import model evaluator from current directory
@@ -119,39 +128,58 @@ def save_predictions_file(predictions: List[Dict[str, Any]], router_name: str) -
119128
120129def load_ground_truth_dataset (split : str ) -> Dict [str , Dict [str , Any ]]:
121130 """
122- Load ground truth dataset based on split from local disk.
131+ Load ground truth dataset based on split from local disk or private repo.
132+
133+ For "full" split: If HF_TOKEN is available, loads from RouteWorks/RouterEvalBenchmark
134+ (private repo with answers). Otherwise, loads from local disk (public dataset without answers).
123135
124136 Args:
125137 split: Dataset split ("sub_10" for testing or "full" for submission)
126138
127139 Returns:
128140 Dictionary mapping global_index to ground truth data
129141 """
130- from datasets import load_from_disk # type: ignore[import-not-found,import-untyped]
142+ from datasets import load_dataset , load_from_disk # type: ignore[import-not-found,import-untyped]
131143 import pandas as pd # type: ignore[import-untyped]
132144
133145 if split not in ["sub_10" , "full" ]:
134146 raise ValueError (f"Invalid split: { split } . Must be 'sub_10' or 'full'" )
135147
136- logger .info (f"Loading ground truth dataset (split: { split } ) from local disk..." )
148+ router_eval_bench_df = None
149+ hf_token = os .getenv ("HF_TOKEN" ) or os .getenv ("HUGGING_FACE_HUB_TOKEN" )
137150
138- # Load the RouterArena dataset from local disk
139- dataset_path = "./dataset/routerarena"
140- if split == "sub_10" :
141- dataset_path = "./dataset/routerarena_10"
142- if not os .path .exists (dataset_path ):
143- raise FileNotFoundError (
144- f"Dataset not found at { dataset_path } . "
145- f"Please run the following command to download the dataset: python scripts/process_datasets/prep_datasets.py"
146- )
147-
148- router_arena_dataset = load_from_disk (dataset_path )
151+ # For "full" split, try private repo first if token is available
152+ if split == "full" and hf_token :
153+ logger .info ("Loading full dataset with answers from private repo..." )
154+ try :
155+ router_arena_dataset = load_dataset (
156+ "RouteWorks/RouterEvalBenchmark" ,
157+ split = "full" ,
158+ token = hf_token ,
159+ )
160+ router_eval_bench_df = pd .DataFrame (router_arena_dataset )
161+ logger .info ("Successfully loaded from private repo." )
162+ except Exception as e :
163+ logger .warning (
164+ f"Could not load from private repo: { e } . Falling back to local dataset."
165+ )
149166
150- router_eval_bench_df = pd .DataFrame (router_arena_dataset )
167+ # Load from local disk if not already loaded
168+ if router_eval_bench_df is None :
169+ dataset_path = (
170+ "./dataset/routerarena_10" if split == "sub_10" else "./dataset/routerarena"
171+ )
172+ if not os .path .exists (dataset_path ):
173+ raise FileNotFoundError (
174+ f"Dataset not found at { dataset_path } . "
175+ f"Please run: python scripts/process_datasets/prep_datasets.py"
176+ )
177+ logger .info (f"Loading dataset from { dataset_path } ..." )
178+ router_arena_dataset = load_from_disk (dataset_path )
179+ router_eval_bench_df = pd .DataFrame (router_arena_dataset )
151180
152- # Check if we have answers for the "full" split
181+ # Verify answers exist for "full" split
153182 if split == "full" :
154- # Sample a few rows to check if answers are empty
155183 sample_size = min (100 , len (router_eval_bench_df ))
156184 sample_answers = router_eval_bench_df .head (sample_size )["Answer" ]
157185 has_answers = any (
@@ -166,9 +194,11 @@ def load_ground_truth_dataset(split: str) -> Dict[str, Dict[str, Any]]:
166194 logger .error ("" )
167195 logger .error ("To submit predictions for the full dataset evaluation:" )
168196 logger .error ("1. Generate predictions for the full dataset" )
169- logger .error ("2. Create an issue in the RouterArena repository" )
170- logger .error ("3. Upload your predictions file" )
171- logger .error ("4. We will run the official evaluation for you" )
197+ logger .error ("2. Create a pull request to the RouterArena repository" )
198+ logger .error ("3. Include your predictions file in the PR" )
199+ logger .error (
200+ "4. The official evaluation would be automatically conducted for you"
201+ )
172202 logger .error ("=" * 80 )
173203 raise ValueError (
174204 "The 'full' split does not have ground truth answers. "
0 commit comments