Skip to content

Commit 786bafd

Browse files
committed
fix: custom args with run_batch_from_df
1 parent 74e74bc commit 786bafd

File tree

2 files changed

+3
-11
lines changed

2 files changed

+3
-11
lines changed

src/openlayer/lib/core/base_model.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,10 @@ def batch(self, dataset_path: str, output_dir: str) -> None:
9191
raise ValueError(f"Unsupported dataset format: {dataset_path}")
9292

9393
# Call the model's run_batch method, passing in the DataFrame
94-
output_df, config = self.run_batch_from_df(df, custom_args=self.custom_args)
94+
output_df, config = self.run_batch_from_df(df)
9595
self.write_output_to_directory(output_df, config, output_dir, fmt)
9696

97-
def run_batch_from_df(
98-
self, df: pd.DataFrame, custom_args: dict = None
99-
) -> Tuple[pd.DataFrame, dict]:
97+
def run_batch_from_df(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
10098
"""Function that runs the model and returns the result."""
10199
# Ensure the 'output' column exists
102100
if "output" not in df.columns:
@@ -105,10 +103,6 @@ def run_batch_from_df(
105103
# Get the signature of the 'run' method
106104
run_signature = inspect.signature(self.run)
107105

108-
# If the model has a custom_args attribute, update it
109-
if hasattr(self, "custom_args") and custom_args is not None:
110-
self.custom_args.update(custom_args)
111-
112106
for index, row in df.iterrows():
113107
# Filter row_dict to only include keys that are valid parameters
114108
# for the 'run' method

src/openlayer/lib/integrations/bedrock_tracer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525
logger = logging.getLogger(__name__)
2626

2727

28-
def trace_bedrock(
29-
client: "boto3.client",
30-
) -> "boto3.client":
28+
def trace_bedrock(client: "boto3.client") -> "boto3.client":
3129
"""Patch the Bedrock client to trace model invocations.
3230
3331
The following information is collected for each model invocation:

0 commit comments

Comments
 (0)