Skip to content
Draft
16 changes: 16 additions & 0 deletions custom_model_runner/datarobot_drum/drum/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,22 @@
import os
import signal
import sys
from datarobot_drum import RuntimeParameters

# Monkey patching for gevent compatibility if running with gunicorn-gevent
if RuntimeParameters.has("DRUM_SERVER_TYPE") and RuntimeParameters.has(
"DRUM_GUNICORN_WORKER_CLASS"
):
if (
str(RuntimeParameters.has("DRUM_SERVER_TYPE")).lower() == "gunicorn"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe define gunicorn, gevent as constants? THey are reused in several places

and str(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CLASS")).lower() == "gevent"
):
try:
from gevent import monkey

monkey.patch_all()
except ImportError:
pass

from datarobot_drum.drum.common import config_logging, setup_otel
from datarobot_drum.drum.utils.setup import setup_options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,67 @@ def handle_exception(e):

return []

def get_gunicorn_config(self):
config = {}
if RuntimeParameters.has("DRUM_GUNICORN_WORKER_CLASS"):
worker_class = str(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CLASS"))
if worker_class.lower() in {"sync", "gevent"}:
config["worker_class"] = worker_class

if RuntimeParameters.has("DRUM_GUNICORN_WORKER_CONNECTIONS"):
worker_connections = int(RuntimeParameters.get("DRUM_GUNICORN_WORKER_CONNECTIONS"))
if 1 <= worker_connections <= 10000:
config["worker_connections"] = worker_connections

if RuntimeParameters.has("DRUM_GUNICORN_BACKLOG"):
backlog = int(RuntimeParameters.get("DRUM_GUNICORN_BACKLOG"))
if 1 <= backlog <= 2048:
config["backlog"] = backlog

if RuntimeParameters.has("DRUM_GUNICORN_TIMEOUT"):
timeout = int(RuntimeParameters.get("DRUM_GUNICORN_TIMEOUT"))
if 1 <= timeout <= 3600:
config["timeout"] = timeout

if RuntimeParameters.has("DRUM_GUNICORN_GRACEFUL_TIMEOUT"):
graceful_timeout = int(RuntimeParameters.get("DRUM_GUNICORN_GRACEFUL_TIMEOUT"))
if 1 <= graceful_timeout <= 3600:
config["graceful_timeout"] = graceful_timeout

if RuntimeParameters.has("DRUM_GUNICORN_KEEP_ALIVE"):
keepalive = int(RuntimeParameters.get("DRUM_GUNICORN_KEEP_ALIVE"))
if 1 <= keepalive <= 3600:
config["keepalive"] = keepalive

if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS"):
max_requests = int(RuntimeParameters.get("DRUM_GUNICORN_MAX_REQUESTS"))
if 100 <= max_requests <= 10000:
config["max_requests"] = max_requests

if RuntimeParameters.has("DRUM_GUNICORN_MAX_REQUESTS_JITTER"):
max_requests_jitter = int(RuntimeParameters.get("DRUM_GUNICORN_MAX_REQUESTS_JITTER"))
if 1 <= max_requests_jitter <= 10000:
config["max_requests_jitter"] = max_requests_jitter

if RuntimeParameters.has("DRUM_GUNICORN_LOG_LEVEL"):
loglevel = str(RuntimeParameters.get("DRUM_GUNICORN_LOG_LEVEL"))
if loglevel.lower() in {"debug", "info", "warning", "error", "critical"}:
config["loglevel"] = loglevel

return config

def get_server_type(self):
server_type = "flask"
if RuntimeParameters.has("DRUM_SERVER_TYPE"):
server_type = str(RuntimeParameters.get("DRUM_SERVER_TYPE"))
if server_type.lower() in {"flask", "gunicorn"}:
server_type = server_type.lower()
return server_type

def _run_flask_app(self, app):
host = self._params.get("host", None)
port = self._params.get("port", None)

server_type = self.get_server_type()
processes = 1
if self._params.get("processes"):
processes = self._params.get("processes")
Expand All @@ -340,20 +397,82 @@ def _run_flask_app(self, app):
)
self._server_watchdog.start()

# Configure the server with timeout settings
app.run(
host=host,
port=port,
threaded=False,
processes=processes,
**(
{"request_handler": TimeoutWSGIRequestHandler}
if RuntimeParameters.has("DRUM_CLIENT_REQUEST_TIMEOUT")
else {}
),
)
if server_type == "gunicorn":
logger.info("Starting gunicorn server")
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably this can be removed as we require gunicorn

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe import can be even moved to the beginning of the file

from gunicorn.app.base import BaseApplication
except ImportError:
BaseApplication = None
raise DrumCommonException("gunicorn is not installed. Please install gunicorn.")

class GunicornApp(BaseApplication):
def __init__(self, app, host, port, params, gunicorn_config):
self.application = app
self.host = host
self.port = port
self.params = params
self.gunicorn_config = gunicorn_config
super().__init__()

def load_config(self):
self.cfg.set("bind", f"{self.host}:{self.port}")
workers = (
self.params.get("gunicorn_workers")
or self.params.get("max_workers")
or self.params.get("processes")
)
self.cfg.set("workers", workers)

self.cfg.set(
"worker_class", self.gunicorn_config.get("worker_class", "sync")
)
self.cfg.set("backlog", self.gunicorn_config.get("backlog", 2048))
self.cfg.set("timeout", self.gunicorn_config.get("timeout", 120))
self.cfg.set(
"graceful_timeout", self.gunicorn_config.get("graceful_timeout", 30)
)
self.cfg.set("keepalive", self.gunicorn_config.get("keepalive", 5))
self.cfg.set("max_requests", self.gunicorn_config.get("max_requests", 2000))
self.cfg.set(
"max_requests_jitter",
self.gunicorn_config.get("max_requests_jitter", 500),
)

if self.gunicorn_config.get("worker_connections"):
self.cfg.set(
"worker_connections", self.gunicorn_config.get("worker_connections")
)
self.cfg.set("loglevel", self.gunicorn_config.get("loglevel", "info"))

self.cfg.set("accesslog", "-")
self.cfg.set("errorlog", "-") # if you want error logs to stdout
self.cfg.set(
"access_log_format",
'%(t)s %(h)s %(l)s %(u)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"',
)
# Remove unsupported config keys: access_logfile, error_logfile, access_logformat
# These must be set via CLI, not config API

def load(self):
return self.application

gunicorn_config = self.get_gunicorn_config()
GunicornApp(app, host, port, self._params, gunicorn_config).run()
else:
# Configure the server with timeout settings
app.run(
host=host,
port=port,
threaded=False,
processes=processes,
**(
{"request_handler": TimeoutWSGIRequestHandler}
if RuntimeParameters.has("DRUM_CLIENT_REQUEST_TIMEOUT")
else {}
),
)
except OSError as e:
raise DrumCommonException("{}: host: {}; port: {}".format(e, host, port))
raise DrumCommonException(f"{e}: host: {host}; port: {port}")

def _kill_all_processes(self):
"""
Expand Down
2 changes: 2 additions & 0 deletions custom_model_runner/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ argcomplete
trafaret>=2.0.0
docker>=4.2.2
flask
gevent>=22.10.2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We try to avoid pinning versions here, instead we pin them in the environment

gunicorn>=20.1.0
jinja2>=3.0.0
memory_profiler<1.0.0
numpy
Expand Down