From 7fc3800b55dcd30bd68a9f9253d9b483fe03fb5c Mon Sep 17 00:00:00 2001 From: janosh Date: Tue, 28 Oct 2025 13:31:57 -0700 Subject: [PATCH 1/2] Fix zombie VASP processes when max_errors thresholds are reached - Add process termination logic before raising MaxCorrectionsPerJobError and MaxCorrectionsError - Ensures running processes are properly killed using the multi-node compatible approach from PR #396 - Prevents orphaned VASP processes in both single-node and multi-node setups - Add comprehensive test case to verify process termination behavior --- src/custodian/custodian.py | 22 ++++++++++++ tests/test_custodian.py | 71 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/src/custodian/custodian.py b/src/custodian/custodian.py index f1b8b237..ae9715b4 100644 --- a/src/custodian/custodian.py +++ b/src/custodian/custodian.py @@ -456,6 +456,7 @@ def _run_job(self, job_n, job) -> None: job.setup(self.directory) attempt = 0 + p = None # Initialize p to None in case the while loop never executes while self.total_errors < self.max_errors and self.errors_current_job < self.max_errors_per_job: attempt += 1 logger.info( @@ -541,6 +542,27 @@ def _run_job(self, job_n, job) -> None: msg = f"Unrecoverable error for handler: {corr['handler']}" raise NonRecoverableError(msg, raises=False, handler=corr["handler"]) + # Terminate any running process before raising max errors exceptions + if isinstance(p, subprocess.Popen) and p.poll() is None: + logger.warning("Max errors threshold reached. Terminating running process.") + terminate = self.terminate_func or job.terminate or p.terminate + try: + # Call terminate with directory parameter if it's not the default Popen terminate + if terminate != p.terminate: + terminate(directory=self.directory) + else: + terminate() + # Wait briefly for process to terminate + if hasattr(p, "wait"): + try: + p.wait(timeout=10) + except subprocess.TimeoutExpired: + logger.warning("Process did not terminate gracefully, force killing") + p.kill() + p.wait() + except Exception: + logger.exception("Error terminating process") + if self.errors_current_job >= self.max_errors_per_job: self.run_log[-1]["max_errors_per_job"] = True msg = f"Max errors per job reached: {self.max_errors_per_job}." diff --git a/tests/test_custodian.py b/tests/test_custodian.py index 24307963..77a52600 100644 --- a/tests/test_custodian.py +++ b/tests/test_custodian.py @@ -1,6 +1,7 @@ import os import random import subprocess +import time import unittest from glob import glob @@ -133,6 +134,49 @@ def check(self, directory="./") -> bool: return True +class LongRunningJob(Job): + """A job that spawns a long-running subprocess to test process termination.""" + + def __init__(self) -> None: + self.process = None + + def setup(self, directory="./") -> None: + pass + + def run(self, directory="./"): + """Spawn a long-running sleep process.""" + # Use sleep command to simulate a long-running VASP job + self.process = subprocess.Popen( + ["sleep", "300"], # Sleep for 5 minutes + cwd=directory, + start_new_session=True, + ) + return self.process + + def postprocess(self, directory="./") -> None: + pass + + def terminate(self, directory="./") -> None: + """Kill the process and all its children.""" + if self.process and self.process.poll() is None: + self.process.terminate() + try: + self.process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.process.kill() + self.process.wait() + + +class AlwaysFailingHandler(ErrorHandler): + """Handler that always detects an error to trigger max_errors_per_job.""" + + def check(self, directory="./") -> bool: + return True + + def correct(self, directory="./"): + return {"errors": "simulated error", "actions": "simulated correction"} + + class CustodianTest(unittest.TestCase): def setUp(self) -> None: self.cwd = os.getcwd() @@ -223,6 +267,33 @@ def test_max_errors_per_job(self) -> None: c.run() assert c.run_log[-1]["max_errors_per_job"] + def test_max_errors_per_job_terminates_process(self) -> None: + """Test that processes are properly terminated when max_errors_per_job is reached.""" + job = LongRunningJob() + handler = AlwaysFailingHandler() + c = Custodian( + [handler], + [job], + max_errors=10, + max_errors_per_job=2, + polling_time_step=1, + ) + + # Run custodian and expect it to raise MaxCorrectionsPerJobError + with pytest.raises(MaxCorrectionsPerJobError): + c.run() + + # Verify the max_errors_per_job flag is set + assert c.run_log[-1]["max_errors_per_job"] + + # Give the process a moment to fully terminate + time.sleep(0.5) + + # Verify the process was actually terminated (not a zombie) + # If the process is still running, poll() will return None + assert job.process is not None + assert job.process.poll() is not None, "Process should be terminated, not left as zombie" + def test_max_errors_per_handler_raise(self) -> None: n_jobs = 100 params = {"initial": 0, "total": 0} From fee3e6203927eccfcc80367bccc5805a884a01c1 Mon Sep 17 00:00:00 2001 From: janosh Date: Tue, 28 Oct 2025 14:30:38 -0700 Subject: [PATCH 2/2] add _terminate_process() on class Custodian - Refactor _do_check() to use _terminate_process() instead of passing terminate_func - DRY: centralizes process termination logic in one place - uses multi-node compatible approach from PR #396 --- src/custodian/custodian.py | 85 +++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 42 deletions(-) diff --git a/src/custodian/custodian.py b/src/custodian/custodian.py index ae9715b4..39556769 100644 --- a/src/custodian/custodian.py +++ b/src/custodian/custodian.py @@ -456,7 +456,7 @@ def _run_job(self, job_n, job) -> None: job.setup(self.directory) attempt = 0 - p = None # Initialize p to None in case the while loop never executes + process = None # Initialize p to None in case the while loop never executes while self.total_errors < self.max_errors and self.errors_current_job < self.max_errors_per_job: attempt += 1 logger.info( @@ -464,44 +464,39 @@ def _run_job(self, job_n, job) -> None: f"errors in job thus far = {self.total_errors}, {self.errors_current_job}." ) - p = job.run(directory=self.directory) + process = job.run(directory=self.directory) # Check for errors using the error handlers and perform # corrections. has_error = False zero_return_code = True - # Choose the terminate function to run. If a terminate_func exists, this - # should take priority, followed by Job.terminate if implemented, and finally - # subprocess.Popen.terminate if neither of the former exist. - terminate = self.terminate_func or job.terminate or p.terminate - # While the job is running, we use the handlers that are # monitors to monitor the job. - if isinstance(p, subprocess.Popen): + if isinstance(process, subprocess.Popen): if self.monitors: - n = 0 + poll_idx = 0 while True: - n += 1 + poll_idx += 1 time.sleep(self.polling_time_step) # We poll the process p to check if it is still running. # Note that the process here is not the actual calculation # but whatever is used to control the execution of the # calculation executable. For instance; mpirun, srun, and so on. - if p.poll() is not None: + if process.poll() is not None: break - if n % self.monitor_freq == 0: + if poll_idx % self.monitor_freq == 0: # At every self.polling_time_step * self.monitor_freq seconds, # we check the job for errors using handlers that are monitors. # In order to properly kill a running calculation, we use # the appropriate implementation of terminate. - has_error = self._do_check(self.monitors, terminate) + has_error = self._do_check(self.monitors, process, job) else: - p.wait() - if self.terminate_func is not None and self.terminate_func != p.terminate: + process.wait() + if self.terminate_func is not None and self.terminate_func != process.terminate: self.terminate_func() time.sleep(self.polling_time_step) - zero_return_code = p.returncode == 0 + zero_return_code = process.returncode == 0 logger.info(f"{job.name}.run has completed. Checking remaining handlers") # Check for errors again, since in some cases non-monitor @@ -523,7 +518,7 @@ def _run_job(self, job_n, job) -> None: if not zero_return_code: if self.terminate_on_nonzero_returncode: self.run_log[-1]["nonzero_return_code"] = True - msg = f"Job return code is {p.returncode}. Terminating..." + msg = f"Job return code is {process.returncode}. Terminating..." logger.info(msg) raise ReturnCodeError(msg, raises=True) warnings.warn("subprocess returned a non-zero return code. Check outputs carefully...") @@ -543,25 +538,9 @@ def _run_job(self, job_n, job) -> None: raise NonRecoverableError(msg, raises=False, handler=corr["handler"]) # Terminate any running process before raising max errors exceptions - if isinstance(p, subprocess.Popen) and p.poll() is None: - logger.warning("Max errors threshold reached. Terminating running process.") - terminate = self.terminate_func or job.terminate or p.terminate - try: - # Call terminate with directory parameter if it's not the default Popen terminate - if terminate != p.terminate: - terminate(directory=self.directory) - else: - terminate() - # Wait briefly for process to terminate - if hasattr(p, "wait"): - try: - p.wait(timeout=10) - except subprocess.TimeoutExpired: - logger.warning("Process did not terminate gracefully, force killing") - p.kill() - p.wait() - except Exception: - logger.exception("Error terminating process") + if process is not None: + logger.warning("Max errors threshold reached.") + self._terminate_process(process, job) if self.errors_current_job >= self.max_errors_per_job: self.run_log[-1]["max_errors_per_job"] = True @@ -675,8 +654,30 @@ def run_interrupted(self): gzip_dir(self.directory) return None - def _do_check(self, handlers, terminate_func=None): - """Checks the specified handlers. Returns True iff errors caught.""" + def _terminate_process(self, process, job) -> None: + """Terminate a running subprocess using the job's terminate method or fallback.""" + if not isinstance(process, subprocess.Popen) or process.poll() is not None: + return # Not a process or already finished + + logger.warning("Terminating running process.") + terminate = self.terminate_func or job.terminate or process.terminate + + try: + if terminate != process.terminate: + terminate(directory=self.directory) + else: + terminate() + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + logger.warning("Process did not terminate gracefully, force killing") + process.kill() + process.wait() + except Exception: + logger.exception("Error terminating process") + + def _do_check(self, handlers, process=None, job=None): + """Check handlers and return True if errors were caught.""" corrections = [] for handler in handlers: try: @@ -694,11 +695,11 @@ def _do_check(self, handlers, terminate_func=None): ) logger.warning(f"{msg} Correction not applied.") continue - if terminate_func is not None and handler.is_terminating: + if process is not None and job is not None and handler.is_terminating: logger.info("Terminating job") - terminate_func(directory=self.directory) - # make sure we don't terminate twice - terminate_func = None + self._terminate_process(process, job) + # Make sure we don't terminate twice + process = None dct = handler.correct(directory=self.directory) logger.error(type(handler).__name__, extra=dct) dct["handler"] = handler