Skip to content

feat: adds oauth2 authentication #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ __pycache__/
build/
dist/
.eggs/
*.egg-info
*.egg-info
poetry.lock
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ pip install netboxlabs-diode-sdk

### Environment variables

* `DIODE_API_KEY` - API key for the Diode service
* `DIODE_SDK_LOG_LEVEL` - Log level for the SDK (default: `INFO`)
* `DIODE_SENTRY_DSN` - Optional Sentry DSN for error reporting
* `DIODE_CLIENT_ID` - Client ID for OAuth2 authentication
* `DIODE_CLIENT_SECRET` - Client Secret for OAuth2 authentication

### Example

Expand Down Expand Up @@ -94,8 +95,7 @@ if __name__ == "__main__":

## Development notes

Code in `netboxlabs/diode/sdk/diode/*` is generated from Protocol Buffers definitions (will be published and referred
here soon).
Code in `netboxlabs/diode/sdk/diode/*` is generated from Protocol Buffers definitions (will be published and referenced here soon).

#### Linting

Expand All @@ -107,7 +107,7 @@ black netboxlabs/
#### Testing

```shell
pytest tests/
PYTHONPATH=$(pwd) pytest
```

## License
Expand Down
157 changes: 110 additions & 47 deletions netboxlabs/diode/sdk/client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
#!/usr/bin/env python
# Copyright 2024 NetBox Labs Inc
"""NetBox Labs, Diode - SDK - Client."""

import collections
import http.client
import json
import logging
import os
import platform
import ssl
import uuid
from collections.abc import Iterable
from urllib.parse import urlparse
from urllib.parse import urlencode, urlparse

import certifi
import grpc
Expand All @@ -18,9 +22,11 @@
from netboxlabs.diode.sdk.ingester import Entity
from netboxlabs.diode.sdk.version import version_semver

_DIODE_API_KEY_ENVVAR_NAME = "DIODE_API_KEY"
_MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES"
_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL"
_DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN"
_CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID"
_CLIENT_SECRET_ENVVAR_NAME = "DIODE_CLIENT_SECRET"
_DEFAULT_STREAM = "latest"
_LOGGER = logging.getLogger(__name__)

Expand All @@ -31,17 +37,6 @@ def _load_certs() -> bytes:
return f.read()


def _get_api_key(api_key: str | None = None) -> str:
"""Get API Key either from provided value or environment variable."""
if api_key is None:
api_key = os.getenv(_DIODE_API_KEY_ENVVAR_NAME)
if api_key is None:
raise DiodeConfigError(
f"api_key param or {_DIODE_API_KEY_ENVVAR_NAME} environment variable required"
)
return api_key


def parse_target(target: str) -> tuple[str, str, bool]:
"""Parse the target into authority, path and tls_verify."""
parsed_target = urlparse(target)
Expand All @@ -66,6 +61,21 @@ def _get_sentry_dsn(sentry_dsn: str | None = None) -> str | None:
return sentry_dsn


def _get_required_config_value(env_var_name: str, value: str | None = None) -> str:
"""Get required config value either from provided value or environment variable."""
if value is None:
value = os.getenv(env_var_name)
if value is None:
raise DiodeConfigError(f"parameter or {env_var_name} environment variable required")
return value

def _get_optional_config_value(env_var_name: str, value: str | None = None) -> str | None:
"""Get optional config value either from provided value or environment variable."""
if value is None:
value = os.getenv(env_var_name)
return value


class DiodeClient:
"""Diode Client."""

Expand All @@ -81,30 +91,41 @@ def __init__(
target: str,
app_name: str,
app_version: str,
api_key: str | None = None,
client_id: str | None = None,
client_secret: str | None = None,
sentry_dsn: str = None,
sentry_traces_sample_rate: float = 1.0,
sentry_profiles_sample_rate: float = 1.0,
max_auth_retries: int = 3,
):
"""Initiate a new client."""
log_level = os.getenv(_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME, "INFO").upper()
logging.basicConfig(level=log_level)

self._max_auth_retries = _get_optional_config_value(_MAX_RETRIES_ENVVAR_NAME, max_auth_retries)
self._target, self._path, self._tls_verify = parse_target(target)
self._app_name = app_name
self._app_version = app_version
self._platform = platform.platform()
self._python_version = platform.python_version()

api_key = _get_api_key(api_key)
# Read client credentials from environment variables
self._client_id = _get_required_config_value(_CLIENT_ID_ENVVAR_NAME, client_id)
self._client_secret = _get_required_config_value(_CLIENT_SECRET_ENVVAR_NAME, client_secret)


self._metadata = (
("diode-api-key", api_key),
("platform", self._platform),
("python-version", self._python_version),
)

self._authenticate()

channel_opts = (
("grpc.primary_user_agent", f"{self._name}/{self._version} {self._app_name}/{self._app_version}"),
(
"grpc.primary_user_agent",
f"{self._name}/{self._version} {self._app_name}/{self._app_version}",
),
)

if self._tls_verify:
Expand All @@ -129,9 +150,7 @@ def __init__(
_LOGGER.debug(f"Setting up gRPC interceptor for path: {self._path}")
rpc_method_interceptor = DiodeMethodClientInterceptor(subpath=self._path)

intercept_channel = grpc.intercept_channel(
self._channel, rpc_method_interceptor
)
intercept_channel = grpc.intercept_channel(self._channel, rpc_method_interceptor)
channel = intercept_channel

self._stub = ingester_pb2_grpc.IngesterServiceStub(channel)
Expand All @@ -140,9 +159,7 @@ def __init__(

if self._sentry_dsn is not None:
_LOGGER.debug("Setting up Sentry")
self._setup_sentry(
self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate
)
self._setup_sentry(self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate)

@property
def name(self) -> str:
Expand Down Expand Up @@ -202,24 +219,28 @@ def ingest(
stream: str | None = _DEFAULT_STREAM,
) -> ingester_pb2.IngestResponse:
"""Ingest entities."""
try:
request = ingester_pb2.IngestRequest(
stream=stream,
id=str(uuid.uuid4()),
entities=entities,
sdk_name=self.name,
sdk_version=self.version,
producer_app_name=self.app_name,
producer_app_version=self.app_version,
)

return self._stub.Ingest(request, metadata=self._metadata)
except grpc.RpcError as err:
raise DiodeClientError(err) from err

def _setup_sentry(
self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float
):
for attempt in range(self._max_auth_retries):
try:
request = ingester_pb2.IngestRequest(
stream=stream,
id=str(uuid.uuid4()),
entities=entities,
sdk_name=self.name,
sdk_version=self.version,
producer_app_name=self.app_name,
producer_app_version=self.app_version,
)
return self._stub.Ingest(request, metadata=self._metadata)
except grpc.RpcError as err:
if err.code() == grpc.StatusCode.UNAUTHENTICATED:
if attempt < self._max_auth_retries - 1:
_LOGGER.info(f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}")
self._authenticate()
continue
raise DiodeClientError(err) from err
return RuntimeError("Max retries exceeded")

def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float):
sentry_sdk.init(
dsn=dsn,
release=self.version,
Expand All @@ -234,6 +255,52 @@ def _setup_sentry(
sentry_sdk.set_tag("platform", self._platform)
sentry_sdk.set_tag("python_version", self._python_version)

def _authenticate(self):
authentication_client = _DiodeAuthentication(self._target, self._tls_verify, self._client_id, self._client_secret, self._path)
access_token = authentication_client.authenticate()
self._metadata = list(filter(lambda x: x[0] != "authorization", self._metadata)) + \
[("authorization", f"Bearer {access_token}")]


class _DiodeAuthentication:
def __init__(self, target: str, path: str, tls_verify: bool, client_id: str, client_secret: str):
self._target = target
self._tls_verify = tls_verify
self._client_id = client_id
self._client_secret = client_secret
self._path = path

def authenticate(self) -> str:
"""Request an OAuth2 token using client credentials and return it."""
if self._tls_verify:
conn = http.client.HTTPSConnection(
self._target,
context=None if self._tls_verify else ssl._create_unverified_context(),
)
else:
conn = http.client.HTTPConnection(
self._target,
)
headers = {"Content-type": "application/x-www-form-urlencoded"}
data = urlencode(
{
"grant_type": "client_credentials",
"client_id": self._client_id,
"client_secret": self._client_secret,
}
)
conn.request("POST", f"{self._path}/auth/token", data, headers)
response = conn.getresponse()
if response.status != 200:
raise DiodeConfigError(f"Failed to obtain access token: {response.reason}")
token_info = json.loads(response.read().decode())
access_token = token_info.get("access_token")
if not access_token:
raise DiodeConfigError(f"Failed to obtain access token for client {self._client_id}")

_LOGGER.debug(f"Access token obtained for client {self._client_id}")
return access_token


class _ClientCallDetails(
collections.namedtuple(
Expand All @@ -259,9 +326,7 @@ class _ClientCallDetails(
pass


class DiodeMethodClientInterceptor(
grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor
):
class DiodeMethodClientInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor):
"""
Diode Method Client Interceptor class.

Expand Down Expand Up @@ -300,8 +365,6 @@ def intercept_unary_unary(self, continuation, client_call_details, request):
"""Intercept unary unary."""
return self._intercept_call(continuation, client_call_details, request)

def intercept_stream_unary(
self, continuation, client_call_details, request_iterator
):
def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
"""Intercept stream unary."""
return self._intercept_call(continuation, client_call_details, request_iterator)
2 changes: 2 additions & 0 deletions netboxlabs/diode/sdk/ingester.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import datetime
import re
from typing import Any

from google.protobuf import timestamp_pb2 as _timestamp_pb2

import netboxlabs.diode.sdk.diode.v1.ingester_pb2 as pb

PRIMARY_VALUE_MAP = {
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [

[project.optional-dependencies] # Optional
dev = ["black", "check-manifest", "ruff"]
test = ["coverage", "pytest", "pytest-cov"]
test = ["coverage", "pytest", "pytest-cov==6.0.0"]

[tool.coverage.run]
omit = [
Expand Down
Loading