Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
9ba4e58
working
s-gavrenkov Aug 13, 2025
209c4f9
fixed
s-gavrenkov Aug 20, 2025
1251e1a
removed rr
s-gavrenkov Aug 20, 2025
7a2b0bc
Reconcile dependencies, updated IDs, tags
svc-harness-git2 Aug 20, 2025
e75d3e4
working
s-gavrenkov Aug 21, 2025
52541cd
working
s-gavrenkov Aug 21, 2025
039611f
Reconcile dependencies, updated IDs, tags
svc-harness-git2 Aug 21, 2025
7e45b62
[BUZZOK-27100] Fix Drum Inline Runner and streamline DRUM options gen…
mjnitz02 Aug 12, 2025
ae5245c
[-] (Auto) Bump env_info versions (#1622)
svc-harness-git2 Aug 13, 2025
e9b3163
Bump keras in /public_dropin_environments/python3_keras (#1624)
dependabot[bot] Aug 13, 2025
cfa570e
Configure OTel metrics by default in drum. (#1620)
nickolai-dr Aug 13, 2025
eadfebe
[RAPTOR-14453] Regen requirements.txt to fix CVE-2025-8747 (#1623)
nullspoon Aug 13, 2025
db6ce4c
[BUZZOK-27241] Update DRUM version to 1.16.22 to add support for `def…
mjnitz02 Aug 13, 2025
f8adbb2
[BUZZOK-27241] [BUZZOK-27421] Bump requirements in GenAI Agents envir…
mjnitz02 Aug 14, 2025
32420b3
[CFX-3334] Update to latest drgithelper and properly set permissions …
c-h-russell-walker Aug 16, 2025
7b847f6
Add OTEL logging configuration, refactor traces and metrics. (#1626)
nickolai-dr Aug 18, 2025
df74c73
Bump DRUM version. (#1631)
nickolai-dr Aug 19, 2025
98d7b14
[RAPTOR-13851] pytorch: rebuild requirements to pull in updates (#1637)
nullspoon Aug 22, 2025
bd34a1e
Avoid infinite recursion in logs. (#1638)
nickolai-dr Aug 25, 2025
f17efc8
Update version for new release. (#1639)
nickolai-dr Aug 25, 2025
f5e64fd
[RAPTOR-14353] add client and NIM timeouts (#1640)
s-gavrenkov Aug 29, 2025
6b41abc
[RAPTOR-14353] Add nim watchdog (#1632)
s-gavrenkov Aug 29, 2025
eef913e
updated version (#1643)
s-gavrenkov Aug 29, 2025
b0bd728
Reconcile dependencies, updated IDs, tags
svc-harness-git2 Aug 30, 2025
55c8763
working
s-gavrenkov Sep 1, 2025
0971bc0
Merge remote-tracking branch 'origin/master' into gavrenkov/POC-gunicorn
s-gavrenkov Sep 1, 2025
a690532
Reconcile dependencies, updated IDs, tags
svc-harness-git2 Sep 1, 2025
170a04c
refactoring
s-gavrenkov Sep 1, 2025
f180994
Merge remote-tracking branch 'origin/gavrenkov/POC-gunicorn' into gav…
s-gavrenkov Sep 1, 2025
3b1967c
revert nim sidecar
s-gavrenkov Sep 1, 2025
c625d6d
revrt env_info.json
s-gavrenkov Sep 1, 2025
4ef3541
delint
s-gavrenkov Sep 1, 2025
77f6ad2
revert
s-gavrenkov Sep 1, 2025
81eeb91
added comments
s-gavrenkov Sep 1, 2025
b7c58fa
added is_client_request_timeout_enabled
s-gavrenkov Sep 3, 2025
8d1c1ae
added comments
s-gavrenkov Sep 3, 2025
c2b1ef7
added comments
s-gavrenkov Sep 3, 2025
b976d6b
fix signal termination
s-gavrenkov Sep 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 50 additions & 13 deletions custom_model_runner/datarobot_drum/drum/drum.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,12 @@


class CMRunner:
def __init__(self, runtime):
def __init__(self, runtime, flask_app=None, worker_ctx=None):
self.runtime = runtime
self.flask_app = (
flask_app # This is the Flask app object, used when running the application via CLI
)
self.worker_ctx = worker_ctx # This is the Gunicorn worker context object (WorkerCtx)
self.options = runtime.options
self.options.model_config = read_model_metadata_yaml(self.options.code_dir)
self.options.default_parameter_values = (
Expand Down Expand Up @@ -497,8 +501,18 @@ def run(self):
with self._setup_output_if_not_exists():
self._run_predictions(stats_collector)
finally:
if stats_collector:
stats_collector.disable()
if self.worker_ctx:
# Perform cleanup specific to the Gunicorn worker being terminated.
# Gunicorn spawns multiple worker processes to handle requests. Each worker has its own context,
# and this ensures that only the resources associated with the current worker are released.
# defer_cleanup simply saves methods to be executed during worker restart or shutdown.
# More details in https://github.com/datarobot/datarobot-custom-templates/pull/419
self.worker_ctx.defer_cleanup(
lambda: stats_collector.disable(), desc="stats_collector.disable()"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably lets add short but genuine descriptions, than simply the method name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added.

)
else:
if stats_collector:
stats_collector.disable()
if stats_collector:
stats_collector.print_reports()
elif self.run_mode == RunMode.SERVER:
Expand Down Expand Up @@ -826,7 +840,7 @@ def _run_predictions(self, stats_collector: Optional[StatsCollector] = None):
if stats_collector:
stats_collector.mark("start")
predictor = (
PredictionServer(params)
PredictionServer(params, self.flask_app)
if self.run_mode == RunMode.SERVER
else GenericPredictorComponent(params)
)
Expand All @@ -836,16 +850,39 @@ def _run_predictions(self, stats_collector: Optional[StatsCollector] = None):
if stats_collector:
stats_collector.mark("run")
finally:
if predictor is not None:
predictor.terminate()
if stats_collector:
stats_collector.mark("end")
if self.worker_ctx:
# Perform cleanup specific to the Gunicorn worker being terminated.
# Gunicorn spawns multiple worker processes to handle requests. Each worker has its own context,
# and this ensures that only the resources associated with the current worker are released.
# defer_cleanup simply saves methods to be executed during worker restart or shutdown.
# More details in https://github.com/datarobot/datarobot-custom-templates/pull/419
if predictor is not None:
self.worker_ctx.defer_cleanup(
lambda: predictor.terminate(), desc="predictor.terminate()"
)
if stats_collector:
self.worker_ctx.defer_cleanup(
lambda: stats_collector.mark("end"), desc="stats_collector.mark('end')"
)
self.worker_ctx.defer_cleanup(
lambda: self.logger.info(
"<<< Finish {} in the {} mode".format(
ArgumentsOptions.MAIN_COMMAND, self.run_mode.value
)
),
desc="logger.info(...)",
)

self.logger.info(
"<<< Finish {} in the {} mode".format(
ArgumentsOptions.MAIN_COMMAND, self.run_mode.value
)
)
else:
if predictor is not None:
predictor.terminate()
if stats_collector:
stats_collector.mark("end")
self.logger.info(
"<<< Finish {} in the {} mode".format(
ArgumentsOptions.MAIN_COMMAND, self.run_mode.value
)
)

@contextlib.contextmanager
def _setup_output_if_not_exists(self):
Expand Down
51 changes: 46 additions & 5 deletions custom_model_runner/datarobot_drum/drum/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,50 @@
)


def main():
with DrumRuntime() as runtime:
def main(flask_app=None, worker_ctx=None):
"""
The main entry point for the custom model runner.

This function initializes the runtime environment, sets up logging, handles
signal interruptions, and starts the CMRunner for executing user-defined models.

Args:
flask_app: Optional[Flask] Flask application instance, used when running using command line.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

..., used when running using command line.

It's always a command line, when running locally or when running DRUM in as a sidecar container.
Is the plan to setup Flask App when running DRUM as a sidecar?

Copy link
Contributor Author

@s-gavrenkov s-gavrenkov Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the old logic remains. It was made to run drum using command gunicorn -c gunicorn.conf.py app:app

if [ "$SERVER_TYPE" == "gunicorn" ]; then
    echo "Starting gunicorn server..."
    exec gunicorn -c gunicorn.conf.py app:app
else
    echo "Starting werkzeug (dev Flask) server..."
    exec drum server --sidecar --gpu-predictor=nim --logging-level=info "$@"
fi

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flast_app is not optional for DrumRuntime so it can't be optional for main()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if flast_app is None it is created app = create_flask_app() in def get_flask_app(api_blueprint):

https://github.com/datarobot/datarobot-user-models/pull/1633/files#diff-1d592639344a14e290f1bd489098ef8322773e22f75ab54e7cd5280177bbfd43R34-R39

worker_ctx: Optional gunicorn worker context (WorkerCtx), used for managing cleanup tasks in a
multi-worker setup (e.g., Gunicorn).

Returns:
None
"""
with DrumRuntime(flask_app) as runtime:
config_logging()

if worker_ctx:
# Perform cleanup specific to the Gunicorn worker being terminated.
# Gunicorn spawns multiple worker processes to handle requests. Each worker has its own context,
# and this ensures that only the resources associated with the current worker are released.
# defer_cleanup simply saves methods to be executed during worker restart or shutdown.
# More details in https://github.com/datarobot/datarobot-custom-templates/pull/419
if runtime.options and RunMode(runtime.options.subparser_name) == RunMode.SERVER:
if runtime.cm_runner:
worker_ctx.defer_cleanup(
lambda: runtime.cm_runner.terminate(), desc="runtime.cm_runner.terminate()"
)
if runtime.trace_provider is not None:
worker_ctx.defer_cleanup(
lambda: runtime.trace_provider.shutdown(),
desc="runtime.trace_provider.shutdown()",
)
if runtime.metric_provider is not None:
worker_ctx.defer_cleanup(
lambda: runtime.metric_provider.shutdown(),
desc="runtime.metric_provider.shutdown()",
)
if runtime.log_provider is not None:
worker_ctx.defer_cleanup(
lambda: runtime.log_provider.shutdown(), desc="runtime.log_provider.shutdown()"
)

def signal_handler(sig, frame):
# The signal is assigned so the stacktrace is not presented when Ctrl-C is pressed.
# The cleanup itself is done only if we are NOT running in performance test mode which
Expand Down Expand Up @@ -89,13 +129,14 @@ def signal_handler(sig, frame):
runtime.metric_provider = metric_provider
runtime.log_provider = log_provider

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
if worker_ctx is None:
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

from datarobot_drum.drum.drum import CMRunner

try:
runtime.cm_runner = CMRunner(runtime)
runtime.cm_runner = CMRunner(runtime, flask_app, worker_ctx)
runtime.cm_runner.run()
except DrumSchemaValidationException:
sys.exit(ExitCodes.SCHEMA_VALIDATION_ERROR.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,11 @@ class TimeoutWSGIRequestHandler(WSGIRequestHandler):


class PredictionServer(PredictMixin):
def __init__(self, params: dict):
def __init__(self, params: dict, flask_app=None):
self._params = params
self.flask_app = (
flask_app # This is the Flask app object, used when running the application via CLI
)
self._show_perf = self._params.get("show_perf")
self._resource_monitor = ResourceMonitor(monitor_current_process=True)
self._run_language = RunLanguage(params.get("run_language"))
Expand Down Expand Up @@ -310,7 +313,7 @@ def handle_exception(e):
cli = sys.modules["flask.cli"]
cli.show_server_banner = lambda *x: None

app = get_flask_app(model_api)
app = get_flask_app(model_api, self.flask_app)
self.load_flask_extensions(app)
self._run_flask_app(app)

Expand All @@ -319,6 +322,19 @@ def handle_exception(e):

return []

def is_client_request_timeout_enabled(self):
if (
RuntimeParameters.has("DRUM_CLIENT_REQUEST_TIMEOUT")
and int(RuntimeParameters.get("DRUM_CLIENT_REQUEST_TIMEOUT")) > 0
):
logger.info(
"Client request timeout is enabled, timeout: %s",
str(int(TimeoutWSGIRequestHandler.timeout)),
)
return True
else:
return False

def _run_flask_app(self, app):
host = self._params.get("host", None)
port = self._params.get("port", None)
Expand All @@ -328,30 +344,34 @@ def _run_flask_app(self, app):
processes = self._params.get("processes")
logger.info("Number of webserver processes: %s", processes)
try:
if RuntimeParameters.has("USE_NIM_WATCHDOG") and str(
RuntimeParameters.get("USE_NIM_WATCHDOG")
).lower() in ["true", "1", "yes"]:
# Start the watchdog thread before running the app
self._server_watchdog = Thread(
target=self.watchdog,
args=(port,),
daemon=True,
name="NIM Sidecar Watchdog",
if self.flask_app:
# when running application via the command line (e.g., gunicorn worker)
pass
else:
if RuntimeParameters.has("USE_NIM_WATCHDOG") and str(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check for has(<attribute>) is redundant as you can check for RuntimeParameters.get("USE_NIM_WATCHDOG", "no") with default False value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see it was defined in the original version so not critical for applying in current PR, just a suggestion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

user should explicitly set this value.

RuntimeParameters.get("USE_NIM_WATCHDOG")
).lower() in ["true", "1", "yes"]:
# Start the watchdog thread before running the app
self._server_watchdog = Thread(
target=self.watchdog,
args=(port,),
daemon=True,
name="NIM Sidecar Watchdog",
)
self._server_watchdog.start()

# Configure the server with timeout settings
app.run(
host=host,
port=port,
threaded=False,
processes=processes,
**(
{"request_handler": TimeoutWSGIRequestHandler}
if self.is_client_request_timeout_enabled()
else {}
),
)
self._server_watchdog.start()

# Configure the server with timeout settings
app.run(
host=host,
port=port,
threaded=False,
processes=processes,
**(
{"request_handler": TimeoutWSGIRequestHandler}
if RuntimeParameters.has("DRUM_CLIENT_REQUEST_TIMEOUT")
else {}
),
)
except OSError as e:
raise DrumCommonException("{}: host: {}; port: {}".format(e, host, port))

Expand Down
19 changes: 14 additions & 5 deletions custom_model_runner/datarobot_drum/drum/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Released under the terms of DataRobot Tool and Utility Agreement.
"""
import logging
from typing import Optional

from datarobot_drum.drum.server import (
empty_api_blueprint,
Expand All @@ -13,6 +14,7 @@
)
from datarobot_drum.drum.common import verbose_stdout, get_drum_logger
from datarobot_drum.drum.enum import LOGGER_NAME_PREFIX, RunMode
from flask import Flask

from datarobot_drum.drum.exceptions import DrumCommonException

Expand All @@ -23,14 +25,17 @@


class DrumRuntime:
def __init__(self):
def __init__(self, flask_app: Optional[Flask] = None):
self.initialization_succeeded = False
self.options = None
self.cm_runner = None
# OTEL services
self.trace_provider = None
self.metric_provider = None
self.log_provider = None
self.flask_app = (
flask_app # This is the Flask app object, used when running the application via CLI
)

def __enter__(self):
return self
Expand Down Expand Up @@ -83,12 +88,12 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
port = int(host_port_list[1]) if len(host_port_list) == 2 else None

with verbose_stdout(self.options.verbose):
run_error_server(host, port, exc_value)
run_error_server(host, port, exc_value, self.flask_app)

return False # propagate exception further


def run_error_server(host, port, exc_value):
def run_error_server(host, port, exc_value, flask_app: Optional[Flask] = None):
model_api = empty_api_blueprint()

@model_api.route("/", methods=["GET"])
Expand All @@ -109,5 +114,9 @@ def predict():
def transform():
return {"message": "ERROR: {}".format(exc_value)}, HTTP_513_DRUM_PIPELINE_ERROR

app = get_flask_app(model_api)
app.run(host, port)
app = get_flask_app(model_api, flask_app)
if flask_app:
# when running application via the command line (e.g., gunicorn worker)
pass
else:
app.run(host, port)
7 changes: 5 additions & 2 deletions custom_model_runner/datarobot_drum/drum/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
Released under the terms of DataRobot Tool and Utility Agreement.
"""
import datetime
from typing import Optional

import flask
import os
import uuid
Expand All @@ -29,8 +31,9 @@
logger = get_drum_logger(LOGGER_NAME_PREFIX)


def get_flask_app(api_blueprint):
app = create_flask_app()
def get_flask_app(api_blueprint, app: Optional[Flask] = None):
if app is None:
app = create_flask_app()
url_prefix = os.environ.get(URL_PREFIX_ENV_VAR_NAME, "")
app.register_blueprint(api_blueprint, url_prefix=url_prefix)
return app
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/datarobot_drum/drum/test_prediction_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ def chat_hook(completion_request, model):


@pytest.mark.parametrize(
"processes_param, expected_processes, request_timeout", [(None, 1, None), (10, 10, 600)]
"processes_param, expected_processes, request_timeout",
[(None, 1, None), (None, 1, 0), (10, 10, 600)],
)
def test_run_flask_app(processes_param, expected_processes, request_timeout):
if request_timeout:
Expand Down