Skip to content

Commit 5390507

Browse files
authored
feat: add OAuth2 diode:ingest scope to authentication process (#43)
Signed-off-by: Michal Fiedorowicz <[email protected]>
1 parent 6c374d1 commit 5390507

File tree

2 files changed

+65
-21
lines changed

2 files changed

+65
-21
lines changed

netboxlabs/diode/sdk/client.py

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
_DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN"
2828
_CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID"
2929
_CLIENT_SECRET_ENVVAR_NAME = "DIODE_CLIENT_SECRET"
30+
_INGEST_SCOPE = "diode:ingest"
3031
_DEFAULT_STREAM = "latest"
3132
_LOGGER = logging.getLogger(__name__)
3233

@@ -66,10 +67,15 @@ def _get_required_config_value(env_var_name: str, value: str | None = None) -> s
6667
if value is None:
6768
value = os.getenv(env_var_name)
6869
if value is None:
69-
raise DiodeConfigError(f"parameter or {env_var_name} environment variable required")
70+
raise DiodeConfigError(
71+
f"parameter or {env_var_name} environment variable required"
72+
)
7073
return value
7174

72-
def _get_optional_config_value(env_var_name: str, value: str | None = None) -> str | None:
75+
76+
def _get_optional_config_value(
77+
env_var_name: str, value: str | None = None
78+
) -> str | None:
7379
"""Get optional config value either from provided value or environment variable."""
7480
if value is None:
7581
value = os.getenv(env_var_name)
@@ -102,7 +108,9 @@ def __init__(
102108
log_level = os.getenv(_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME, "INFO").upper()
103109
logging.basicConfig(level=log_level)
104110

105-
self._max_auth_retries = _get_optional_config_value(_MAX_RETRIES_ENVVAR_NAME, max_auth_retries)
111+
self._max_auth_retries = _get_optional_config_value(
112+
_MAX_RETRIES_ENVVAR_NAME, max_auth_retries
113+
)
106114
self._target, self._path, self._tls_verify = parse_target(target)
107115
self._app_name = app_name
108116
self._app_version = app_version
@@ -111,14 +119,16 @@ def __init__(
111119

112120
# Read client credentials from environment variables
113121
self._client_id = _get_required_config_value(_CLIENT_ID_ENVVAR_NAME, client_id)
114-
self._client_secret = _get_required_config_value(_CLIENT_SECRET_ENVVAR_NAME, client_secret)
122+
self._client_secret = _get_required_config_value(
123+
_CLIENT_SECRET_ENVVAR_NAME, client_secret
124+
)
115125

116126
self._metadata = (
117127
("platform", self._platform),
118128
("python-version", self._python_version),
119129
)
120130

121-
self._authenticate()
131+
self._authenticate(_INGEST_SCOPE)
122132

123133
channel_opts = (
124134
(
@@ -149,7 +159,9 @@ def __init__(
149159
_LOGGER.debug(f"Setting up gRPC interceptor for path: {self._path}")
150160
rpc_method_interceptor = DiodeMethodClientInterceptor(subpath=self._path)
151161

152-
intercept_channel = grpc.intercept_channel(self._channel, rpc_method_interceptor)
162+
intercept_channel = grpc.intercept_channel(
163+
self._channel, rpc_method_interceptor
164+
)
153165
channel = intercept_channel
154166

155167
self._stub = ingester_pb2_grpc.IngesterServiceStub(channel)
@@ -158,7 +170,9 @@ def __init__(
158170

159171
if self._sentry_dsn is not None:
160172
_LOGGER.debug("Setting up Sentry")
161-
self._setup_sentry(self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate)
173+
self._setup_sentry(
174+
self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate
175+
)
162176

163177
@property
164178
def name(self) -> str:
@@ -233,13 +247,17 @@ def ingest(
233247
except grpc.RpcError as err:
234248
if err.code() == grpc.StatusCode.UNAUTHENTICATED:
235249
if attempt < self._max_auth_retries - 1:
236-
_LOGGER.info(f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}")
237-
self._authenticate()
250+
_LOGGER.info(
251+
f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}"
252+
)
253+
self._authenticate(_INGEST_SCOPE)
238254
continue
239255
raise DiodeClientError(err) from err
240256
raise RuntimeError("Max retries exceeded")
241257

242-
def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float):
258+
def _setup_sentry(
259+
self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float
260+
):
243261
sentry_sdk.init(
244262
dsn=dsn,
245263
release=self.version,
@@ -254,20 +272,37 @@ def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rat
254272
sentry_sdk.set_tag("platform", self._platform)
255273
sentry_sdk.set_tag("python_version", self._python_version)
256274

257-
def _authenticate(self):
258-
authentication_client = _DiodeAuthentication(self._target, self._path, self._tls_verify, self._client_id, self._client_secret)
275+
def _authenticate(self, scope: str):
276+
authentication_client = _DiodeAuthentication(
277+
self._target,
278+
self._path,
279+
self._tls_verify,
280+
self._client_id,
281+
self._client_secret,
282+
scope,
283+
)
259284
access_token = authentication_client.authenticate()
260-
self._metadata = list(filter(lambda x: x[0] != "authorization", self._metadata)) + \
261-
[("authorization", f"Bearer {access_token}")]
285+
self._metadata = list(
286+
filter(lambda x: x[0] != "authorization", self._metadata)
287+
) + [("authorization", f"Bearer {access_token}")]
262288

263289

264290
class _DiodeAuthentication:
265-
def __init__(self, target: str, path: str, tls_verify: bool, client_id: str, client_secret: str):
291+
def __init__(
292+
self,
293+
target: str,
294+
path: str,
295+
tls_verify: bool,
296+
client_id: str,
297+
client_secret: str,
298+
scope: str,
299+
):
266300
self._target = target
267301
self._tls_verify = tls_verify
268302
self._client_id = client_id
269303
self._client_secret = client_secret
270304
self._path = path
305+
self._scope = scope
271306

272307
def authenticate(self) -> str:
273308
"""Request an OAuth2 token using client credentials and return it."""
@@ -286,6 +321,7 @@ def authenticate(self) -> str:
286321
"grant_type": "client_credentials",
287322
"client_id": self._client_id,
288323
"client_secret": self._client_secret,
324+
"scope": self._scope,
289325
}
290326
)
291327
url = self._get_auth_url()
@@ -299,15 +335,17 @@ def authenticate(self) -> str:
299335
token_info = json.loads(response.read().decode())
300336
access_token = token_info.get("access_token")
301337
if not access_token:
302-
raise DiodeConfigError(f"Failed to obtain access token for client {self._client_id}")
338+
raise DiodeConfigError(
339+
f"Failed to obtain access token for client {self._client_id}"
340+
)
303341

304342
_LOGGER.debug(f"Access token obtained for client {self._client_id}")
305343
return access_token
306344

307345
def _get_auth_url(self) -> str:
308346
"""Construct the authentication URL, handling trailing slashes in the path."""
309347
# Ensure the path does not have trailing slashes
310-
path = self._path.rstrip('/') if self._path else ''
348+
path = self._path.rstrip("/") if self._path else ""
311349
return f"{path}/auth/token"
312350

313351

@@ -332,10 +370,10 @@ class _ClientCallDetails(
332370
333371
"""
334372

335-
pass
336-
337373

338-
class DiodeMethodClientInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor):
374+
class DiodeMethodClientInterceptor(
375+
grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor
376+
):
339377
"""
340378
Diode Method Client Interceptor class.
341379
@@ -374,6 +412,8 @@ def intercept_unary_unary(self, continuation, client_call_details, request):
374412
"""Intercept unary unary."""
375413
return self._intercept_call(continuation, client_call_details, request)
376414

377-
def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
415+
def intercept_stream_unary(
416+
self, continuation, client_call_details, request_iterator
417+
):
378418
"""Intercept stream unary."""
379419
return self._intercept_call(continuation, client_call_details, request_iterator)

tests/test_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ def test_diode_authentication_success(mock_diode_authentication):
560560
tls_verify=False,
561561
client_id="test_client_id",
562562
client_secret="test_client_secret",
563+
scope="diode:ingest",
563564
)
564565
with mock.patch("http.client.HTTPConnection") as mock_http_conn:
565566
mock_conn_instance = mock_http_conn.return_value
@@ -578,6 +579,7 @@ def test_diode_authentication_failure(mock_diode_authentication):
578579
tls_verify=False,
579580
client_id="test_client_id",
580581
client_secret="test_client_secret",
582+
scope="diode:ingest",
581583
)
582584
with mock.patch("http.client.HTTPConnection") as mock_http_conn:
583585
mock_conn_instance = mock_http_conn.return_value
@@ -605,6 +607,7 @@ def test_diode_authentication_url_with_path(mock_diode_authentication, path):
605607
tls_verify=False,
606608
client_id="test_client_id",
607609
client_secret="test_client_secret",
610+
scope="diode:ingest",
608611
)
609612
with mock.patch("http.client.HTTPConnection") as mock_http_conn:
610613
mock_conn_instance = mock_http_conn.return_value
@@ -622,6 +625,7 @@ def test_diode_authentication_request_exception(mock_diode_authentication):
622625
tls_verify=False,
623626
client_id="test_client_id",
624627
client_secret="test_client_secret",
628+
scope="diode:ingest",
625629
)
626630
with mock.patch("http.client.HTTPConnection") as mock_http_conn:
627631
mock_conn_instance = mock_http_conn.return_value

0 commit comments

Comments
 (0)