Skip to content

Commit 71ecc84

Browse files
whoseoysterstainless-app[bot]
authored andcommitted
improvement: updates to custom metric runner
1 parent b5bec3a commit 71ecc84

File tree

2 files changed

+58
-24
lines changed

2 files changed

+58
-24
lines changed

src/openlayer/lib/core/base_model.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ class OpenlayerModel(abc.ABC):
4242
def run_from_cli(self) -> None:
4343
"""Run the model from the command line."""
4444
parser = argparse.ArgumentParser(description="Run data through a model.")
45-
parser.add_argument("--dataset-path", type=str, required=True, help="Path to the dataset")
45+
parser.add_argument(
46+
"--dataset-path", type=str, required=True, help="Path to the dataset"
47+
)
4648
parser.add_argument(
4749
"--output-dir",
4850
type=str,
@@ -61,14 +63,16 @@ def run_from_cli(self) -> None:
6163
def batch(self, dataset_path: str, output_dir: str) -> None:
6264
"""Reads the dataset from a file and runs the model on it."""
6365
# Load the dataset into a pandas DataFrame
66+
fmt = "csv"
6467
if dataset_path.endswith(".csv"):
6568
df = pd.read_csv(dataset_path)
6669
elif dataset_path.endswith(".json"):
6770
df = pd.read_json(dataset_path, orient="records")
71+
fmt = "json"
6872

6973
# Call the model's run_batch method, passing in the DataFrame
7074
output_df, config = self.run_batch_from_df(df)
71-
self.write_output_to_directory(output_df, config, output_dir)
75+
self.write_output_to_directory(output_df, config, output_dir, fmt)
7276

7377
def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
7478
"""Function that runs the model and returns the result."""
@@ -83,7 +87,9 @@ def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
8387
# Filter row_dict to only include keys that are valid parameters
8488
# for the 'run' method
8589
row_dict = row.to_dict()
86-
filtered_kwargs = {k: v for k, v in row_dict.items() if k in run_signature.parameters}
90+
filtered_kwargs = {
91+
k: v for k, v in row_dict.items() if k in run_signature.parameters
92+
}
8793

8894
# Call the run method with filtered kwargs
8995
output = self.run(**filtered_kwargs)

src/openlayer/lib/core/metrics.py

+49-21
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def __init__(self):
5959
self.config_path: str = ""
6060
self.config: Dict[str, Any] = {}
6161
self.datasets: List[Dataset] = []
62-
self.selected_metrics: Optional[List[str]] = None
62+
self.likely_dir: str = ""
6363

6464
def run_metrics(self, metrics: List[BaseMetric]) -> None:
6565
"""Run a list of metrics."""
@@ -87,30 +87,28 @@ def _parse_args(self) -> None:
8787
type=str,
8888
required=False,
8989
default="",
90-
help="The path to your openlayer.json. Uses working dir if not provided.",
90+
help=(
91+
"The path to your openlayer.json. Uses parent parent dir if not "
92+
"provided (assuming location is metrics/metric_name/run.py)."
93+
),
9194
)
9295

9396
# Parse the arguments
9497
args = parser.parse_args()
9598
self.config_path = args.config_path
99+
self.likely_dir = os.path.dirname(os.path.dirname(os.getcwd()))
96100

97101
def _load_openlayer_json(self) -> None:
98102
"""Load the openlayer.json file."""
99103

100104
if not self.config_path:
101-
openlayer_json_path = os.path.join(os.getcwd(), "openlayer.json")
105+
openlayer_json_path = os.path.join(self.likely_dir, "openlayer.json")
102106
else:
103107
openlayer_json_path = self.config_path
104108

105109
with open(openlayer_json_path, "r", encoding="utf-8") as f:
106110
self.config = json.load(f)
107111

108-
# Extract selected metrics
109-
if "metrics" in self.config and "settings" in self.config["metrics"]:
110-
self.selected_metrics = [
111-
metric["key"] for metric in self.config["metrics"]["settings"] if metric["selected"]
112-
]
113-
114112
def _load_datasets(self) -> None:
115113
"""Compute the metric from the command line."""
116114

@@ -125,20 +123,34 @@ def _load_datasets(self) -> None:
125123
# Read the outputs directory for dataset folders. For each, load
126124
# the config.json and the dataset.json files into a dict and a dataframe
127125

128-
for dataset_folder in os.listdir(output_directory):
126+
full_output_dir = os.path.join(self.likely_dir, output_directory)
127+
128+
for dataset_folder in os.listdir(full_output_dir):
129129
if dataset_folder not in dataset_names:
130130
continue
131-
dataset_path = os.path.join(output_directory, dataset_folder)
131+
dataset_path = os.path.join(full_output_dir, dataset_folder)
132132
config_path = os.path.join(dataset_path, "config.json")
133133
with open(config_path, "r", encoding="utf-8") as f:
134134
dataset_config = json.load(f)
135+
# Merge with the dataset fields from the openlayer.json
136+
dataset_dict = next(
137+
(
138+
item
139+
for item in datasets_list
140+
if item["name"] == dataset_folder
141+
),
142+
None,
143+
)
144+
dataset_config = {**dataset_dict, **dataset_config}
135145

136146
# Load the dataset into a pandas DataFrame
137147
if os.path.exists(os.path.join(dataset_path, "dataset.csv")):
138148
dataset_df = pd.read_csv(os.path.join(dataset_path, "dataset.csv"))
139149
data_format = "csv"
140150
elif os.path.exists(os.path.join(dataset_path, "dataset.json")):
141-
dataset_df = pd.read_json(os.path.join(dataset_path, "dataset.json"), orient="records")
151+
dataset_df = pd.read_json(
152+
os.path.join(dataset_path, "dataset.json"), orient="records"
153+
)
142154
data_format = "json"
143155
else:
144156
raise ValueError(f"No dataset found in {dataset_folder}.")
@@ -153,19 +165,20 @@ def _load_datasets(self) -> None:
153165
)
154166
)
155167
else:
156-
raise ValueError("No model found in the openlayer.json file. Cannot compute metric.")
168+
raise ValueError(
169+
"No model found in the openlayer.json file. Cannot compute metric."
170+
)
157171

158172
if not datasets:
159-
raise ValueError("No datasets found in the openlayer.json file. Cannot compute metric.")
173+
raise ValueError(
174+
"No datasets found in the openlayer.json file. Cannot compute metric."
175+
)
160176

161177
self.datasets = datasets
162178

163179
def _compute_metrics(self, metrics: List[BaseMetric]) -> None:
164180
"""Compute the metrics."""
165181
for metric in metrics:
166-
if self.selected_metrics and metric.key not in self.selected_metrics:
167-
print(f"Skipping metric {metric.key} as it is not a selected metric.")
168-
continue
169182
metric.compute(self.datasets)
170183

171184
def _write_updated_datasets_to_output(self) -> None:
@@ -200,10 +213,14 @@ class BaseMetric(abc.ABC):
200213
Your metric's class should inherit from this class and implement the compute method.
201214
"""
202215

216+
@abc.abstractmethod
217+
def get_key(self) -> str:
218+
"""Return the key of the metric. This should correspond to the folder name."""
219+
pass
220+
203221
@property
204222
def key(self) -> str:
205-
"""Return the key of the metric."""
206-
return self.__class__.__name__
223+
return self.get_key()
207224

208225
def compute(self, datasets: List[Dataset]) -> None:
209226
"""Compute the metric on the model outputs."""
@@ -226,15 +243,26 @@ def compute_on_dataset(self, dataset: Dataset) -> MetricReturn:
226243
"""Compute the metric on a specific dataset."""
227244
pass
228245

229-
def _write_metric_return_to_file(self, metric_return: MetricReturn, output_dir: str) -> None:
246+
def _write_metric_return_to_file(
247+
self, metric_return: MetricReturn, output_dir: str
248+
) -> None:
230249
"""Write the metric return to a file."""
231250

232251
# Create the directory if it doesn't exist
233252
os.makedirs(output_dir, exist_ok=True)
234253

235254
# Turn the metric return to a dict
236255
metric_return_dict = asdict(metric_return)
256+
# Convert the set to a list
257+
metric_return_dict["added_cols"] = list(metric_return.added_cols)
237258

238-
with open(os.path.join(output_dir, f"{self.key}.json"), "w", encoding="utf-8") as f:
259+
with open(
260+
os.path.join(output_dir, f"{self.key}.json"), "w", encoding="utf-8"
261+
) as f:
239262
json.dump(metric_return_dict, f, indent=4)
240263
print(f"Metric ({self.key}) value written to {output_dir}/{self.key}.json")
264+
265+
def run(self) -> None:
266+
"""Run the metric."""
267+
metric_runner = MetricRunner()
268+
metric_runner.run_metrics([self])

0 commit comments

Comments
 (0)