diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py index 5eba28271..e7fefc2b4 100644 --- a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py @@ -7,7 +7,11 @@ import logging import os import sys +import time from pathlib import Path +from threading import Thread +import subprocess +import signal import requests from flask import Response, jsonify, request @@ -26,6 +30,7 @@ ModelInfoKeys, RunLanguage, TargetType, + URL_PREFIX_ENV_VAR_NAME, ) from datarobot_drum.drum.exceptions import DrumCommonException from datarobot_drum.drum.model_metadata import read_model_metadata_yaml @@ -81,6 +86,7 @@ def __init__(self, params: dict): "run_predictor_total", "finish", StatsOperation.SUB, "start" ) self._predictor = self._setup_predictor() + self._server_watchdog = None def _setup_predictor(self): if self._run_language == RunLanguage.PYTHON: @@ -322,6 +328,18 @@ def _run_flask_app(self, app): processes = self._params.get("processes") logger.info("Number of webserver processes: %s", processes) try: + if RuntimeParameters.has("USE_NIM_WATCHDOG") and str( + RuntimeParameters.get("USE_NIM_WATCHDOG") + ).lower() in ["true", "1", "yes"]: + # Start the watchdog thread before running the app + self._server_watchdog = Thread( + target=self.watchdog, + args=(port,), + daemon=True, + name="NIM Sidecar Watchdog", + ) + self._server_watchdog.start() + # Configure the server with timeout settings app.run( host=host, @@ -337,6 +355,98 @@ def _run_flask_app(self, app): except OSError as e: raise DrumCommonException("{}: host: {}; port: {}".format(e, host, port)) + def _kill_all_processes(self): + """ + Forcefully terminates all running processes related to the server. + Attempts a clean termination first, then uses system commands to kill remaining processes. + Logs errors encountered during termination. + """ + + logger.error("All health check attempts failed. Forcefully killing all processes.") + + # First try clean termination + try: + self._terminate() + except Exception as e: + logger.error(f"Error during clean termination: {str(e)}") + + # Use more direct system commands to kill processes + try: + # Kill packedge jobs first (more aggressive approach) + logger.info("Killing Python package jobs") + # Run `busybox ps` and capture output + result = subprocess.run(["busybox", "ps"], capture_output=True, text=True) + # Parse lines, skip the header + lines = result.stdout.strip().split("\n")[1:] + # Extract the PID (first column) + pids = [int(line.split()[0]) for line in lines] + for pid in pids: + print("Killing pid:", pid) + os.kill(pid, signal.SIGTERM) + except Exception as kill_error: + logger.error(f"Error during process killing: {str(kill_error)}") + + def watchdog(self, port): + """ + Watchdog thread that periodically checks if the server is alive by making + GET requests to the /info/ endpoint. Makes 3 attempts with quadratic backoff + before terminating the Flask app. + """ + + logger.info("Starting watchdog to monitor server health...") + + import os + + url_host = os.environ.get("TEST_URL_HOST", "localhost") + url_prefix = os.environ.get(URL_PREFIX_ENV_VAR_NAME, "") + health_url = f"http://{url_host}:{port}{url_prefix}/info/" + + request_timeout = 120 + if RuntimeParameters.has("NIM_WATCHDOG_REQUEST_TIMEOUT"): + try: + request_timeout = int(RuntimeParameters.get("NIM_WATCHDOG_REQUEST_TIMEOUT")) + except ValueError: + logger.warning( + "Invalid value for NIM_WATCHDOG_REQUEST_TIMEOUT, using default of 120 seconds" + ) + logger.info("Nim watchdog health check request timeout is %s", request_timeout) + check_interval = 10 # seconds + max_attempts = 3 + if RuntimeParameters.has("NIM_WATCHDOG_MAX_ATTEMPTS"): + try: + max_attempts = int(RuntimeParameters.get("NIM_WATCHDOG_MAX_ATTEMPTS")) + except ValueError: + logger.warning("Invalid value for NIM_WATCHDOG_MAX_ATTEMPTS, using default of 3") + logger.info("Nim watchdog max attempts: %s", max_attempts) + attempt = 0 + base_sleep_time = 4 + + while True: + try: + # Check if server is responding to health checks + logger.debug(f"Server health check") + response = requests.get(health_url, timeout=request_timeout) + logger.debug(f"Server health check status: {response.status_code}") + # Connection succeeded, reset attempts and wait for next check + attempt = 0 + time.sleep(check_interval) # Regular check interval + continue + + except Exception as e: + attempt += 1 + logger.warning(f"health_url {health_url}") + logger.warning( + f"Server health check failed (attempt {attempt}/{max_attempts}): {str(e)}" + ) + + if attempt >= max_attempts: + self._kill_all_processes() + + # Quadratic backoff + sleep_time = base_sleep_time * (attempt**2) + logger.info(f"Retrying in {sleep_time} seconds...") + time.sleep(sleep_time) + def terminate(self): terminate_op = getattr(self._predictor, "terminate", None) if callable(terminate_op):