Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions src/custodian/custodian.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,51 +456,47 @@ def _run_job(self, job_n, job) -> None:
job.setup(self.directory)

attempt = 0
process = None # Initialize p to None in case the while loop never executes
while self.total_errors < self.max_errors and self.errors_current_job < self.max_errors_per_job:
attempt += 1
logger.info(
f"Starting job no. {job_n} ({job.name}) attempt no. {attempt}. Total errors and "
f"errors in job thus far = {self.total_errors}, {self.errors_current_job}."
)

p = job.run(directory=self.directory)
process = job.run(directory=self.directory)
# Check for errors using the error handlers and perform
# corrections.
has_error = False
zero_return_code = True

# Choose the terminate function to run. If a terminate_func exists, this
# should take priority, followed by Job.terminate if implemented, and finally
# subprocess.Popen.terminate if neither of the former exist.
terminate = self.terminate_func or job.terminate or p.terminate

# While the job is running, we use the handlers that are
# monitors to monitor the job.
if isinstance(p, subprocess.Popen):
if isinstance(process, subprocess.Popen):
if self.monitors:
n = 0
poll_idx = 0
while True:
n += 1
poll_idx += 1
time.sleep(self.polling_time_step)
# We poll the process p to check if it is still running.
# Note that the process here is not the actual calculation
# but whatever is used to control the execution of the
# calculation executable. For instance; mpirun, srun, and so on.
if p.poll() is not None:
if process.poll() is not None:
break
if n % self.monitor_freq == 0:
if poll_idx % self.monitor_freq == 0:
# At every self.polling_time_step * self.monitor_freq seconds,
# we check the job for errors using handlers that are monitors.
# In order to properly kill a running calculation, we use
# the appropriate implementation of terminate.
has_error = self._do_check(self.monitors, terminate)
has_error = self._do_check(self.monitors, process, job)
else:
p.wait()
if self.terminate_func is not None and self.terminate_func != p.terminate:
process.wait()
if self.terminate_func is not None and self.terminate_func != process.terminate:
self.terminate_func()
time.sleep(self.polling_time_step)

zero_return_code = p.returncode == 0
zero_return_code = process.returncode == 0

logger.info(f"{job.name}.run has completed. Checking remaining handlers")
# Check for errors again, since in some cases non-monitor
Expand All @@ -522,7 +518,7 @@ def _run_job(self, job_n, job) -> None:
if not zero_return_code:
if self.terminate_on_nonzero_returncode:
self.run_log[-1]["nonzero_return_code"] = True
msg = f"Job return code is {p.returncode}. Terminating..."
msg = f"Job return code is {process.returncode}. Terminating..."
logger.info(msg)
raise ReturnCodeError(msg, raises=True)
warnings.warn("subprocess returned a non-zero return code. Check outputs carefully...")
Expand All @@ -541,6 +537,11 @@ def _run_job(self, job_n, job) -> None:
msg = f"Unrecoverable error for handler: {corr['handler']}"
raise NonRecoverableError(msg, raises=False, handler=corr["handler"])

# Terminate any running process before raising max errors exceptions
if process is not None:
logger.warning("Max errors threshold reached.")
self._terminate_process(process, job)

if self.errors_current_job >= self.max_errors_per_job:
self.run_log[-1]["max_errors_per_job"] = True
msg = f"Max errors per job reached: {self.max_errors_per_job}."
Expand Down Expand Up @@ -653,8 +654,30 @@ def run_interrupted(self):
gzip_dir(self.directory)
return None

def _do_check(self, handlers, terminate_func=None):
"""Checks the specified handlers. Returns True iff errors caught."""
def _terminate_process(self, process, job) -> None:
"""Terminate a running subprocess using the job's terminate method or fallback."""
if not isinstance(process, subprocess.Popen) or process.poll() is not None:
return # Not a process or already finished

logger.warning("Terminating running process.")
terminate = self.terminate_func or job.terminate or process.terminate

try:
if terminate != process.terminate:
terminate(directory=self.directory)
else:
terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
logger.warning("Process did not terminate gracefully, force killing")
process.kill()
process.wait()
except Exception:
logger.exception("Error terminating process")

def _do_check(self, handlers, process=None, job=None):
"""Check handlers and return True if errors were caught."""
corrections = []
for handler in handlers:
try:
Expand All @@ -672,11 +695,11 @@ def _do_check(self, handlers, terminate_func=None):
)
logger.warning(f"{msg} Correction not applied.")
continue
if terminate_func is not None and handler.is_terminating:
if process is not None and job is not None and handler.is_terminating:
logger.info("Terminating job")
terminate_func(directory=self.directory)
# make sure we don't terminate twice
terminate_func = None
self._terminate_process(process, job)
# Make sure we don't terminate twice
process = None
dct = handler.correct(directory=self.directory)
logger.error(type(handler).__name__, extra=dct)
dct["handler"] = handler
Expand Down
71 changes: 71 additions & 0 deletions tests/test_custodian.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import random
import subprocess
import time
import unittest
from glob import glob

Expand Down Expand Up @@ -133,6 +134,49 @@ def check(self, directory="./") -> bool:
return True


class LongRunningJob(Job):
"""A job that spawns a long-running subprocess to test process termination."""

def __init__(self) -> None:
self.process = None

def setup(self, directory="./") -> None:
pass

def run(self, directory="./"):
"""Spawn a long-running sleep process."""
# Use sleep command to simulate a long-running VASP job
self.process = subprocess.Popen(
["sleep", "300"], # Sleep for 5 minutes
cwd=directory,
start_new_session=True,
)
return self.process

def postprocess(self, directory="./") -> None:
pass

def terminate(self, directory="./") -> None:
"""Kill the process and all its children."""
if self.process and self.process.poll() is None:
self.process.terminate()
try:
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.process.kill()
self.process.wait()


class AlwaysFailingHandler(ErrorHandler):
"""Handler that always detects an error to trigger max_errors_per_job."""

def check(self, directory="./") -> bool:
return True

def correct(self, directory="./"):
return {"errors": "simulated error", "actions": "simulated correction"}


class CustodianTest(unittest.TestCase):
def setUp(self) -> None:
self.cwd = os.getcwd()
Expand Down Expand Up @@ -223,6 +267,33 @@ def test_max_errors_per_job(self) -> None:
c.run()
assert c.run_log[-1]["max_errors_per_job"]

def test_max_errors_per_job_terminates_process(self) -> None:
"""Test that processes are properly terminated when max_errors_per_job is reached."""
job = LongRunningJob()
handler = AlwaysFailingHandler()
c = Custodian(
[handler],
[job],
max_errors=10,
max_errors_per_job=2,
polling_time_step=1,
)

# Run custodian and expect it to raise MaxCorrectionsPerJobError
with pytest.raises(MaxCorrectionsPerJobError):
c.run()

# Verify the max_errors_per_job flag is set
assert c.run_log[-1]["max_errors_per_job"]

# Give the process a moment to fully terminate
time.sleep(0.5)

# Verify the process was actually terminated (not a zombie)
# If the process is still running, poll() will return None
assert job.process is not None
assert job.process.poll() is not None, "Process should be terminated, not left as zombie"

def test_max_errors_per_handler_raise(self) -> None:
n_jobs = 100
params = {"initial": 0, "total": 0}
Expand Down
Loading