Skip to content

Commit 07cccc8

Browse files
authored
feat: adds oauth2 authentication (#40)
2 parents 3bc433a + a285af3 commit 07cccc8

File tree

6 files changed

+331
-132
lines changed

6 files changed

+331
-132
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ __pycache__/
2424
build/
2525
dist/
2626
.eggs/
27-
*.egg-info
27+
*.egg-info
28+
poetry.lock

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ pip install netboxlabs-diode-sdk
2020

2121
### Environment variables
2222

23-
* `DIODE_API_KEY` - API key for the Diode service
2423
* `DIODE_SDK_LOG_LEVEL` - Log level for the SDK (default: `INFO`)
2524
* `DIODE_SENTRY_DSN` - Optional Sentry DSN for error reporting
25+
* `DIODE_CLIENT_ID` - Client ID for OAuth2 authentication
26+
* `DIODE_CLIENT_SECRET` - Client Secret for OAuth2 authentication
2627

2728
### Example
2829

@@ -94,8 +95,7 @@ if __name__ == "__main__":
9495

9596
## Development notes
9697

97-
Code in `netboxlabs/diode/sdk/diode/*` is generated from Protocol Buffers definitions (will be published and referred
98-
here soon).
98+
Code in `netboxlabs/diode/sdk/diode/*` is generated from Protocol Buffers definitions (will be published and referenced here soon).
9999

100100
#### Linting
101101

@@ -107,7 +107,7 @@ black netboxlabs/
107107
#### Testing
108108

109109
```shell
110-
pytest tests/
110+
PYTHONPATH=$(pwd) pytest
111111
```
112112

113113
## License

netboxlabs/diode/sdk/client.py

Lines changed: 116 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
#!/usr/bin/env python
22
# Copyright 2024 NetBox Labs Inc
33
"""NetBox Labs, Diode - SDK - Client."""
4+
45
import collections
6+
import http.client
7+
import json
58
import logging
69
import os
710
import platform
11+
import ssl
812
import uuid
913
from collections.abc import Iterable
10-
from urllib.parse import urlparse
14+
from urllib.parse import urlencode, urlparse
1115

1216
import certifi
1317
import grpc
@@ -18,9 +22,11 @@
1822
from netboxlabs.diode.sdk.ingester import Entity
1923
from netboxlabs.diode.sdk.version import version_semver
2024

21-
_DIODE_API_KEY_ENVVAR_NAME = "DIODE_API_KEY"
25+
_MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES"
2226
_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL"
2327
_DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN"
28+
_CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID"
29+
_CLIENT_SECRET_ENVVAR_NAME = "DIODE_CLIENT_SECRET"
2430
_DEFAULT_STREAM = "latest"
2531
_LOGGER = logging.getLogger(__name__)
2632

@@ -31,17 +37,6 @@ def _load_certs() -> bytes:
3137
return f.read()
3238

3339

34-
def _get_api_key(api_key: str | None = None) -> str:
35-
"""Get API Key either from provided value or environment variable."""
36-
if api_key is None:
37-
api_key = os.getenv(_DIODE_API_KEY_ENVVAR_NAME)
38-
if api_key is None:
39-
raise DiodeConfigError(
40-
f"api_key param or {_DIODE_API_KEY_ENVVAR_NAME} environment variable required"
41-
)
42-
return api_key
43-
44-
4540
def parse_target(target: str) -> tuple[str, str, bool]:
4641
"""Parse the target into authority, path and tls_verify."""
4742
parsed_target = urlparse(target)
@@ -66,6 +61,21 @@ def _get_sentry_dsn(sentry_dsn: str | None = None) -> str | None:
6661
return sentry_dsn
6762

6863

64+
def _get_required_config_value(env_var_name: str, value: str | None = None) -> str:
65+
"""Get required config value either from provided value or environment variable."""
66+
if value is None:
67+
value = os.getenv(env_var_name)
68+
if value is None:
69+
raise DiodeConfigError(f"parameter or {env_var_name} environment variable required")
70+
return value
71+
72+
def _get_optional_config_value(env_var_name: str, value: str | None = None) -> str | None:
73+
"""Get optional config value either from provided value or environment variable."""
74+
if value is None:
75+
value = os.getenv(env_var_name)
76+
return value
77+
78+
6979
class DiodeClient:
7080
"""Diode Client."""
7181

@@ -81,30 +91,40 @@ def __init__(
8191
target: str,
8292
app_name: str,
8393
app_version: str,
84-
api_key: str | None = None,
94+
client_id: str | None = None,
95+
client_secret: str | None = None,
8596
sentry_dsn: str = None,
8697
sentry_traces_sample_rate: float = 1.0,
8798
sentry_profiles_sample_rate: float = 1.0,
99+
max_auth_retries: int = 3,
88100
):
89101
"""Initiate a new client."""
90102
log_level = os.getenv(_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME, "INFO").upper()
91103
logging.basicConfig(level=log_level)
92104

105+
self._max_auth_retries = _get_optional_config_value(_MAX_RETRIES_ENVVAR_NAME, max_auth_retries)
93106
self._target, self._path, self._tls_verify = parse_target(target)
94107
self._app_name = app_name
95108
self._app_version = app_version
96109
self._platform = platform.platform()
97110
self._python_version = platform.python_version()
98111

99-
api_key = _get_api_key(api_key)
112+
# Read client credentials from environment variables
113+
self._client_id = _get_required_config_value(_CLIENT_ID_ENVVAR_NAME, client_id)
114+
self._client_secret = _get_required_config_value(_CLIENT_SECRET_ENVVAR_NAME, client_secret)
115+
100116
self._metadata = (
101-
("diode-api-key", api_key),
102117
("platform", self._platform),
103118
("python-version", self._python_version),
104119
)
105120

121+
self._authenticate()
122+
106123
channel_opts = (
107-
("grpc.primary_user_agent", f"{self._name}/{self._version} {self._app_name}/{self._app_version}"),
124+
(
125+
"grpc.primary_user_agent",
126+
f"{self._name}/{self._version} {self._app_name}/{self._app_version}",
127+
),
108128
)
109129

110130
if self._tls_verify:
@@ -129,9 +149,7 @@ def __init__(
129149
_LOGGER.debug(f"Setting up gRPC interceptor for path: {self._path}")
130150
rpc_method_interceptor = DiodeMethodClientInterceptor(subpath=self._path)
131151

132-
intercept_channel = grpc.intercept_channel(
133-
self._channel, rpc_method_interceptor
134-
)
152+
intercept_channel = grpc.intercept_channel(self._channel, rpc_method_interceptor)
135153
channel = intercept_channel
136154

137155
self._stub = ingester_pb2_grpc.IngesterServiceStub(channel)
@@ -140,9 +158,7 @@ def __init__(
140158

141159
if self._sentry_dsn is not None:
142160
_LOGGER.debug("Setting up Sentry")
143-
self._setup_sentry(
144-
self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate
145-
)
161+
self._setup_sentry(self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate)
146162

147163
@property
148164
def name(self) -> str:
@@ -202,24 +218,28 @@ def ingest(
202218
stream: str | None = _DEFAULT_STREAM,
203219
) -> ingester_pb2.IngestResponse:
204220
"""Ingest entities."""
205-
try:
206-
request = ingester_pb2.IngestRequest(
207-
stream=stream,
208-
id=str(uuid.uuid4()),
209-
entities=entities,
210-
sdk_name=self.name,
211-
sdk_version=self.version,
212-
producer_app_name=self.app_name,
213-
producer_app_version=self.app_version,
214-
)
215-
216-
return self._stub.Ingest(request, metadata=self._metadata)
217-
except grpc.RpcError as err:
218-
raise DiodeClientError(err) from err
219-
220-
def _setup_sentry(
221-
self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float
222-
):
221+
for attempt in range(self._max_auth_retries):
222+
try:
223+
request = ingester_pb2.IngestRequest(
224+
stream=stream,
225+
id=str(uuid.uuid4()),
226+
entities=entities,
227+
sdk_name=self.name,
228+
sdk_version=self.version,
229+
producer_app_name=self.app_name,
230+
producer_app_version=self.app_version,
231+
)
232+
return self._stub.Ingest(request, metadata=self._metadata)
233+
except grpc.RpcError as err:
234+
if err.code() == grpc.StatusCode.UNAUTHENTICATED:
235+
if attempt < self._max_auth_retries - 1:
236+
_LOGGER.info(f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}")
237+
self._authenticate()
238+
continue
239+
raise DiodeClientError(err) from err
240+
return RuntimeError("Max retries exceeded")
241+
242+
def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float):
223243
sentry_sdk.init(
224244
dsn=dsn,
225245
release=self.version,
@@ -234,6 +254,59 @@ def _setup_sentry(
234254
sentry_sdk.set_tag("platform", self._platform)
235255
sentry_sdk.set_tag("python_version", self._python_version)
236256

257+
def _authenticate(self):
258+
authentication_client = _DiodeAuthentication(self._target, self._path, self._tls_verify, self._client_id, self._client_secret)
259+
access_token = authentication_client.authenticate()
260+
self._metadata = list(filter(lambda x: x[0] != "authorization", self._metadata)) + \
261+
[("authorization", f"Bearer {access_token}")]
262+
263+
264+
class _DiodeAuthentication:
265+
def __init__(self, target: str, path: str, tls_verify: bool, client_id: str, client_secret: str):
266+
self._target = target
267+
self._tls_verify = tls_verify
268+
self._client_id = client_id
269+
self._client_secret = client_secret
270+
self._path = path
271+
272+
def authenticate(self) -> str:
273+
"""Request an OAuth2 token using client credentials and return it."""
274+
if self._tls_verify:
275+
conn = http.client.HTTPSConnection(
276+
self._target,
277+
context=None if self._tls_verify else ssl._create_unverified_context(),
278+
)
279+
else:
280+
conn = http.client.HTTPConnection(
281+
self._target,
282+
)
283+
headers = {"Content-type": "application/x-www-form-urlencoded"}
284+
data = urlencode(
285+
{
286+
"grant_type": "client_credentials",
287+
"client_id": self._client_id,
288+
"client_secret": self._client_secret,
289+
}
290+
)
291+
url = self._get_auth_url()
292+
conn.request("POST", url, data, headers)
293+
response = conn.getresponse()
294+
if response.status != 200:
295+
raise DiodeConfigError(f"Failed to obtain access token: {response.reason}")
296+
token_info = json.loads(response.read().decode())
297+
access_token = token_info.get("access_token")
298+
if not access_token:
299+
raise DiodeConfigError(f"Failed to obtain access token for client {self._client_id}")
300+
301+
_LOGGER.debug(f"Access token obtained for client {self._client_id}")
302+
return access_token
303+
304+
def _get_auth_url(self) -> str:
305+
"""Construct the authentication URL, handling trailing slashes in the path."""
306+
# Ensure the path does not have trailing slashes
307+
path = self._path.rstrip('/') if self._path else ''
308+
return f"{path}/auth/token"
309+
237310

238311
class _ClientCallDetails(
239312
collections.namedtuple(
@@ -259,9 +332,7 @@ class _ClientCallDetails(
259332
pass
260333

261334

262-
class DiodeMethodClientInterceptor(
263-
grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor
264-
):
335+
class DiodeMethodClientInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor):
265336
"""
266337
Diode Method Client Interceptor class.
267338
@@ -300,8 +371,6 @@ def intercept_unary_unary(self, continuation, client_call_details, request):
300371
"""Intercept unary unary."""
301372
return self._intercept_call(continuation, client_call_details, request)
302373

303-
def intercept_stream_unary(
304-
self, continuation, client_call_details, request_iterator
305-
):
374+
def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
306375
"""Intercept stream unary."""
307376
return self._intercept_call(continuation, client_call_details, request_iterator)

netboxlabs/diode/sdk/ingester.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
import datetime
1212
import re
1313
from typing import Any
14+
1415
from google.protobuf import timestamp_pb2 as _timestamp_pb2
16+
1517
import netboxlabs.diode.sdk.diode.v1.ingester_pb2 as pb
1618

1719
PRIMARY_VALUE_MAP = {

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ dependencies = [
3232

3333
[project.optional-dependencies] # Optional
3434
dev = ["black", "check-manifest", "ruff"]
35-
test = ["coverage", "pytest", "pytest-cov"]
35+
test = ["coverage", "pytest", "pytest-cov==6.0.0"]
3636

3737
[tool.coverage.run]
3838
omit = [
39+
"*/netboxlabs/diode/sdk/ingester.py",
3940
"*/netboxlabs/diode/sdk/diode/*",
4041
"*/netboxlabs/diode/sdk/validate/*",
4142
"*/tests/*",

0 commit comments

Comments
 (0)