diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index f9f94b1..3c650f1 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -27,6 +27,7 @@ _DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN" _CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID" _CLIENT_SECRET_ENVVAR_NAME = "DIODE_CLIENT_SECRET" +_INGEST_SCOPE = "diode:ingest" _DEFAULT_STREAM = "latest" _LOGGER = logging.getLogger(__name__) @@ -66,10 +67,15 @@ def _get_required_config_value(env_var_name: str, value: str | None = None) -> s if value is None: value = os.getenv(env_var_name) if value is None: - raise DiodeConfigError(f"parameter or {env_var_name} environment variable required") + raise DiodeConfigError( + f"parameter or {env_var_name} environment variable required" + ) return value -def _get_optional_config_value(env_var_name: str, value: str | None = None) -> str | None: + +def _get_optional_config_value( + env_var_name: str, value: str | None = None +) -> str | None: """Get optional config value either from provided value or environment variable.""" if value is None: value = os.getenv(env_var_name) @@ -102,7 +108,9 @@ def __init__( log_level = os.getenv(_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME, "INFO").upper() logging.basicConfig(level=log_level) - self._max_auth_retries = _get_optional_config_value(_MAX_RETRIES_ENVVAR_NAME, max_auth_retries) + self._max_auth_retries = _get_optional_config_value( + _MAX_RETRIES_ENVVAR_NAME, max_auth_retries + ) self._target, self._path, self._tls_verify = parse_target(target) self._app_name = app_name self._app_version = app_version @@ -111,14 +119,16 @@ def __init__( # Read client credentials from environment variables self._client_id = _get_required_config_value(_CLIENT_ID_ENVVAR_NAME, client_id) - self._client_secret = _get_required_config_value(_CLIENT_SECRET_ENVVAR_NAME, client_secret) + self._client_secret = _get_required_config_value( + _CLIENT_SECRET_ENVVAR_NAME, client_secret + ) self._metadata = ( ("platform", self._platform), ("python-version", self._python_version), ) - self._authenticate() + self._authenticate(_INGEST_SCOPE) channel_opts = ( ( @@ -149,7 +159,9 @@ def __init__( _LOGGER.debug(f"Setting up gRPC interceptor for path: {self._path}") rpc_method_interceptor = DiodeMethodClientInterceptor(subpath=self._path) - intercept_channel = grpc.intercept_channel(self._channel, rpc_method_interceptor) + intercept_channel = grpc.intercept_channel( + self._channel, rpc_method_interceptor + ) channel = intercept_channel self._stub = ingester_pb2_grpc.IngesterServiceStub(channel) @@ -158,7 +170,9 @@ def __init__( if self._sentry_dsn is not None: _LOGGER.debug("Setting up Sentry") - self._setup_sentry(self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate) + self._setup_sentry( + self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate + ) @property def name(self) -> str: @@ -233,13 +247,17 @@ def ingest( except grpc.RpcError as err: if err.code() == grpc.StatusCode.UNAUTHENTICATED: if attempt < self._max_auth_retries - 1: - _LOGGER.info(f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}") - self._authenticate() + _LOGGER.info( + f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}" + ) + self._authenticate(_INGEST_SCOPE) continue raise DiodeClientError(err) from err raise RuntimeError("Max retries exceeded") - def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float): + def _setup_sentry( + self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float + ): sentry_sdk.init( dsn=dsn, release=self.version, @@ -254,20 +272,37 @@ def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rat sentry_sdk.set_tag("platform", self._platform) sentry_sdk.set_tag("python_version", self._python_version) - def _authenticate(self): - authentication_client = _DiodeAuthentication(self._target, self._path, self._tls_verify, self._client_id, self._client_secret) + def _authenticate(self, scope: str): + authentication_client = _DiodeAuthentication( + self._target, + self._path, + self._tls_verify, + self._client_id, + self._client_secret, + scope, + ) access_token = authentication_client.authenticate() - self._metadata = list(filter(lambda x: x[0] != "authorization", self._metadata)) + \ - [("authorization", f"Bearer {access_token}")] + self._metadata = list( + filter(lambda x: x[0] != "authorization", self._metadata) + ) + [("authorization", f"Bearer {access_token}")] class _DiodeAuthentication: - def __init__(self, target: str, path: str, tls_verify: bool, client_id: str, client_secret: str): + def __init__( + self, + target: str, + path: str, + tls_verify: bool, + client_id: str, + client_secret: str, + scope: str, + ): self._target = target self._tls_verify = tls_verify self._client_id = client_id self._client_secret = client_secret self._path = path + self._scope = scope def authenticate(self) -> str: """Request an OAuth2 token using client credentials and return it.""" @@ -286,6 +321,7 @@ def authenticate(self) -> str: "grant_type": "client_credentials", "client_id": self._client_id, "client_secret": self._client_secret, + "scope": self._scope, } ) url = self._get_auth_url() @@ -299,7 +335,9 @@ def authenticate(self) -> str: token_info = json.loads(response.read().decode()) access_token = token_info.get("access_token") if not access_token: - raise DiodeConfigError(f"Failed to obtain access token for client {self._client_id}") + raise DiodeConfigError( + f"Failed to obtain access token for client {self._client_id}" + ) _LOGGER.debug(f"Access token obtained for client {self._client_id}") return access_token @@ -307,7 +345,7 @@ def authenticate(self) -> str: def _get_auth_url(self) -> str: """Construct the authentication URL, handling trailing slashes in the path.""" # Ensure the path does not have trailing slashes - path = self._path.rstrip('/') if self._path else '' + path = self._path.rstrip("/") if self._path else "" return f"{path}/auth/token" @@ -332,10 +370,10 @@ class _ClientCallDetails( """ - pass - -class DiodeMethodClientInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor): +class DiodeMethodClientInterceptor( + grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor +): """ Diode Method Client Interceptor class. @@ -374,6 +412,8 @@ def intercept_unary_unary(self, continuation, client_call_details, request): """Intercept unary unary.""" return self._intercept_call(continuation, client_call_details, request) - def intercept_stream_unary(self, continuation, client_call_details, request_iterator): + def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): """Intercept stream unary.""" return self._intercept_call(continuation, client_call_details, request_iterator) diff --git a/tests/test_client.py b/tests/test_client.py index e8383eb..7af797d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -560,6 +560,7 @@ def test_diode_authentication_success(mock_diode_authentication): tls_verify=False, client_id="test_client_id", client_secret="test_client_secret", + scope="diode:ingest", ) with mock.patch("http.client.HTTPConnection") as mock_http_conn: mock_conn_instance = mock_http_conn.return_value @@ -578,6 +579,7 @@ def test_diode_authentication_failure(mock_diode_authentication): tls_verify=False, client_id="test_client_id", client_secret="test_client_secret", + scope="diode:ingest", ) with mock.patch("http.client.HTTPConnection") as mock_http_conn: mock_conn_instance = mock_http_conn.return_value @@ -605,6 +607,7 @@ def test_diode_authentication_url_with_path(mock_diode_authentication, path): tls_verify=False, client_id="test_client_id", client_secret="test_client_secret", + scope="diode:ingest", ) with mock.patch("http.client.HTTPConnection") as mock_http_conn: mock_conn_instance = mock_http_conn.return_value @@ -622,6 +625,7 @@ def test_diode_authentication_request_exception(mock_diode_authentication): tls_verify=False, client_id="test_client_id", client_secret="test_client_secret", + scope="diode:ingest", ) with mock.patch("http.client.HTTPConnection") as mock_http_conn: mock_conn_instance = mock_http_conn.return_value