Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
LMMasters committed Feb 20, 2025
1 parent 40e6a0d commit 596057a
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions llm_extractinator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def run_tasks(self) -> None:
f"Running short cases with max_context_len: {self.config.max_context_len}"
)

self._run_with_manager()
self.short_path = self._run_with_manager()

# Run long cases
self.config.max_context_len = self.data_loader.get_max_input_tokens(
Expand All @@ -174,9 +174,9 @@ def run_tasks(self) -> None:
f"Running long cases with max_context_len: {self.config.max_context_len}"
)

self._run_with_manager()

self.long_path = self._run_with_manager()
self._combine_results()

elif self.config.max_context_len == "max":
self.config.max_context_len = self.data_loader.get_max_input_tokens(
df=self.test,
Expand All @@ -190,11 +190,11 @@ def run_tasks(self) -> None:
f"Running cases with max_context_len: {self.config.max_context_len}"
)

self._run_with_manager()
_ = self._run_with_manager()
else:
self.config.train = self.train
self.config.test = self.test
self._run_with_manager()
_ = self._run_with_manager()

total_time = timedelta(seconds=time.time() - start_time)
logging.info(f"Task execution completed in {total_time}")
Expand All @@ -203,8 +203,9 @@ def _run_with_manager(self) -> None:
"""Runs the task with a managed Ollama server."""
with OllamaServerManager() as manager:
manager.pull_model(self.config.model_name)
_ = self._run_task()
filepath = self._run_task()
manager.stop(self.config.model_name)
return filepath

def _run_task(self) -> bool:
"""Executes a single prediction task in parallel."""
Expand Down Expand Up @@ -269,14 +270,14 @@ def _combine_results(self) -> None:
"""
try:
logging.info("Combining results from short and long cases.")
if not self.short_paths:
if not self.short_path:
raise ValueError(
"No paths found for short cases. Something went wrong."
)
if not self.long_paths:
if not self.long_path:
raise ValueError("No paths found for long cases. Something went wrong.")

for short_path, long_path in zip(self.short_paths, self.long_paths):
for short_path, long_path in zip(self.short_path, self.long_path):
short_df = pd.read_json(short_path, orient="records")
long_df = pd.read_json(long_path, orient="records")
combined_df = pd.concat([short_df, long_df], ignore_index=True)
Expand All @@ -288,7 +289,7 @@ def _combine_results(self) -> None:
)

# Remove the individual files
for short_path, long_path in zip(self.short_paths, self.long_paths):
for short_path, long_path in zip(self.short_path, self.long_path):
short_path.unlink()
long_path.unlink()

Expand Down

0 comments on commit 596057a

Please sign in to comment.