diff --git a/custom_model_runner/datarobot_drum/drum/drum.py b/custom_model_runner/datarobot_drum/drum/drum.py index 37df6da5e..748b98dae 100644 --- a/custom_model_runner/datarobot_drum/drum/drum.py +++ b/custom_model_runner/datarobot_drum/drum/drum.py @@ -77,8 +77,12 @@ class CMRunner: - def __init__(self, runtime): + def __init__(self, runtime, flask_app=None, worker_ctx=None): self.runtime = runtime + self.flask_app = ( + flask_app # This is the Flask app object, used when running the application via CLI + ) + self.worker_ctx = worker_ctx # This is the Gunicorn worker context object (WorkerCtx) self.options = runtime.options self.options.model_config = read_model_metadata_yaml(self.options.code_dir) self.options.default_parameter_values = ( @@ -497,8 +501,18 @@ def run(self): with self._setup_output_if_not_exists(): self._run_predictions(stats_collector) finally: - if stats_collector: - stats_collector.disable() + if self.worker_ctx: + # Perform cleanup specific to the Gunicorn worker being terminated. + # Gunicorn spawns multiple worker processes to handle requests. Each worker has its own context, + # and this ensures that only the resources associated with the current worker are released. + # defer_cleanup simply saves methods to be executed during worker restart or shutdown. + # More details in https://github.com/datarobot/datarobot-custom-templates/pull/419 + self.worker_ctx.defer_cleanup( + lambda: stats_collector.disable(), desc="stats_collector.disable()" + ) + else: + if stats_collector: + stats_collector.disable() if stats_collector: stats_collector.print_reports() elif self.run_mode == RunMode.SERVER: @@ -826,7 +840,7 @@ def _run_predictions(self, stats_collector: Optional[StatsCollector] = None): if stats_collector: stats_collector.mark("start") predictor = ( - PredictionServer(params) + PredictionServer(params, self.flask_app) if self.run_mode == RunMode.SERVER else GenericPredictorComponent(params) ) @@ -836,16 +850,39 @@ def _run_predictions(self, stats_collector: Optional[StatsCollector] = None): if stats_collector: stats_collector.mark("run") finally: - if predictor is not None: - predictor.terminate() - if stats_collector: - stats_collector.mark("end") + if self.worker_ctx: + # Perform cleanup specific to the Gunicorn worker being terminated. + # Gunicorn spawns multiple worker processes to handle requests. Each worker has its own context, + # and this ensures that only the resources associated with the current worker are released. + # defer_cleanup simply saves methods to be executed during worker restart or shutdown. + # More details in https://github.com/datarobot/datarobot-custom-templates/pull/419 + if predictor is not None: + self.worker_ctx.defer_cleanup( + lambda: predictor.terminate(), desc="predictor.terminate()" + ) + if stats_collector: + self.worker_ctx.defer_cleanup( + lambda: stats_collector.mark("end"), desc="stats_collector.mark('end')" + ) + self.worker_ctx.defer_cleanup( + lambda: self.logger.info( + "<<< Finish {} in the {} mode".format( + ArgumentsOptions.MAIN_COMMAND, self.run_mode.value + ) + ), + desc="logger.info(...)", + ) - self.logger.info( - "<<< Finish {} in the {} mode".format( - ArgumentsOptions.MAIN_COMMAND, self.run_mode.value - ) - ) + else: + if predictor is not None: + predictor.terminate() + if stats_collector: + stats_collector.mark("end") + self.logger.info( + "<<< Finish {} in the {} mode".format( + ArgumentsOptions.MAIN_COMMAND, self.run_mode.value + ) + ) @contextlib.contextmanager def _setup_output_if_not_exists(self): diff --git a/custom_model_runner/datarobot_drum/drum/main.py b/custom_model_runner/datarobot_drum/drum/main.py index 94aa0e5b5..a6522329b 100644 --- a/custom_model_runner/datarobot_drum/drum/main.py +++ b/custom_model_runner/datarobot_drum/drum/main.py @@ -54,10 +54,50 @@ ) -def main(): - with DrumRuntime() as runtime: +def main(flask_app=None, worker_ctx=None): + """ + The main entry point for the custom model runner. + + This function initializes the runtime environment, sets up logging, handles + signal interruptions, and starts the CMRunner for executing user-defined models. + + Args: + flask_app: Optional[Flask] Flask application instance, used when running using command line. + worker_ctx: Optional gunicorn worker context (WorkerCtx), used for managing cleanup tasks in a + multi-worker setup (e.g., Gunicorn). + + Returns: + None + """ + with DrumRuntime(flask_app) as runtime: config_logging() + if worker_ctx: + # Perform cleanup specific to the Gunicorn worker being terminated. + # Gunicorn spawns multiple worker processes to handle requests. Each worker has its own context, + # and this ensures that only the resources associated with the current worker are released. + # defer_cleanup simply saves methods to be executed during worker restart or shutdown. + # More details in https://github.com/datarobot/datarobot-custom-templates/pull/419 + if runtime.options and RunMode(runtime.options.subparser_name) == RunMode.SERVER: + if runtime.cm_runner: + worker_ctx.defer_cleanup( + lambda: runtime.cm_runner.terminate(), desc="runtime.cm_runner.terminate()" + ) + if runtime.trace_provider is not None: + worker_ctx.defer_cleanup( + lambda: runtime.trace_provider.shutdown(), + desc="runtime.trace_provider.shutdown()", + ) + if runtime.metric_provider is not None: + worker_ctx.defer_cleanup( + lambda: runtime.metric_provider.shutdown(), + desc="runtime.metric_provider.shutdown()", + ) + if runtime.log_provider is not None: + worker_ctx.defer_cleanup( + lambda: runtime.log_provider.shutdown(), desc="runtime.log_provider.shutdown()" + ) + def signal_handler(sig, frame): # The signal is assigned so the stacktrace is not presented when Ctrl-C is pressed. # The cleanup itself is done only if we are NOT running in performance test mode which @@ -89,13 +129,14 @@ def signal_handler(sig, frame): runtime.metric_provider = metric_provider runtime.log_provider = log_provider - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) + if worker_ctx is None: + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) from datarobot_drum.drum.drum import CMRunner try: - runtime.cm_runner = CMRunner(runtime) + runtime.cm_runner = CMRunner(runtime, flask_app, worker_ctx) runtime.cm_runner.run() except DrumSchemaValidationException: sys.exit(ExitCodes.SCHEMA_VALIDATION_ERROR.value) diff --git a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py index e7fefc2b4..5758d6bf9 100644 --- a/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py +++ b/custom_model_runner/datarobot_drum/drum/root_predictors/prediction_server.py @@ -68,8 +68,11 @@ class TimeoutWSGIRequestHandler(WSGIRequestHandler): class PredictionServer(PredictMixin): - def __init__(self, params: dict): + def __init__(self, params: dict, flask_app=None): self._params = params + self.flask_app = ( + flask_app # This is the Flask app object, used when running the application via CLI + ) self._show_perf = self._params.get("show_perf") self._resource_monitor = ResourceMonitor(monitor_current_process=True) self._run_language = RunLanguage(params.get("run_language")) @@ -310,7 +313,7 @@ def handle_exception(e): cli = sys.modules["flask.cli"] cli.show_server_banner = lambda *x: None - app = get_flask_app(model_api) + app = get_flask_app(model_api, self.flask_app) self.load_flask_extensions(app) self._run_flask_app(app) @@ -319,6 +322,19 @@ def handle_exception(e): return [] + def is_client_request_timeout_enabled(self): + if ( + RuntimeParameters.has("DRUM_CLIENT_REQUEST_TIMEOUT") + and int(RuntimeParameters.get("DRUM_CLIENT_REQUEST_TIMEOUT")) > 0 + ): + logger.info( + "Client request timeout is enabled, timeout: %s", + str(int(TimeoutWSGIRequestHandler.timeout)), + ) + return True + else: + return False + def _run_flask_app(self, app): host = self._params.get("host", None) port = self._params.get("port", None) @@ -328,30 +344,34 @@ def _run_flask_app(self, app): processes = self._params.get("processes") logger.info("Number of webserver processes: %s", processes) try: - if RuntimeParameters.has("USE_NIM_WATCHDOG") and str( - RuntimeParameters.get("USE_NIM_WATCHDOG") - ).lower() in ["true", "1", "yes"]: - # Start the watchdog thread before running the app - self._server_watchdog = Thread( - target=self.watchdog, - args=(port,), - daemon=True, - name="NIM Sidecar Watchdog", + if self.flask_app: + # when running application via the command line (e.g., gunicorn worker) + pass + else: + if RuntimeParameters.has("USE_NIM_WATCHDOG") and str( + RuntimeParameters.get("USE_NIM_WATCHDOG") + ).lower() in ["true", "1", "yes"]: + # Start the watchdog thread before running the app + self._server_watchdog = Thread( + target=self.watchdog, + args=(port,), + daemon=True, + name="NIM Sidecar Watchdog", + ) + self._server_watchdog.start() + + # Configure the server with timeout settings + app.run( + host=host, + port=port, + threaded=False, + processes=processes, + **( + {"request_handler": TimeoutWSGIRequestHandler} + if self.is_client_request_timeout_enabled() + else {} + ), ) - self._server_watchdog.start() - - # Configure the server with timeout settings - app.run( - host=host, - port=port, - threaded=False, - processes=processes, - **( - {"request_handler": TimeoutWSGIRequestHandler} - if RuntimeParameters.has("DRUM_CLIENT_REQUEST_TIMEOUT") - else {} - ), - ) except OSError as e: raise DrumCommonException("{}: host: {}; port: {}".format(e, host, port)) diff --git a/custom_model_runner/datarobot_drum/drum/runtime.py b/custom_model_runner/datarobot_drum/drum/runtime.py index ea3f197b1..ebef0e0a4 100644 --- a/custom_model_runner/datarobot_drum/drum/runtime.py +++ b/custom_model_runner/datarobot_drum/drum/runtime.py @@ -5,6 +5,7 @@ Released under the terms of DataRobot Tool and Utility Agreement. """ import logging +from typing import Optional from datarobot_drum.drum.server import ( empty_api_blueprint, @@ -13,6 +14,7 @@ ) from datarobot_drum.drum.common import verbose_stdout, get_drum_logger from datarobot_drum.drum.enum import LOGGER_NAME_PREFIX, RunMode +from flask import Flask from datarobot_drum.drum.exceptions import DrumCommonException @@ -23,7 +25,7 @@ class DrumRuntime: - def __init__(self): + def __init__(self, flask_app: Optional[Flask] = None): self.initialization_succeeded = False self.options = None self.cm_runner = None @@ -31,6 +33,9 @@ def __init__(self): self.trace_provider = None self.metric_provider = None self.log_provider = None + self.flask_app = ( + flask_app # This is the Flask app object, used when running the application via CLI + ) def __enter__(self): return self @@ -83,12 +88,12 @@ def __exit__(self, exc_type, exc_value, exc_traceback): port = int(host_port_list[1]) if len(host_port_list) == 2 else None with verbose_stdout(self.options.verbose): - run_error_server(host, port, exc_value) + run_error_server(host, port, exc_value, self.flask_app) return False # propagate exception further -def run_error_server(host, port, exc_value): +def run_error_server(host, port, exc_value, flask_app: Optional[Flask] = None): model_api = empty_api_blueprint() @model_api.route("/", methods=["GET"]) @@ -109,5 +114,9 @@ def predict(): def transform(): return {"message": "ERROR: {}".format(exc_value)}, HTTP_513_DRUM_PIPELINE_ERROR - app = get_flask_app(model_api) - app.run(host, port) + app = get_flask_app(model_api, flask_app) + if flask_app: + # when running application via the command line (e.g., gunicorn worker) + pass + else: + app.run(host, port) diff --git a/custom_model_runner/datarobot_drum/drum/server.py b/custom_model_runner/datarobot_drum/drum/server.py index 1fd30e9b5..012d39d6b 100644 --- a/custom_model_runner/datarobot_drum/drum/server.py +++ b/custom_model_runner/datarobot_drum/drum/server.py @@ -5,6 +5,8 @@ Released under the terms of DataRobot Tool and Utility Agreement. """ import datetime +from typing import Optional + import flask import os import uuid @@ -29,8 +31,9 @@ logger = get_drum_logger(LOGGER_NAME_PREFIX) -def get_flask_app(api_blueprint): - app = create_flask_app() +def get_flask_app(api_blueprint, app: Optional[Flask] = None): + if app is None: + app = create_flask_app() url_prefix = os.environ.get(URL_PREFIX_ENV_VAR_NAME, "") app.register_blueprint(api_blueprint, url_prefix=url_prefix) return app diff --git a/tests/unit/datarobot_drum/drum/test_prediction_server.py b/tests/unit/datarobot_drum/drum/test_prediction_server.py index 7e57687f0..5f21846fe 100644 --- a/tests/unit/datarobot_drum/drum/test_prediction_server.py +++ b/tests/unit/datarobot_drum/drum/test_prediction_server.py @@ -235,7 +235,8 @@ def chat_hook(completion_request, model): @pytest.mark.parametrize( - "processes_param, expected_processes, request_timeout", [(None, 1, None), (10, 10, 600)] + "processes_param, expected_processes, request_timeout", + [(None, 1, None), (None, 1, 0), (10, 10, 600)], ) def test_run_flask_app(processes_param, expected_processes, request_timeout): if request_timeout: