27
27
_DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN"
28
28
_CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID"
29
29
_CLIENT_SECRET_ENVVAR_NAME = "DIODE_CLIENT_SECRET"
30
+ _INGEST_SCOPE = "diode:ingest"
30
31
_DEFAULT_STREAM = "latest"
31
32
_LOGGER = logging .getLogger (__name__ )
32
33
@@ -66,10 +67,15 @@ def _get_required_config_value(env_var_name: str, value: str | None = None) -> s
66
67
if value is None :
67
68
value = os .getenv (env_var_name )
68
69
if value is None :
69
- raise DiodeConfigError (f"parameter or { env_var_name } environment variable required" )
70
+ raise DiodeConfigError (
71
+ f"parameter or { env_var_name } environment variable required"
72
+ )
70
73
return value
71
74
72
- def _get_optional_config_value (env_var_name : str , value : str | None = None ) -> str | None :
75
+
76
+ def _get_optional_config_value (
77
+ env_var_name : str , value : str | None = None
78
+ ) -> str | None :
73
79
"""Get optional config value either from provided value or environment variable."""
74
80
if value is None :
75
81
value = os .getenv (env_var_name )
@@ -102,7 +108,9 @@ def __init__(
102
108
log_level = os .getenv (_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME , "INFO" ).upper ()
103
109
logging .basicConfig (level = log_level )
104
110
105
- self ._max_auth_retries = _get_optional_config_value (_MAX_RETRIES_ENVVAR_NAME , max_auth_retries )
111
+ self ._max_auth_retries = _get_optional_config_value (
112
+ _MAX_RETRIES_ENVVAR_NAME , max_auth_retries
113
+ )
106
114
self ._target , self ._path , self ._tls_verify = parse_target (target )
107
115
self ._app_name = app_name
108
116
self ._app_version = app_version
@@ -111,14 +119,16 @@ def __init__(
111
119
112
120
# Read client credentials from environment variables
113
121
self ._client_id = _get_required_config_value (_CLIENT_ID_ENVVAR_NAME , client_id )
114
- self ._client_secret = _get_required_config_value (_CLIENT_SECRET_ENVVAR_NAME , client_secret )
122
+ self ._client_secret = _get_required_config_value (
123
+ _CLIENT_SECRET_ENVVAR_NAME , client_secret
124
+ )
115
125
116
126
self ._metadata = (
117
127
("platform" , self ._platform ),
118
128
("python-version" , self ._python_version ),
119
129
)
120
130
121
- self ._authenticate ()
131
+ self ._authenticate (_INGEST_SCOPE )
122
132
123
133
channel_opts = (
124
134
(
@@ -149,7 +159,9 @@ def __init__(
149
159
_LOGGER .debug (f"Setting up gRPC interceptor for path: { self ._path } " )
150
160
rpc_method_interceptor = DiodeMethodClientInterceptor (subpath = self ._path )
151
161
152
- intercept_channel = grpc .intercept_channel (self ._channel , rpc_method_interceptor )
162
+ intercept_channel = grpc .intercept_channel (
163
+ self ._channel , rpc_method_interceptor
164
+ )
153
165
channel = intercept_channel
154
166
155
167
self ._stub = ingester_pb2_grpc .IngesterServiceStub (channel )
@@ -158,7 +170,9 @@ def __init__(
158
170
159
171
if self ._sentry_dsn is not None :
160
172
_LOGGER .debug ("Setting up Sentry" )
161
- self ._setup_sentry (self ._sentry_dsn , sentry_traces_sample_rate , sentry_profiles_sample_rate )
173
+ self ._setup_sentry (
174
+ self ._sentry_dsn , sentry_traces_sample_rate , sentry_profiles_sample_rate
175
+ )
162
176
163
177
@property
164
178
def name (self ) -> str :
@@ -233,13 +247,17 @@ def ingest(
233
247
except grpc .RpcError as err :
234
248
if err .code () == grpc .StatusCode .UNAUTHENTICATED :
235
249
if attempt < self ._max_auth_retries - 1 :
236
- _LOGGER .info (f"Retrying ingestion due to UNAUTHENTICATED error, attempt { attempt + 1 } " )
237
- self ._authenticate ()
250
+ _LOGGER .info (
251
+ f"Retrying ingestion due to UNAUTHENTICATED error, attempt { attempt + 1 } "
252
+ )
253
+ self ._authenticate (_INGEST_SCOPE )
238
254
continue
239
255
raise DiodeClientError (err ) from err
240
256
raise RuntimeError ("Max retries exceeded" )
241
257
242
- def _setup_sentry (self , dsn : str , traces_sample_rate : float , profiles_sample_rate : float ):
258
+ def _setup_sentry (
259
+ self , dsn : str , traces_sample_rate : float , profiles_sample_rate : float
260
+ ):
243
261
sentry_sdk .init (
244
262
dsn = dsn ,
245
263
release = self .version ,
@@ -254,20 +272,37 @@ def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rat
254
272
sentry_sdk .set_tag ("platform" , self ._platform )
255
273
sentry_sdk .set_tag ("python_version" , self ._python_version )
256
274
257
- def _authenticate (self ):
258
- authentication_client = _DiodeAuthentication (self ._target , self ._path , self ._tls_verify , self ._client_id , self ._client_secret )
275
+ def _authenticate (self , scope : str ):
276
+ authentication_client = _DiodeAuthentication (
277
+ self ._target ,
278
+ self ._path ,
279
+ self ._tls_verify ,
280
+ self ._client_id ,
281
+ self ._client_secret ,
282
+ scope ,
283
+ )
259
284
access_token = authentication_client .authenticate ()
260
- self ._metadata = list (filter (lambda x : x [0 ] != "authorization" , self ._metadata )) + \
261
- [("authorization" , f"Bearer { access_token } " )]
285
+ self ._metadata = list (
286
+ filter (lambda x : x [0 ] != "authorization" , self ._metadata )
287
+ ) + [("authorization" , f"Bearer { access_token } " )]
262
288
263
289
264
290
class _DiodeAuthentication :
265
- def __init__ (self , target : str , path : str , tls_verify : bool , client_id : str , client_secret : str ):
291
+ def __init__ (
292
+ self ,
293
+ target : str ,
294
+ path : str ,
295
+ tls_verify : bool ,
296
+ client_id : str ,
297
+ client_secret : str ,
298
+ scope : str ,
299
+ ):
266
300
self ._target = target
267
301
self ._tls_verify = tls_verify
268
302
self ._client_id = client_id
269
303
self ._client_secret = client_secret
270
304
self ._path = path
305
+ self ._scope = scope
271
306
272
307
def authenticate (self ) -> str :
273
308
"""Request an OAuth2 token using client credentials and return it."""
@@ -286,6 +321,7 @@ def authenticate(self) -> str:
286
321
"grant_type" : "client_credentials" ,
287
322
"client_id" : self ._client_id ,
288
323
"client_secret" : self ._client_secret ,
324
+ "scope" : self ._scope ,
289
325
}
290
326
)
291
327
url = self ._get_auth_url ()
@@ -299,15 +335,17 @@ def authenticate(self) -> str:
299
335
token_info = json .loads (response .read ().decode ())
300
336
access_token = token_info .get ("access_token" )
301
337
if not access_token :
302
- raise DiodeConfigError (f"Failed to obtain access token for client { self ._client_id } " )
338
+ raise DiodeConfigError (
339
+ f"Failed to obtain access token for client { self ._client_id } "
340
+ )
303
341
304
342
_LOGGER .debug (f"Access token obtained for client { self ._client_id } " )
305
343
return access_token
306
344
307
345
def _get_auth_url (self ) -> str :
308
346
"""Construct the authentication URL, handling trailing slashes in the path."""
309
347
# Ensure the path does not have trailing slashes
310
- path = self ._path .rstrip ('/' ) if self ._path else ''
348
+ path = self ._path .rstrip ("/" ) if self ._path else ""
311
349
return f"{ path } /auth/token"
312
350
313
351
@@ -332,10 +370,10 @@ class _ClientCallDetails(
332
370
333
371
"""
334
372
335
- pass
336
-
337
373
338
- class DiodeMethodClientInterceptor (grpc .UnaryUnaryClientInterceptor , grpc .StreamUnaryClientInterceptor ):
374
+ class DiodeMethodClientInterceptor (
375
+ grpc .UnaryUnaryClientInterceptor , grpc .StreamUnaryClientInterceptor
376
+ ):
339
377
"""
340
378
Diode Method Client Interceptor class.
341
379
@@ -374,6 +412,8 @@ def intercept_unary_unary(self, continuation, client_call_details, request):
374
412
"""Intercept unary unary."""
375
413
return self ._intercept_call (continuation , client_call_details , request )
376
414
377
- def intercept_stream_unary (self , continuation , client_call_details , request_iterator ):
415
+ def intercept_stream_unary (
416
+ self , continuation , client_call_details , request_iterator
417
+ ):
378
418
"""Intercept stream unary."""
379
419
return self ._intercept_call (continuation , client_call_details , request_iterator )
0 commit comments