1
1
#!/usr/bin/env python
2
2
# Copyright 2024 NetBox Labs Inc
3
3
"""NetBox Labs, Diode - SDK - Client."""
4
+
4
5
import collections
6
+ import http .client
7
+ import json
5
8
import logging
6
9
import os
7
10
import platform
11
+ import ssl
8
12
import uuid
9
13
from collections .abc import Iterable
10
- from urllib .parse import urlparse
14
+ from urllib .parse import urlencode , urlparse
11
15
12
16
import certifi
13
17
import grpc
18
22
from netboxlabs .diode .sdk .ingester import Entity
19
23
from netboxlabs .diode .sdk .version import version_semver
20
24
21
- _DIODE_API_KEY_ENVVAR_NAME = "DIODE_API_KEY "
25
+ _MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES "
22
26
_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL"
23
27
_DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN"
28
+ _CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID"
29
+ _CLIENT_SECRET_ENVVAR_NAME = "DIODE_CLIENT_SECRET"
24
30
_DEFAULT_STREAM = "latest"
25
31
_LOGGER = logging .getLogger (__name__ )
26
32
@@ -31,17 +37,6 @@ def _load_certs() -> bytes:
31
37
return f .read ()
32
38
33
39
34
- def _get_api_key (api_key : str | None = None ) -> str :
35
- """Get API Key either from provided value or environment variable."""
36
- if api_key is None :
37
- api_key = os .getenv (_DIODE_API_KEY_ENVVAR_NAME )
38
- if api_key is None :
39
- raise DiodeConfigError (
40
- f"api_key param or { _DIODE_API_KEY_ENVVAR_NAME } environment variable required"
41
- )
42
- return api_key
43
-
44
-
45
40
def parse_target (target : str ) -> tuple [str , str , bool ]:
46
41
"""Parse the target into authority, path and tls_verify."""
47
42
parsed_target = urlparse (target )
@@ -66,6 +61,21 @@ def _get_sentry_dsn(sentry_dsn: str | None = None) -> str | None:
66
61
return sentry_dsn
67
62
68
63
64
+ def _get_required_config_value (env_var_name : str , value : str | None = None ) -> str :
65
+ """Get required config value either from provided value or environment variable."""
66
+ if value is None :
67
+ value = os .getenv (env_var_name )
68
+ if value is None :
69
+ raise DiodeConfigError (f"parameter or { env_var_name } environment variable required" )
70
+ return value
71
+
72
+ def _get_optional_config_value (env_var_name : str , value : str | None = None ) -> str | None :
73
+ """Get optional config value either from provided value or environment variable."""
74
+ if value is None :
75
+ value = os .getenv (env_var_name )
76
+ return value
77
+
78
+
69
79
class DiodeClient :
70
80
"""Diode Client."""
71
81
@@ -81,30 +91,40 @@ def __init__(
81
91
target : str ,
82
92
app_name : str ,
83
93
app_version : str ,
84
- api_key : str | None = None ,
94
+ client_id : str | None = None ,
95
+ client_secret : str | None = None ,
85
96
sentry_dsn : str = None ,
86
97
sentry_traces_sample_rate : float = 1.0 ,
87
98
sentry_profiles_sample_rate : float = 1.0 ,
99
+ max_auth_retries : int = 3 ,
88
100
):
89
101
"""Initiate a new client."""
90
102
log_level = os .getenv (_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME , "INFO" ).upper ()
91
103
logging .basicConfig (level = log_level )
92
104
105
+ self ._max_auth_retries = _get_optional_config_value (_MAX_RETRIES_ENVVAR_NAME , max_auth_retries )
93
106
self ._target , self ._path , self ._tls_verify = parse_target (target )
94
107
self ._app_name = app_name
95
108
self ._app_version = app_version
96
109
self ._platform = platform .platform ()
97
110
self ._python_version = platform .python_version ()
98
111
99
- api_key = _get_api_key (api_key )
112
+ # Read client credentials from environment variables
113
+ self ._client_id = _get_required_config_value (_CLIENT_ID_ENVVAR_NAME , client_id )
114
+ self ._client_secret = _get_required_config_value (_CLIENT_SECRET_ENVVAR_NAME , client_secret )
115
+
100
116
self ._metadata = (
101
- ("diode-api-key" , api_key ),
102
117
("platform" , self ._platform ),
103
118
("python-version" , self ._python_version ),
104
119
)
105
120
121
+ self ._authenticate ()
122
+
106
123
channel_opts = (
107
- ("grpc.primary_user_agent" , f"{ self ._name } /{ self ._version } { self ._app_name } /{ self ._app_version } " ),
124
+ (
125
+ "grpc.primary_user_agent" ,
126
+ f"{ self ._name } /{ self ._version } { self ._app_name } /{ self ._app_version } " ,
127
+ ),
108
128
)
109
129
110
130
if self ._tls_verify :
@@ -129,9 +149,7 @@ def __init__(
129
149
_LOGGER .debug (f"Setting up gRPC interceptor for path: { self ._path } " )
130
150
rpc_method_interceptor = DiodeMethodClientInterceptor (subpath = self ._path )
131
151
132
- intercept_channel = grpc .intercept_channel (
133
- self ._channel , rpc_method_interceptor
134
- )
152
+ intercept_channel = grpc .intercept_channel (self ._channel , rpc_method_interceptor )
135
153
channel = intercept_channel
136
154
137
155
self ._stub = ingester_pb2_grpc .IngesterServiceStub (channel )
@@ -140,9 +158,7 @@ def __init__(
140
158
141
159
if self ._sentry_dsn is not None :
142
160
_LOGGER .debug ("Setting up Sentry" )
143
- self ._setup_sentry (
144
- self ._sentry_dsn , sentry_traces_sample_rate , sentry_profiles_sample_rate
145
- )
161
+ self ._setup_sentry (self ._sentry_dsn , sentry_traces_sample_rate , sentry_profiles_sample_rate )
146
162
147
163
@property
148
164
def name (self ) -> str :
@@ -202,24 +218,28 @@ def ingest(
202
218
stream : str | None = _DEFAULT_STREAM ,
203
219
) -> ingester_pb2 .IngestResponse :
204
220
"""Ingest entities."""
205
- try :
206
- request = ingester_pb2 .IngestRequest (
207
- stream = stream ,
208
- id = str (uuid .uuid4 ()),
209
- entities = entities ,
210
- sdk_name = self .name ,
211
- sdk_version = self .version ,
212
- producer_app_name = self .app_name ,
213
- producer_app_version = self .app_version ,
214
- )
215
-
216
- return self ._stub .Ingest (request , metadata = self ._metadata )
217
- except grpc .RpcError as err :
218
- raise DiodeClientError (err ) from err
219
-
220
- def _setup_sentry (
221
- self , dsn : str , traces_sample_rate : float , profiles_sample_rate : float
222
- ):
221
+ for attempt in range (self ._max_auth_retries ):
222
+ try :
223
+ request = ingester_pb2 .IngestRequest (
224
+ stream = stream ,
225
+ id = str (uuid .uuid4 ()),
226
+ entities = entities ,
227
+ sdk_name = self .name ,
228
+ sdk_version = self .version ,
229
+ producer_app_name = self .app_name ,
230
+ producer_app_version = self .app_version ,
231
+ )
232
+ return self ._stub .Ingest (request , metadata = self ._metadata )
233
+ except grpc .RpcError as err :
234
+ if err .code () == grpc .StatusCode .UNAUTHENTICATED :
235
+ if attempt < self ._max_auth_retries - 1 :
236
+ _LOGGER .info (f"Retrying ingestion due to UNAUTHENTICATED error, attempt { attempt + 1 } " )
237
+ self ._authenticate ()
238
+ continue
239
+ raise DiodeClientError (err ) from err
240
+ return RuntimeError ("Max retries exceeded" )
241
+
242
+ def _setup_sentry (self , dsn : str , traces_sample_rate : float , profiles_sample_rate : float ):
223
243
sentry_sdk .init (
224
244
dsn = dsn ,
225
245
release = self .version ,
@@ -234,6 +254,59 @@ def _setup_sentry(
234
254
sentry_sdk .set_tag ("platform" , self ._platform )
235
255
sentry_sdk .set_tag ("python_version" , self ._python_version )
236
256
257
+ def _authenticate (self ):
258
+ authentication_client = _DiodeAuthentication (self ._target , self ._path , self ._tls_verify , self ._client_id , self ._client_secret )
259
+ access_token = authentication_client .authenticate ()
260
+ self ._metadata = list (filter (lambda x : x [0 ] != "authorization" , self ._metadata )) + \
261
+ [("authorization" , f"Bearer { access_token } " )]
262
+
263
+
264
+ class _DiodeAuthentication :
265
+ def __init__ (self , target : str , path : str , tls_verify : bool , client_id : str , client_secret : str ):
266
+ self ._target = target
267
+ self ._tls_verify = tls_verify
268
+ self ._client_id = client_id
269
+ self ._client_secret = client_secret
270
+ self ._path = path
271
+
272
+ def authenticate (self ) -> str :
273
+ """Request an OAuth2 token using client credentials and return it."""
274
+ if self ._tls_verify :
275
+ conn = http .client .HTTPSConnection (
276
+ self ._target ,
277
+ context = None if self ._tls_verify else ssl ._create_unverified_context (),
278
+ )
279
+ else :
280
+ conn = http .client .HTTPConnection (
281
+ self ._target ,
282
+ )
283
+ headers = {"Content-type" : "application/x-www-form-urlencoded" }
284
+ data = urlencode (
285
+ {
286
+ "grant_type" : "client_credentials" ,
287
+ "client_id" : self ._client_id ,
288
+ "client_secret" : self ._client_secret ,
289
+ }
290
+ )
291
+ url = self ._get_auth_url ()
292
+ conn .request ("POST" , url , data , headers )
293
+ response = conn .getresponse ()
294
+ if response .status != 200 :
295
+ raise DiodeConfigError (f"Failed to obtain access token: { response .reason } " )
296
+ token_info = json .loads (response .read ().decode ())
297
+ access_token = token_info .get ("access_token" )
298
+ if not access_token :
299
+ raise DiodeConfigError (f"Failed to obtain access token for client { self ._client_id } " )
300
+
301
+ _LOGGER .debug (f"Access token obtained for client { self ._client_id } " )
302
+ return access_token
303
+
304
+ def _get_auth_url (self ) -> str :
305
+ """Construct the authentication URL, handling trailing slashes in the path."""
306
+ # Ensure the path does not have trailing slashes
307
+ path = self ._path .rstrip ('/' ) if self ._path else ''
308
+ return f"{ path } /auth/token"
309
+
237
310
238
311
class _ClientCallDetails (
239
312
collections .namedtuple (
@@ -259,9 +332,7 @@ class _ClientCallDetails(
259
332
pass
260
333
261
334
262
- class DiodeMethodClientInterceptor (
263
- grpc .UnaryUnaryClientInterceptor , grpc .StreamUnaryClientInterceptor
264
- ):
335
+ class DiodeMethodClientInterceptor (grpc .UnaryUnaryClientInterceptor , grpc .StreamUnaryClientInterceptor ):
265
336
"""
266
337
Diode Method Client Interceptor class.
267
338
@@ -300,8 +371,6 @@ def intercept_unary_unary(self, continuation, client_call_details, request):
300
371
"""Intercept unary unary."""
301
372
return self ._intercept_call (continuation , client_call_details , request )
302
373
303
- def intercept_stream_unary (
304
- self , continuation , client_call_details , request_iterator
305
- ):
374
+ def intercept_stream_unary (self , continuation , client_call_details , request_iterator ):
306
375
"""Intercept stream unary."""
307
376
return self ._intercept_call (continuation , client_call_details , request_iterator )
0 commit comments