1
1
"""Client-side gRPC interceptors."""
2
2
3
3
import abc
4
+ from collections import namedtuple
4
5
import logging
5
6
6
- from grpc import StatusCode , UnaryUnaryClientInterceptor
7
+ from grpc import (
8
+ ClientCallDetails ,
9
+ StatusCode ,
10
+ UnaryStreamClientInterceptor ,
11
+ UnaryUnaryClientInterceptor ,
12
+ )
7
13
8
14
from ansys .edb .core .inner .exceptions import EDBSessionException , ErrorCode , InvalidArgumentException
15
+ from ansys .edb .core .utility .cache import get_cache
9
16
10
17
11
- class Interceptor (UnaryUnaryClientInterceptor , metaclass = abc .ABCMeta ):
18
+ class Interceptor (UnaryUnaryClientInterceptor , UnaryStreamClientInterceptor , metaclass = abc .ABCMeta ):
12
19
"""Provides the base interceptor class."""
13
20
14
21
def __init__ (self , logger ):
@@ -20,14 +27,21 @@ def __init__(self, logger):
20
27
def _post_process (self , response ):
21
28
pass
22
29
30
+ def _continue_unary_unary (self , continuation , client_call_details , request ):
31
+ return continuation (client_call_details , request )
32
+
23
33
def intercept_unary_unary (self , continuation , client_call_details , request ):
24
34
"""Intercept a gRPC call."""
25
- response = continuation ( client_call_details , request )
35
+ response = self . _continue_unary_unary ( continuation , client_call_details , request )
26
36
27
37
self ._post_process (response )
28
38
29
39
return response
30
40
41
+ def intercept_unary_stream (self , continuation , client_call_details , request ):
42
+ """Intercept a gRPC streaming call."""
43
+ return continuation (client_call_details , request )
44
+
31
45
32
46
class LoggingInterceptor (Interceptor ):
33
47
"""Logs EDB errors on each request."""
@@ -76,3 +90,78 @@ def _post_process(self, response):
76
90
77
91
if exception is not None :
78
92
raise exception
93
+
94
+
95
+ class CachingInterceptor (Interceptor ):
96
+ """Returns cached values if a given request has already been made and caching is enabled."""
97
+
98
+ def __init__ (self , logger , rpc_counter ):
99
+ """Initialize a caching interceptor with a logger and rpc counter."""
100
+ super ().__init__ (logger )
101
+ self ._rpc_counter = rpc_counter
102
+ self ._reset_cache_entry_data ()
103
+
104
+ def _reset_cache_entry_data (self ):
105
+ self ._current_rpc_method = ""
106
+ self ._current_cache_key_details = None
107
+
108
+ def _should_log_traffic (self ):
109
+ return self ._rpc_counter is not None
110
+
111
+ class _ClientCallDetails (
112
+ namedtuple ("_ClientCallDetails" , ("method" , "timeout" , "metadata" , "credentials" )),
113
+ ClientCallDetails ,
114
+ ):
115
+ pass
116
+
117
+ @classmethod
118
+ def _get_client_call_details_with_caching_options (cls , client_call_details ):
119
+ if get_cache () is None :
120
+ return client_call_details
121
+ metadata = []
122
+ if client_call_details .metadata is not None :
123
+ metadata = list (client_call_details .metadata )
124
+ metadata .append (("enable-caching" , "1" ))
125
+ return cls ._ClientCallDetails (
126
+ client_call_details .method ,
127
+ client_call_details .timeout ,
128
+ metadata ,
129
+ client_call_details .credentials ,
130
+ )
131
+
132
+ def _continue_unary_unary (self , continuation , client_call_details , request ):
133
+ if self ._should_log_traffic ():
134
+ self ._current_rpc_method = client_call_details .method
135
+ cache = get_cache ()
136
+ if cache is not None :
137
+ method_tokens = client_call_details .method .strip ("/" ).split ("/" )
138
+ cache_key_details = method_tokens [0 ], method_tokens [1 ], request
139
+ cached_response = cache .get (* cache_key_details )
140
+ if cached_response is not None :
141
+ return cached_response
142
+ else :
143
+ self ._current_cache_key_details = cache_key_details
144
+ return super ()._continue_unary_unary (
145
+ continuation ,
146
+ self ._get_client_call_details_with_caching_options (client_call_details ),
147
+ request ,
148
+ )
149
+
150
+ def _cache_missed (self ):
151
+ return self ._current_cache_key_details is not None
152
+
153
+ def _post_process (self , response ):
154
+ cache = get_cache ()
155
+ if cache is not None and self ._cache_missed ():
156
+ cache .add (* self ._current_cache_key_details , response .result ())
157
+ if self ._should_log_traffic () and (cache is None or self ._cache_missed ()):
158
+ self ._rpc_counter [self ._current_rpc_method ] += 1
159
+ self ._reset_cache_entry_data ()
160
+
161
+ def intercept_unary_stream (self , continuation , client_call_details , request ):
162
+ """Intercept a gRPC streaming call."""
163
+ return super ().intercept_unary_stream (
164
+ continuation ,
165
+ self ._get_client_call_details_with_caching_options (client_call_details ),
166
+ request ,
167
+ )
0 commit comments