diff --git a/polytope_server/broker/broker.py b/polytope_server/broker/broker.py index 9ba1b53c..356cbf08 100644 --- a/polytope_server/broker/broker.py +++ b/polytope_server/broker/broker.py @@ -30,8 +30,8 @@ class Broker: def __init__(self, config: dict): - queue_config = config.get("queue") - self.queue = queue.create_queue(queue_config) + self.queue_config = config.get("queue") + self.queue = None self.max_queue_size = config.get("deployment", {}).get("worker", {}).get("replicas", 40) @@ -47,8 +47,14 @@ def run(self): logging.info("Starting broker...") logging.info("Maximum Queue Size: {}".format(self.max_queue_size)) - while not time.sleep(self.scheduling_interval): - self.check_requests() + q = queue.create_queue(self.queue_config) + self.queue = q + try: + while not time.sleep(self.scheduling_interval): + self.check_requests() + finally: + q.close_connection() + self.request_store.close() def check_requests(self): diff --git a/polytope_server/common/metric_store/dynamodb_metric_store.py b/polytope_server/common/metric_store/dynamodb_metric_store.py index 8b925837..efb88a3b 100644 --- a/polytope_server/common/metric_store/dynamodb_metric_store.py +++ b/polytope_server/common/metric_store/dynamodb_metric_store.py @@ -136,14 +136,14 @@ def __init__(self, config=None): table_name = config.get("table_name", "metrics") dynamodb = boto3.resource("dynamodb", region_name=region, endpoint_url=endpoint_url) - client = dynamodb.meta.client + self.client = dynamodb.meta.client self.table = dynamodb.Table(table_name) try: - response = client.describe_table(TableName=table_name) + response = self.client.describe_table(TableName=table_name) if response["Table"]["TableStatus"] != "ACTIVE": raise RuntimeError(f"DynamoDB table {table_name} is not active.") - except client.exceptions.ResourceNotFoundException: + except self.client.exceptions.ResourceNotFoundException: _create_table(dynamodb, table_name) def get_type(self): @@ -253,6 +253,10 @@ def remove_old_metrics(self, cutoff): return len(items_to_delete) + def close(self) -> None: + if self.client is not None: + self.client.close() + def get_usage_metrics_aggregated(self, cutoff_timestamps): """ Fetch and aggregate usage metrics from DynamoDB. diff --git a/polytope_server/common/metric_store/metric_store.py b/polytope_server/common/metric_store/metric_store.py index 5cc772e4..0084118d 100644 --- a/polytope_server/common/metric_store/metric_store.py +++ b/polytope_server/common/metric_store/metric_store.py @@ -76,6 +76,10 @@ def remove_old_metrics(self, cutoff: datetime.datetime) -> int: int: Number of removed metrics. """ + @abstractmethod + def close(self) -> None: + """Close any resources held by the metric store.""" + type_to_class_map = {"mongodb": "MongoMetricStore", "dynamodb": "DynamoDBMetricStore"} diff --git a/polytope_server/common/metric_store/mongodb_metric_store.py b/polytope_server/common/metric_store/mongodb_metric_store.py index 5dc2b9b7..7ada4793 100644 --- a/polytope_server/common/metric_store/mongodb_metric_store.py +++ b/polytope_server/common/metric_store/mongodb_metric_store.py @@ -180,3 +180,7 @@ def remove_old_metrics(self, cutoff): cutoff = cutoff.timestamp() result = self.store.delete_many({"timestamp": {"$lt": cutoff}}) return result.deleted_count + + def close(self) -> None: + if self.mongo_client is not None: + self.mongo_client.close() diff --git a/polytope_server/common/queue/__init__.py b/polytope_server/common/queue/__init__.py index 45a707a8..25099cba 100644 --- a/polytope_server/common/queue/__init__.py +++ b/polytope_server/common/queue/__init__.py @@ -18,4 +18,13 @@ # does it submit to any jurisdiction. # -from .queue import * +from .queue import Message, Queue # noqa: F401 +from .rabbitmq_queue import RabbitmqQueue +from .sqs_queue import SQSQueue + +queue_types = {"rabbitmq": RabbitmqQueue, "sqs": SQSQueue} + + +def create_queue(queue_config) -> Queue: + queue_type = next(iter(queue_config.keys()), "rabbitmq") + return queue_types[queue_type](queue_config[queue_type]) diff --git a/polytope_server/common/queue/queue.py b/polytope_server/common/queue/queue.py index 64f5ffc3..4f422ca9 100644 --- a/polytope_server/common/queue/queue.py +++ b/polytope_server/common/queue/queue.py @@ -18,7 +18,6 @@ # does it submit to any jurisdiction. # -import importlib from abc import ABC, abstractmethod @@ -68,14 +67,3 @@ def count(self) -> int: @abstractmethod def get_type(self) -> str: """Get the implementation type""" - - -queue_dict = {"rabbitmq": "RabbitmqQueue", "sqs": "SQSQueue"} - - -def create_queue(queue_config) -> Queue: - - queue_type = next(iter(queue_config.keys()), "rabbitmq") - - QueueClass = importlib.import_module("polytope_server.common.queue." + queue_type + "_queue") - return getattr(QueueClass, queue_dict[queue_type])(queue_config[queue_type]) diff --git a/polytope_server/common/request_store/dynamodb_request_store.py b/polytope_server/common/request_store/dynamodb_request_store.py index 025ccef4..bbbf6b36 100644 --- a/polytope_server/common/request_store/dynamodb_request_store.py +++ b/polytope_server/common/request_store/dynamodb_request_store.py @@ -136,14 +136,14 @@ def __init__(self, config=None, metric_store_config=None): table_name = config.get("table_name", "requests") dynamodb = boto3.resource("dynamodb", region_name=region, endpoint_url=endpoint_url) - client = dynamodb.meta.client + self.client = dynamodb.meta.client self.table = dynamodb.Table(table_name) try: - response = client.describe_table(TableName=table_name) + response = self.client.describe_table(TableName=table_name) if response["Table"]["TableStatus"] != "ACTIVE": raise RuntimeError(f"DynamoDB table {table_name} is not active.") - except client.exceptions.ResourceNotFoundException: + except self.client.exceptions.ResourceNotFoundException: _create_table(dynamodb, table_name) self.metric_store = None @@ -355,3 +355,9 @@ def remove_old_requests(self, cutoff: dt.datetime): batch.delete_item(Key={"id": id}) logger.info("Deleting request %s because it is older than cutoff.", id) return len(items_to_delete) + + def close(self) -> None: + if self.metric_store is not None: + self.metric_store.close() + if self.client is not None: + self.client.close() diff --git a/polytope_server/common/request_store/mongodb_request_store.py b/polytope_server/common/request_store/mongodb_request_store.py index 5d4332e3..f6dfcd36 100644 --- a/polytope_server/common/request_store/mongodb_request_store.py +++ b/polytope_server/common/request_store/mongodb_request_store.py @@ -219,3 +219,9 @@ def remove_old_requests(self, cutoff): ) logging.info("Removed {} old requests from request store.".format(result.deleted_count)) return result.deleted_count + + def close(self) -> None: + if self.metric_store is not None: + self.metric_store.close() + if self.mongo_client is not None: + self.mongo_client.close() diff --git a/polytope_server/common/request_store/request_store.py b/polytope_server/common/request_store/request_store.py index e6a503c1..45058e17 100644 --- a/polytope_server/common/request_store/request_store.py +++ b/polytope_server/common/request_store/request_store.py @@ -116,6 +116,10 @@ def remove_old_requests(self, cutoff: datetime.datetime) -> int: int: Number of removed requests. """ + @abstractmethod + def close(self) -> None: + """Close any resources held by the request store.""" + type_to_class_map = {"mongodb": "MongoRequestStore", "dynamodb": "DynamoDBRequestStore"} diff --git a/polytope_server/frontend/frontend.py b/polytope_server/frontend/frontend.py index 97480ce3..9915879d 100644 --- a/polytope_server/frontend/frontend.py +++ b/polytope_server/frontend/frontend.py @@ -89,5 +89,8 @@ def run(self): self.config.get("frontend", {}).get("proxy_support", False), ) - logging.info("Starting frontend...") - handler_class.run_server(handler, self.server_type, self.host, self.port) + try: + logging.info("Starting frontend...") + handler_class.run_server(handler, self.server_type, self.host, self.port) + finally: + request_store.close() diff --git a/polytope_server/garbage_collector/garbage_collector.py b/polytope_server/garbage_collector/garbage_collector.py index fc0f2626..3b7102c3 100644 --- a/polytope_server/garbage_collector/garbage_collector.py +++ b/polytope_server/garbage_collector/garbage_collector.py @@ -59,11 +59,15 @@ def __init__(self, config): self.metric_store = create_metric_store(config.get("metric_store")) def run(self): - while not time.sleep(self.interval): - self.remove_old_requests() - self.remove_old_metrics() - self.remove_dangling_data() - self.remove_by_size() + try: + while not time.sleep(self.interval): + self.remove_old_requests() + self.remove_old_metrics() + self.remove_dangling_data() + self.remove_by_size() + finally: + self.request_store.close() + self.metric_store.close() def remove_old_requests(self): """Removes requests that are FAILED or PROCESSED after the configured time""" diff --git a/polytope_server/telemetry/telemetry_service.py b/polytope_server/telemetry/telemetry_service.py index 7f5ea90b..5050af7b 100644 --- a/polytope_server/telemetry/telemetry_service.py +++ b/polytope_server/telemetry/telemetry_service.py @@ -41,6 +41,11 @@ async def lifespan(self, app: FastAPI): # Attach resources to the app state app.state.resources = resources yield + # Close resources on shutdown + if resources.get("request_store") is not None: + resources["request_store"].close() + if resources.get("metric_store") is not None: + resources["metric_store"].close() def load_handler(self): handler_type = self.config.get("telemetry", {}).get("handler", "fastapi") diff --git a/polytope_server/worker/worker.py b/polytope_server/worker/worker.py index d8371460..baf21829 100644 --- a/polytope_server/worker/worker.py +++ b/polytope_server/worker/worker.py @@ -206,11 +206,14 @@ def handle_termination(group: aio.TaskGroup) -> None: def run(self): self.queue = polytope_queue.create_queue(self.config.get("queue")) + try: + self.update_status("idle", time_spent=0) - self.update_status("idle", time_spent=0) - - with ThreadPoolExecutor(max_workers=1) as executor: - aio.run(self.schedule(executor)) + with ThreadPoolExecutor(max_workers=1) as executor: + aio.run(self.schedule(executor)) + finally: + self.request_store.close() + self.queue.close_connection() def process_request( self,