Skip to content

feat: add OAuth2 diode:ingest scope to authentication process #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 61 additions & 21 deletions netboxlabs/diode/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN"
_CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID"
_CLIENT_SECRET_ENVVAR_NAME = "DIODE_CLIENT_SECRET"
_INGEST_SCOPE = "diode:ingest"
_DEFAULT_STREAM = "latest"
_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,10 +67,15 @@ def _get_required_config_value(env_var_name: str, value: str | None = None) -> s
if value is None:
value = os.getenv(env_var_name)
if value is None:
raise DiodeConfigError(f"parameter or {env_var_name} environment variable required")
raise DiodeConfigError(
f"parameter or {env_var_name} environment variable required"
)
return value

def _get_optional_config_value(env_var_name: str, value: str | None = None) -> str | None:

def _get_optional_config_value(
env_var_name: str, value: str | None = None
) -> str | None:
"""Get optional config value either from provided value or environment variable."""
if value is None:
value = os.getenv(env_var_name)
Expand Down Expand Up @@ -102,7 +108,9 @@ def __init__(
log_level = os.getenv(_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME, "INFO").upper()
logging.basicConfig(level=log_level)

self._max_auth_retries = _get_optional_config_value(_MAX_RETRIES_ENVVAR_NAME, max_auth_retries)
self._max_auth_retries = _get_optional_config_value(
_MAX_RETRIES_ENVVAR_NAME, max_auth_retries
)
self._target, self._path, self._tls_verify = parse_target(target)
self._app_name = app_name
self._app_version = app_version
Expand All @@ -111,14 +119,16 @@ def __init__(

# Read client credentials from environment variables
self._client_id = _get_required_config_value(_CLIENT_ID_ENVVAR_NAME, client_id)
self._client_secret = _get_required_config_value(_CLIENT_SECRET_ENVVAR_NAME, client_secret)
self._client_secret = _get_required_config_value(
_CLIENT_SECRET_ENVVAR_NAME, client_secret
)

self._metadata = (
("platform", self._platform),
("python-version", self._python_version),
)

self._authenticate()
self._authenticate(_INGEST_SCOPE)

channel_opts = (
(
Expand Down Expand Up @@ -149,7 +159,9 @@ def __init__(
_LOGGER.debug(f"Setting up gRPC interceptor for path: {self._path}")
rpc_method_interceptor = DiodeMethodClientInterceptor(subpath=self._path)

intercept_channel = grpc.intercept_channel(self._channel, rpc_method_interceptor)
intercept_channel = grpc.intercept_channel(
self._channel, rpc_method_interceptor
)
channel = intercept_channel

self._stub = ingester_pb2_grpc.IngesterServiceStub(channel)
Expand All @@ -158,7 +170,9 @@ def __init__(

if self._sentry_dsn is not None:
_LOGGER.debug("Setting up Sentry")
self._setup_sentry(self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate)
self._setup_sentry(
self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate
)

@property
def name(self) -> str:
Expand Down Expand Up @@ -233,13 +247,17 @@ def ingest(
except grpc.RpcError as err:
if err.code() == grpc.StatusCode.UNAUTHENTICATED:
if attempt < self._max_auth_retries - 1:
_LOGGER.info(f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}")
self._authenticate()
_LOGGER.info(
f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}"
)
self._authenticate(_INGEST_SCOPE)
continue
raise DiodeClientError(err) from err
raise RuntimeError("Max retries exceeded")

def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float):
def _setup_sentry(
self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float
):
sentry_sdk.init(
dsn=dsn,
release=self.version,
Expand All @@ -254,20 +272,37 @@ def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rat
sentry_sdk.set_tag("platform", self._platform)
sentry_sdk.set_tag("python_version", self._python_version)

def _authenticate(self):
authentication_client = _DiodeAuthentication(self._target, self._path, self._tls_verify, self._client_id, self._client_secret)
def _authenticate(self, scope: str):
authentication_client = _DiodeAuthentication(
self._target,
self._path,
self._tls_verify,
self._client_id,
self._client_secret,
scope,
)
access_token = authentication_client.authenticate()
self._metadata = list(filter(lambda x: x[0] != "authorization", self._metadata)) + \
[("authorization", f"Bearer {access_token}")]
self._metadata = list(
filter(lambda x: x[0] != "authorization", self._metadata)
) + [("authorization", f"Bearer {access_token}")]


class _DiodeAuthentication:
def __init__(self, target: str, path: str, tls_verify: bool, client_id: str, client_secret: str):
def __init__(
self,
target: str,
path: str,
tls_verify: bool,
client_id: str,
client_secret: str,
scope: str,
):
self._target = target
self._tls_verify = tls_verify
self._client_id = client_id
self._client_secret = client_secret
self._path = path
self._scope = scope

def authenticate(self) -> str:
"""Request an OAuth2 token using client credentials and return it."""
Expand All @@ -286,6 +321,7 @@ def authenticate(self) -> str:
"grant_type": "client_credentials",
"client_id": self._client_id,
"client_secret": self._client_secret,
"scope": self._scope,
}
)
url = self._get_auth_url()
Expand All @@ -299,15 +335,17 @@ def authenticate(self) -> str:
token_info = json.loads(response.read().decode())
access_token = token_info.get("access_token")
if not access_token:
raise DiodeConfigError(f"Failed to obtain access token for client {self._client_id}")
raise DiodeConfigError(
f"Failed to obtain access token for client {self._client_id}"
)

_LOGGER.debug(f"Access token obtained for client {self._client_id}")
return access_token

def _get_auth_url(self) -> str:
"""Construct the authentication URL, handling trailing slashes in the path."""
# Ensure the path does not have trailing slashes
path = self._path.rstrip('/') if self._path else ''
path = self._path.rstrip("/") if self._path else ""
return f"{path}/auth/token"


Expand All @@ -332,10 +370,10 @@ class _ClientCallDetails(

"""

pass


class DiodeMethodClientInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor):
class DiodeMethodClientInterceptor(
grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor
):
"""
Diode Method Client Interceptor class.

Expand Down Expand Up @@ -374,6 +412,8 @@ def intercept_unary_unary(self, continuation, client_call_details, request):
"""Intercept unary unary."""
return self._intercept_call(continuation, client_call_details, request)

def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
def intercept_stream_unary(
self, continuation, client_call_details, request_iterator
):
"""Intercept stream unary."""
return self._intercept_call(continuation, client_call_details, request_iterator)
4 changes: 4 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ def test_diode_authentication_success(mock_diode_authentication):
tls_verify=False,
client_id="test_client_id",
client_secret="test_client_secret",
scope="diode:ingest",
)
with mock.patch("http.client.HTTPConnection") as mock_http_conn:
mock_conn_instance = mock_http_conn.return_value
Expand All @@ -578,6 +579,7 @@ def test_diode_authentication_failure(mock_diode_authentication):
tls_verify=False,
client_id="test_client_id",
client_secret="test_client_secret",
scope="diode:ingest",
)
with mock.patch("http.client.HTTPConnection") as mock_http_conn:
mock_conn_instance = mock_http_conn.return_value
Expand Down Expand Up @@ -605,6 +607,7 @@ def test_diode_authentication_url_with_path(mock_diode_authentication, path):
tls_verify=False,
client_id="test_client_id",
client_secret="test_client_secret",
scope="diode:ingest",
)
with mock.patch("http.client.HTTPConnection") as mock_http_conn:
mock_conn_instance = mock_http_conn.return_value
Expand All @@ -622,6 +625,7 @@ def test_diode_authentication_request_exception(mock_diode_authentication):
tls_verify=False,
client_id="test_client_id",
client_secret="test_client_secret",
scope="diode:ingest",
)
with mock.patch("http.client.HTTPConnection") as mock_http_conn:
mock_conn_instance = mock_http_conn.return_value
Expand Down