-
Notifications
You must be signed in to change notification settings - Fork 88
[RAPTOR-14353] Run flask server using CLI (gunicorn with gevent support) #1633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9ba4e58
209c4f9
1251e1a
7a2b0bc
e75d3e4
52541cd
039611f
7e45b62
ae5245c
e9b3163
cfa570e
eadfebe
db6ce4c
f8adbb2
32420b3
7b847f6
df74c73
98d7b14
bd34a1e
f17efc8
f5e64fd
6b41abc
eef913e
b0bd728
55c8763
0971bc0
a690532
170a04c
f180994
3b1967c
c625d6d
4ef3541
77f6ad2
81eeb91
b7c58fa
8d1c1ae
c2b1ef7
b976d6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,10 +54,50 @@ | |
) | ||
|
||
|
||
def main(): | ||
with DrumRuntime() as runtime: | ||
def main(flask_app=None, worker_ctx=None): | ||
""" | ||
The main entry point for the custom model runner. | ||
|
||
This function initializes the runtime environment, sets up logging, handles | ||
signal interruptions, and starts the CMRunner for executing user-defined models. | ||
|
||
Args: | ||
flask_app: Optional[Flask] Flask application instance, used when running using command line. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It's always a command line, when running locally or when running DRUM in as a sidecar container. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the old logic remains. It was made to run drum using command
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if flast_app is None it is created |
||
worker_ctx: Optional gunicorn worker context (WorkerCtx), used for managing cleanup tasks in a | ||
multi-worker setup (e.g., Gunicorn). | ||
|
||
Returns: | ||
None | ||
""" | ||
with DrumRuntime(flask_app) as runtime: | ||
config_logging() | ||
|
||
if worker_ctx: | ||
# Perform cleanup specific to the Gunicorn worker being terminated. | ||
# Gunicorn spawns multiple worker processes to handle requests. Each worker has its own context, | ||
# and this ensures that only the resources associated with the current worker are released. | ||
# defer_cleanup simply saves methods to be executed during worker restart or shutdown. | ||
# More details in https://github.com/datarobot/datarobot-custom-templates/pull/419 | ||
if runtime.options and RunMode(runtime.options.subparser_name) == RunMode.SERVER: | ||
if runtime.cm_runner: | ||
worker_ctx.defer_cleanup( | ||
lambda: runtime.cm_runner.terminate(), desc="runtime.cm_runner.terminate()" | ||
) | ||
if runtime.trace_provider is not None: | ||
worker_ctx.defer_cleanup( | ||
lambda: runtime.trace_provider.shutdown(), | ||
desc="runtime.trace_provider.shutdown()", | ||
) | ||
if runtime.metric_provider is not None: | ||
worker_ctx.defer_cleanup( | ||
lambda: runtime.metric_provider.shutdown(), | ||
desc="runtime.metric_provider.shutdown()", | ||
) | ||
if runtime.log_provider is not None: | ||
worker_ctx.defer_cleanup( | ||
lambda: runtime.log_provider.shutdown(), desc="runtime.log_provider.shutdown()" | ||
) | ||
|
||
def signal_handler(sig, frame): | ||
# The signal is assigned so the stacktrace is not presented when Ctrl-C is pressed. | ||
# The cleanup itself is done only if we are NOT running in performance test mode which | ||
|
@@ -89,13 +129,14 @@ def signal_handler(sig, frame): | |
runtime.metric_provider = metric_provider | ||
runtime.log_provider = log_provider | ||
|
||
signal.signal(signal.SIGINT, signal_handler) | ||
signal.signal(signal.SIGTERM, signal_handler) | ||
if worker_ctx is None: | ||
signal.signal(signal.SIGINT, signal_handler) | ||
signal.signal(signal.SIGTERM, signal_handler) | ||
|
||
from datarobot_drum.drum.drum import CMRunner | ||
|
||
try: | ||
runtime.cm_runner = CMRunner(runtime) | ||
runtime.cm_runner = CMRunner(runtime, flask_app, worker_ctx) | ||
runtime.cm_runner.run() | ||
except DrumSchemaValidationException: | ||
sys.exit(ExitCodes.SCHEMA_VALIDATION_ERROR.value) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,8 +68,11 @@ class TimeoutWSGIRequestHandler(WSGIRequestHandler): | |
|
||
|
||
class PredictionServer(PredictMixin): | ||
def __init__(self, params: dict): | ||
def __init__(self, params: dict, flask_app=None): | ||
self._params = params | ||
self.flask_app = ( | ||
flask_app # This is the Flask app object, used when running the application via CLI | ||
) | ||
self._show_perf = self._params.get("show_perf") | ||
self._resource_monitor = ResourceMonitor(monitor_current_process=True) | ||
self._run_language = RunLanguage(params.get("run_language")) | ||
|
@@ -310,7 +313,7 @@ def handle_exception(e): | |
cli = sys.modules["flask.cli"] | ||
cli.show_server_banner = lambda *x: None | ||
|
||
app = get_flask_app(model_api) | ||
app = get_flask_app(model_api, self.flask_app) | ||
self.load_flask_extensions(app) | ||
self._run_flask_app(app) | ||
|
||
|
@@ -319,6 +322,19 @@ def handle_exception(e): | |
|
||
return [] | ||
|
||
def is_client_request_timeout_enabled(self): | ||
if ( | ||
RuntimeParameters.has("DRUM_CLIENT_REQUEST_TIMEOUT") | ||
and int(RuntimeParameters.get("DRUM_CLIENT_REQUEST_TIMEOUT")) > 0 | ||
): | ||
logger.info( | ||
"Client request timeout is enabled, timeout: %s", | ||
str(int(TimeoutWSGIRequestHandler.timeout)), | ||
) | ||
return True | ||
else: | ||
return False | ||
|
||
def _run_flask_app(self, app): | ||
host = self._params.get("host", None) | ||
port = self._params.get("port", None) | ||
|
@@ -328,30 +344,34 @@ def _run_flask_app(self, app): | |
processes = self._params.get("processes") | ||
logger.info("Number of webserver processes: %s", processes) | ||
try: | ||
if RuntimeParameters.has("USE_NIM_WATCHDOG") and str( | ||
RuntimeParameters.get("USE_NIM_WATCHDOG") | ||
).lower() in ["true", "1", "yes"]: | ||
# Start the watchdog thread before running the app | ||
self._server_watchdog = Thread( | ||
target=self.watchdog, | ||
args=(port,), | ||
daemon=True, | ||
name="NIM Sidecar Watchdog", | ||
if self.flask_app: | ||
# when running application via the command line (e.g., gunicorn worker) | ||
pass | ||
else: | ||
if RuntimeParameters.has("USE_NIM_WATCHDOG") and str( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this check for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, I see it was defined in the original version so not critical for applying in current PR, just a suggestion There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. user should explicitly set this value. |
||
RuntimeParameters.get("USE_NIM_WATCHDOG") | ||
).lower() in ["true", "1", "yes"]: | ||
# Start the watchdog thread before running the app | ||
self._server_watchdog = Thread( | ||
target=self.watchdog, | ||
args=(port,), | ||
daemon=True, | ||
name="NIM Sidecar Watchdog", | ||
) | ||
self._server_watchdog.start() | ||
|
||
# Configure the server with timeout settings | ||
app.run( | ||
host=host, | ||
port=port, | ||
threaded=False, | ||
processes=processes, | ||
**( | ||
{"request_handler": TimeoutWSGIRequestHandler} | ||
if self.is_client_request_timeout_enabled() | ||
else {} | ||
), | ||
) | ||
self._server_watchdog.start() | ||
|
||
# Configure the server with timeout settings | ||
app.run( | ||
host=host, | ||
port=port, | ||
threaded=False, | ||
processes=processes, | ||
**( | ||
{"request_handler": TimeoutWSGIRequestHandler} | ||
if RuntimeParameters.has("DRUM_CLIENT_REQUEST_TIMEOUT") | ||
else {} | ||
), | ||
) | ||
except OSError as e: | ||
raise DrumCommonException("{}: host: {}; port: {}".format(e, host, port)) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably lets add short but genuine descriptions, than simply the method name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added.