Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEATURE: Add client side buffering support #488

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [


dependencies = [
"ansys-api-edb==1.0.11",
"ansys-api-edb==1.0.12",
"protobuf>=3.19.3,<5",
"grpcio>=1.44.0",
"Django>=4.2.16"
Expand Down
10 changes: 2 additions & 8 deletions src/ansys/edb/core/hierarchy/hierarchy_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,12 @@ def transform(self):
""":class:`.Transform`: \
Transformation information of the hierarchy object."""
transform_msg = self.__stub.GetTransform(self.msg)
return Transform.create(
transform_msg.scale,
transform_msg.angle,
transform_msg.mirror,
transform_msg.offset_x,
transform_msg.offset_y,
)
return Transform(transform_msg)

@transform.setter
def transform(self, value):
"""Set transform."""
self.__stub.SetTransform(messages.transform_property_message(self, value))
self.__stub.SetTransform(messages.pointer_property_message(self, value))

@property
def name(self):
Expand Down
28 changes: 20 additions & 8 deletions src/ansys/edb/core/inner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ansys.api.edb.v1.edb_messages_pb2 import EDBObjMessage

from ansys.edb.core.utility.cache import get_cache
from ansys.edb.core.utility.io_manager import get_buffer, get_cache, get_io_manager


class ObjBase:
Expand All @@ -15,10 +15,7 @@ def __init__(self, msg):
----------
msg : EDBObjMessage
"""
self._id = 0 if msg is None else msg.id
cache = get_cache()
if cache is not None:
cache.add_from_cache_msg(msg.cache)
self.msg = msg

@property
def is_null(self):
Expand All @@ -44,12 +41,27 @@ def msg(self):

This property can only be set to ``None``.
"""
return EDBObjMessage(id=self.id)
msg = EDBObjMessage(id=self.id)
io_mgr = get_io_manager()
if io_mgr.is_enabled:
if self._is_future:
msg.is_future = True
get_io_manager().active_request_edb_obj_msg_mgr.add_active_request_edb_obj_msg(msg)
return msg

@msg.setter
def msg(self, val):
if val is None:
def msg(self, msg):
if msg is None:
self._id = 0
return
self._id = msg.id
self._is_future = msg.is_future
if self._is_future:
if (buffer := get_buffer()) is not None:
buffer.add_future_ref(self)
else:
if (cache := get_cache()) is not None:
cache.add_from_cache_msg(msg)


class TypeField(object):
Expand Down
89 changes: 68 additions & 21 deletions src/ansys/edb/core/inner/interceptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@
from grpc import (
ClientCallDetails,
StatusCode,
StreamStreamClientInterceptor,
UnaryStreamClientInterceptor,
UnaryUnaryClientInterceptor,
)

from ansys.edb.core.inner.exceptions import EDBSessionException, ErrorCode, InvalidArgumentException
from ansys.edb.core.utility.cache import get_cache
from ansys.edb.core.inner.rpc_info_utils import can_cache
from ansys.edb.core.utility.io_manager import ServerNotification, get_io_manager


class Interceptor(UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, metaclass=abc.ABCMeta):
class Interceptor(
UnaryUnaryClientInterceptor,
UnaryStreamClientInterceptor,
StreamStreamClientInterceptor,
metaclass=abc.ABCMeta,
):
"""Provides the base interceptor class."""

def __init__(self, logger):
Expand All @@ -42,6 +49,10 @@ def intercept_unary_stream(self, continuation, client_call_details, request):
"""Intercept a gRPC streaming call."""
return continuation(client_call_details, request)

def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
"""Intercept a gRPC streaming call."""
return continuation(client_call_details, request_iterator)


class LoggingInterceptor(Interceptor):
"""Logs EDB errors on each request."""
Expand Down Expand Up @@ -92,7 +103,7 @@ def _post_process(self, response):
raise exception


class CachingInterceptor(Interceptor):
class IOInterceptor(Interceptor):
"""Returns cached values if a given request has already been made and caching is enabled."""

def __init__(self, logger, rpc_counter):
Expand All @@ -104,6 +115,7 @@ def __init__(self, logger, rpc_counter):
def _reset_cache_entry_data(self):
self._current_rpc_method = ""
self._current_cache_key_details = None
self._hijacked = False

def _should_log_traffic(self):
return self._rpc_counter is not None
Expand All @@ -114,54 +126,89 @@ class _ClientCallDetails(
):
pass

@staticmethod
def _add_caching_option_to_metadata(metadata, option, is_enabled):
metadata.append((option, "1" if is_enabled else "0"))

@classmethod
def _get_client_call_details_with_caching_options(cls, client_call_details):
if get_cache() is None:
io_mgr = get_io_manager()
if not io_mgr.get_notifications_for_server():
return client_call_details
metadata = []
if client_call_details.metadata is not None:
metadata = list(client_call_details.metadata)
metadata.append(("enable-caching", "1"))
for notification in io_mgr.get_notifications_for_server(True):
if notification == ServerNotification.INVALIDATE_CACHE:
cls._add_caching_option_to_metadata(metadata, "invalidate-cache", True)
elif notification == ServerNotification.FLUSH_BUFFER:
cls._add_caching_option_to_metadata(metadata, "flush-buffer", True)
return cls._ClientCallDetails(
client_call_details.method,
client_call_details.timeout,
metadata,
client_call_details.credentials,
)

def _continue_unary_unary(self, continuation, client_call_details, request):
@staticmethod
def _attempt_hijack(*args):
io_manager = get_io_manager()
hijacked_response = None
if (buffer := io_manager.buffer) is not None:
hijacked_response = buffer.hijack_request(*args)
if hijacked_response is None and (cache := io_manager.cache) is not None:
hijacked_response = cache.hijack_request(*args)
return hijacked_response

def _hijack(self, client_call_details, request):
io_manager = get_io_manager()
if io_manager.is_enabled and not io_manager.is_blocking:
with io_manager.manage_io():
method_tokens = client_call_details.method.strip("/").split("/")
cache_key_details = method_tokens[0], method_tokens[1], request
if (hijacked_result := self._attempt_hijack(*cache_key_details)) is not None:
self._hijacked = True
return hijacked_result
if io_manager.cache is not None and can_cache(
cache_key_details[0], cache_key_details[1]
):
self._current_cache_key_details = cache_key_details
if self._should_log_traffic():
self._current_rpc_method = client_call_details.method
cache = get_cache()
if cache is not None:
method_tokens = client_call_details.method.strip("/").split("/")
cache_key_details = method_tokens[0], method_tokens[1], request
cached_response = cache.get(*cache_key_details)
if cached_response is not None:
return cached_response
else:
self._current_cache_key_details = cache_key_details

def _continue_unary_unary(self, continuation, client_call_details, request):
if (hijacked_result := self._hijack(client_call_details, request)) is not None:
return hijacked_result
return super()._continue_unary_unary(
continuation,
self._get_client_call_details_with_caching_options(client_call_details),
request,
)

def _cache_missed(self):
return self._current_cache_key_details is not None

def _post_process(self, response):
cache = get_cache()
if cache is not None and self._cache_missed():
io_manager = get_io_manager()
if (cache := io_manager.cache) is not None and self._current_cache_key_details is not None:
cache.add(*self._current_cache_key_details, response.result())
if self._should_log_traffic() and (cache is None or self._cache_missed()):
if self._should_log_traffic() and not self._hijacked:
self._rpc_counter[self._current_rpc_method] += 1
self._reset_cache_entry_data()

def intercept_unary_stream(self, continuation, client_call_details, request):
"""Intercept a gRPC streaming call."""
if (hijacked_result := self._hijack(client_call_details, request)) is not None:
return hijacked_result
return super().intercept_unary_stream(
continuation,
self._get_client_call_details_with_caching_options(client_call_details),
request,
)

def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
"""Intercept a gRPC streaming call."""
if (hijacked_result := self._hijack(client_call_details, request_iterator)) is not None:
return hijacked_result
return super().intercept_stream_stream(
continuation,
self._get_client_call_details_with_caching_options(client_call_details),
request_iterator,
)
Loading
Loading