diff --git a/.gitignore b/.gitignore index 257791d..65f0a15 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,5 @@ __pycache__/ build/ dist/ .eggs/ -*.egg-info \ No newline at end of file +*.egg-info +poetry.lock diff --git a/README.md b/README.md index 53cac60..c181a55 100644 --- a/README.md +++ b/README.md @@ -20,9 +20,10 @@ pip install netboxlabs-diode-sdk ### Environment variables -* `DIODE_API_KEY` - API key for the Diode service * `DIODE_SDK_LOG_LEVEL` - Log level for the SDK (default: `INFO`) * `DIODE_SENTRY_DSN` - Optional Sentry DSN for error reporting +* `DIODE_CLIENT_ID` - Client ID for OAuth2 authentication +* `DIODE_CLIENT_SECRET` - Client Secret for OAuth2 authentication ### Example @@ -94,8 +95,7 @@ if __name__ == "__main__": ## Development notes -Code in `netboxlabs/diode/sdk/diode/*` is generated from Protocol Buffers definitions (will be published and referred -here soon). +Code in `netboxlabs/diode/sdk/diode/*` is generated from Protocol Buffers definitions (will be published and referenced here soon). #### Linting @@ -107,7 +107,7 @@ black netboxlabs/ #### Testing ```shell -pytest tests/ +PYTHONPATH=$(pwd) pytest ``` ## License diff --git a/netboxlabs/diode/sdk/client.py b/netboxlabs/diode/sdk/client.py index 44049c8..429cab5 100644 --- a/netboxlabs/diode/sdk/client.py +++ b/netboxlabs/diode/sdk/client.py @@ -1,13 +1,17 @@ #!/usr/bin/env python # Copyright 2024 NetBox Labs Inc """NetBox Labs, Diode - SDK - Client.""" + import collections +import http.client +import json import logging import os import platform +import ssl import uuid from collections.abc import Iterable -from urllib.parse import urlparse +from urllib.parse import urlencode, urlparse import certifi import grpc @@ -18,9 +22,11 @@ from netboxlabs.diode.sdk.ingester import Entity from netboxlabs.diode.sdk.version import version_semver -_DIODE_API_KEY_ENVVAR_NAME = "DIODE_API_KEY" +_MAX_RETRIES_ENVVAR_NAME = "DIODE_MAX_AUTH_RETRIES" _DIODE_SDK_LOG_LEVEL_ENVVAR_NAME = "DIODE_SDK_LOG_LEVEL" _DIODE_SENTRY_DSN_ENVVAR_NAME = "DIODE_SENTRY_DSN" +_CLIENT_ID_ENVVAR_NAME = "DIODE_CLIENT_ID" +_CLIENT_SECRET_ENVVAR_NAME = "DIODE_CLIENT_SECRET" _DEFAULT_STREAM = "latest" _LOGGER = logging.getLogger(__name__) @@ -31,17 +37,6 @@ def _load_certs() -> bytes: return f.read() -def _get_api_key(api_key: str | None = None) -> str: - """Get API Key either from provided value or environment variable.""" - if api_key is None: - api_key = os.getenv(_DIODE_API_KEY_ENVVAR_NAME) - if api_key is None: - raise DiodeConfigError( - f"api_key param or {_DIODE_API_KEY_ENVVAR_NAME} environment variable required" - ) - return api_key - - def parse_target(target: str) -> tuple[str, str, bool]: """Parse the target into authority, path and tls_verify.""" parsed_target = urlparse(target) @@ -66,6 +61,21 @@ def _get_sentry_dsn(sentry_dsn: str | None = None) -> str | None: return sentry_dsn +def _get_required_config_value(env_var_name: str, value: str | None = None) -> str: + """Get required config value either from provided value or environment variable.""" + if value is None: + value = os.getenv(env_var_name) + if value is None: + raise DiodeConfigError(f"parameter or {env_var_name} environment variable required") + return value + +def _get_optional_config_value(env_var_name: str, value: str | None = None) -> str | None: + """Get optional config value either from provided value or environment variable.""" + if value is None: + value = os.getenv(env_var_name) + return value + + class DiodeClient: """Diode Client.""" @@ -81,30 +91,40 @@ def __init__( target: str, app_name: str, app_version: str, - api_key: str | None = None, + client_id: str | None = None, + client_secret: str | None = None, sentry_dsn: str = None, sentry_traces_sample_rate: float = 1.0, sentry_profiles_sample_rate: float = 1.0, + max_auth_retries: int = 3, ): """Initiate a new client.""" log_level = os.getenv(_DIODE_SDK_LOG_LEVEL_ENVVAR_NAME, "INFO").upper() logging.basicConfig(level=log_level) + self._max_auth_retries = _get_optional_config_value(_MAX_RETRIES_ENVVAR_NAME, max_auth_retries) self._target, self._path, self._tls_verify = parse_target(target) self._app_name = app_name self._app_version = app_version self._platform = platform.platform() self._python_version = platform.python_version() - api_key = _get_api_key(api_key) + # Read client credentials from environment variables + self._client_id = _get_required_config_value(_CLIENT_ID_ENVVAR_NAME, client_id) + self._client_secret = _get_required_config_value(_CLIENT_SECRET_ENVVAR_NAME, client_secret) + self._metadata = ( - ("diode-api-key", api_key), ("platform", self._platform), ("python-version", self._python_version), ) + self._authenticate() + channel_opts = ( - ("grpc.primary_user_agent", f"{self._name}/{self._version} {self._app_name}/{self._app_version}"), + ( + "grpc.primary_user_agent", + f"{self._name}/{self._version} {self._app_name}/{self._app_version}", + ), ) if self._tls_verify: @@ -129,9 +149,7 @@ def __init__( _LOGGER.debug(f"Setting up gRPC interceptor for path: {self._path}") rpc_method_interceptor = DiodeMethodClientInterceptor(subpath=self._path) - intercept_channel = grpc.intercept_channel( - self._channel, rpc_method_interceptor - ) + intercept_channel = grpc.intercept_channel(self._channel, rpc_method_interceptor) channel = intercept_channel self._stub = ingester_pb2_grpc.IngesterServiceStub(channel) @@ -140,9 +158,7 @@ def __init__( if self._sentry_dsn is not None: _LOGGER.debug("Setting up Sentry") - self._setup_sentry( - self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate - ) + self._setup_sentry(self._sentry_dsn, sentry_traces_sample_rate, sentry_profiles_sample_rate) @property def name(self) -> str: @@ -202,24 +218,28 @@ def ingest( stream: str | None = _DEFAULT_STREAM, ) -> ingester_pb2.IngestResponse: """Ingest entities.""" - try: - request = ingester_pb2.IngestRequest( - stream=stream, - id=str(uuid.uuid4()), - entities=entities, - sdk_name=self.name, - sdk_version=self.version, - producer_app_name=self.app_name, - producer_app_version=self.app_version, - ) - - return self._stub.Ingest(request, metadata=self._metadata) - except grpc.RpcError as err: - raise DiodeClientError(err) from err - - def _setup_sentry( - self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float - ): + for attempt in range(self._max_auth_retries): + try: + request = ingester_pb2.IngestRequest( + stream=stream, + id=str(uuid.uuid4()), + entities=entities, + sdk_name=self.name, + sdk_version=self.version, + producer_app_name=self.app_name, + producer_app_version=self.app_version, + ) + return self._stub.Ingest(request, metadata=self._metadata) + except grpc.RpcError as err: + if err.code() == grpc.StatusCode.UNAUTHENTICATED: + if attempt < self._max_auth_retries - 1: + _LOGGER.info(f"Retrying ingestion due to UNAUTHENTICATED error, attempt {attempt + 1}") + self._authenticate() + continue + raise DiodeClientError(err) from err + return RuntimeError("Max retries exceeded") + + def _setup_sentry(self, dsn: str, traces_sample_rate: float, profiles_sample_rate: float): sentry_sdk.init( dsn=dsn, release=self.version, @@ -234,6 +254,59 @@ def _setup_sentry( sentry_sdk.set_tag("platform", self._platform) sentry_sdk.set_tag("python_version", self._python_version) + def _authenticate(self): + authentication_client = _DiodeAuthentication(self._target, self._path, self._tls_verify, self._client_id, self._client_secret) + access_token = authentication_client.authenticate() + self._metadata = list(filter(lambda x: x[0] != "authorization", self._metadata)) + \ + [("authorization", f"Bearer {access_token}")] + + +class _DiodeAuthentication: + def __init__(self, target: str, path: str, tls_verify: bool, client_id: str, client_secret: str): + self._target = target + self._tls_verify = tls_verify + self._client_id = client_id + self._client_secret = client_secret + self._path = path + + def authenticate(self) -> str: + """Request an OAuth2 token using client credentials and return it.""" + if self._tls_verify: + conn = http.client.HTTPSConnection( + self._target, + context=None if self._tls_verify else ssl._create_unverified_context(), + ) + else: + conn = http.client.HTTPConnection( + self._target, + ) + headers = {"Content-type": "application/x-www-form-urlencoded"} + data = urlencode( + { + "grant_type": "client_credentials", + "client_id": self._client_id, + "client_secret": self._client_secret, + } + ) + url = self._get_auth_url() + conn.request("POST", url, data, headers) + response = conn.getresponse() + if response.status != 200: + raise DiodeConfigError(f"Failed to obtain access token: {response.reason}") + token_info = json.loads(response.read().decode()) + access_token = token_info.get("access_token") + if not access_token: + raise DiodeConfigError(f"Failed to obtain access token for client {self._client_id}") + + _LOGGER.debug(f"Access token obtained for client {self._client_id}") + return access_token + + def _get_auth_url(self) -> str: + """Construct the authentication URL, handling trailing slashes in the path.""" + # Ensure the path does not have trailing slashes + path = self._path.rstrip('/') if self._path else '' + return f"{path}/auth/token" + class _ClientCallDetails( collections.namedtuple( @@ -259,9 +332,7 @@ class _ClientCallDetails( pass -class DiodeMethodClientInterceptor( - grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor -): +class DiodeMethodClientInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor): """ Diode Method Client Interceptor class. @@ -300,8 +371,6 @@ def intercept_unary_unary(self, continuation, client_call_details, request): """Intercept unary unary.""" return self._intercept_call(continuation, client_call_details, request) - def intercept_stream_unary( - self, continuation, client_call_details, request_iterator - ): + def intercept_stream_unary(self, continuation, client_call_details, request_iterator): """Intercept stream unary.""" return self._intercept_call(continuation, client_call_details, request_iterator) diff --git a/netboxlabs/diode/sdk/ingester.py b/netboxlabs/diode/sdk/ingester.py index c5c84ac..3bf7f0b 100644 --- a/netboxlabs/diode/sdk/ingester.py +++ b/netboxlabs/diode/sdk/ingester.py @@ -11,7 +11,9 @@ import datetime import re from typing import Any + from google.protobuf import timestamp_pb2 as _timestamp_pb2 + import netboxlabs.diode.sdk.diode.v1.ingester_pb2 as pb PRIMARY_VALUE_MAP = { diff --git a/pyproject.toml b/pyproject.toml index 3388bec..31b892b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,10 +32,11 @@ dependencies = [ [project.optional-dependencies] # Optional dev = ["black", "check-manifest", "ruff"] -test = ["coverage", "pytest", "pytest-cov"] +test = ["coverage", "pytest", "pytest-cov==6.0.0"] [tool.coverage.run] omit = [ + "*/netboxlabs/diode/sdk/ingester.py", "*/netboxlabs/diode/sdk/diode/*", "*/netboxlabs/diode/sdk/validate/*", "*/tests/*", diff --git a/tests/test_client.py b/tests/test_client.py index 02a319c..4cc510e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,19 +1,21 @@ #!/usr/bin/env python # Copyright 2024 NetBox Labs Inc """NetBox Labs - Tests.""" + +import json import os from unittest import mock +from unittest.mock import MagicMock, patch import grpc import pytest from netboxlabs.diode.sdk.client import ( - _DIODE_API_KEY_ENVVAR_NAME, _DIODE_SENTRY_DSN_ENVVAR_NAME, DiodeClient, DiodeMethodClientInterceptor, _ClientCallDetails, - _get_api_key, + _DiodeAuthentication, _get_sentry_dsn, _load_certs, parse_target, @@ -22,13 +24,14 @@ from netboxlabs.diode.sdk.version import version_semver -def test_init(): +def test_init(mock_diode_authentication): """Check we can initiate a client configuration.""" config = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) assert config.target == "localhost:8081" assert config.name == "diode-sdk-python" @@ -39,25 +42,36 @@ def test_init(): assert config.path == "" -def test_config_error(): +@pytest.mark.parametrize( + "client_id,client_secret,env_var_name", + [ + (None, "123", "DIODE_CLIENT_ID"), + ("123", None, "DIODE_CLIENT_SECRET"), + (None, None, "DIODE_CLIENT_ID"), + ], +) +def test_config_errors(client_id, client_secret, env_var_name): """Check we can raise a config error.""" with pytest.raises(DiodeConfigError) as err: DiodeClient( - target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1" + target="grpc://localhost:8081", + app_name="my-producer", + app_version="0.0.1", + client_id=client_id, + client_secret=client_secret, ) - assert ( - str(err.value) == "api_key param or DIODE_API_KEY environment variable required" - ) + assert str(err.value) == f"parameter or {env_var_name} environment variable required" -def test_client_error(): +def test_client_error(mock_diode_authentication): """Check we can raise a client error.""" with pytest.raises(DiodeClientError) as err: client = DiodeClient( target="grpc://invalid:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) client.ingest(entities=[]) assert err.value.status_code == grpc.StatusCode.UNAVAILABLE @@ -72,10 +86,7 @@ def test_diode_client_error_repr_returns_correct_string(): error = DiodeClientError(grpc_error) error._status_code = grpc.StatusCode.UNAVAILABLE error._details = "Some details about the error" - assert ( - repr(error) - == "" - ) + assert repr(error) == "" def test_load_certs_returns_bytes(): @@ -83,26 +94,6 @@ def test_load_certs_returns_bytes(): assert isinstance(_load_certs(), bytes) -def test_get_api_key_returns_env_var_when_no_input(): - """Check that _get_api_key returns the env var when no input is provided.""" - os.environ[_DIODE_API_KEY_ENVVAR_NAME] = "env_var_key" - assert _get_api_key() == "env_var_key" - - -def test_get_api_key_returns_input_when_provided(): - """Check that _get_api_key returns the input when provided.""" - os.environ[_DIODE_API_KEY_ENVVAR_NAME] = "env_var_key" - assert _get_api_key("input_key") == "input_key" - - -def test_get_api_key_raises_error_when_no_input_or_env_var(): - """Check that _get_api_key raises an error when no input or env var is provided.""" - if _DIODE_API_KEY_ENVVAR_NAME in os.environ: - del os.environ[_DIODE_API_KEY_ENVVAR_NAME] - with pytest.raises(DiodeConfigError): - _get_api_key() - - def test_parse_target_handles_http_prefix(): """Check that parse_target raises an error when the target contains http://.""" with pytest.raises(ValueError): @@ -166,13 +157,14 @@ def test_get_sentry_dsn_returns_none_when_no_input_or_env_var(): assert _get_sentry_dsn() is None -def test_setup_sentry_initializes_with_correct_parameters(): +def test_setup_sentry_initializes_with_correct_parameters(mock_diode_authentication): """Check that DiodeClient._setup_sentry() initializes with the correct parameters.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with mock.patch("sentry_sdk.init") as mock_init: client._setup_sentry("https://user@password.mock.dsn/123456", 0.5, 0.5) @@ -184,13 +176,14 @@ def test_setup_sentry_initializes_with_correct_parameters(): ) -def test_client_sets_up_secure_channel_when_grpcs_scheme_is_found_in_target(): +def test_client_sets_up_secure_channel_when_grpcs_scheme_is_found_in_target(mock_diode_authentication): """Check that DiodeClient.__init__() sets up the gRPC secure channel when grpcs:// scheme is found in the target.""" client = DiodeClient( target="grpcs://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with ( mock.patch("grpc.secure_channel") as mock_secure_channel, @@ -200,20 +193,22 @@ def test_client_sets_up_secure_channel_when_grpcs_scheme_is_found_in_target(): target="grpcs://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) mock_debug.assert_called_once_with("Setting up gRPC secure channel") mock_secure_channel.assert_called_once() -def test_client_sets_up_insecure_channel_when_grpc_scheme_is_found_in_target(): +def test_client_sets_up_insecure_channel_when_grpc_scheme_is_found_in_target(mock_diode_authentication): """Check that DiodeClient.__init__() sets up the gRPC insecure channel when grpc:// scheme is found in the target.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with ( mock.patch("grpc.insecure_channel") as mock_insecure_channel, @@ -223,7 +218,8 @@ def test_client_sets_up_insecure_channel_when_grpc_scheme_is_found_in_target(): target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) mock_debug.assert_called_with( @@ -232,14 +228,15 @@ def test_client_sets_up_insecure_channel_when_grpc_scheme_is_found_in_target(): mock_insecure_channel.assert_called_once() -def test_insecure_channel_options_with_primary_user_agent(): +def test_insecure_channel_options_with_primary_user_agent(mock_diode_authentication): """Check that DiodeClient.__init__() sets the gRPC primary_user_agent option for insecure channel.""" with mock.patch("grpc.insecure_channel") as mock_insecure_channel: client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) mock_insecure_channel.assert_called_once() @@ -252,14 +249,15 @@ def test_insecure_channel_options_with_primary_user_agent(): ) -def test_secure_channel_options_with_primary_user_agent(): +def test_secure_channel_options_with_primary_user_agent(mock_diode_authentication): """Check that DiodeClient.__init__() sets the gRPC primary_user_agent option for secure channel.""" with mock.patch("grpc.secure_channel") as mock_secure_channel: client = DiodeClient( target="grpcs://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) mock_secure_channel.assert_called_once() @@ -272,13 +270,14 @@ def test_secure_channel_options_with_primary_user_agent(): ) -def test_client_interceptor_setup_with_path(): +def test_client_interceptor_setup_with_path(mock_diode_authentication): """Check that DiodeClient.__init__() sets up the gRPC interceptor when a path is provided.""" client = DiodeClient( target="grpc://localhost:8081/my-path", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with ( mock.patch("grpc.intercept_channel") as mock_intercept_channel, @@ -288,7 +287,8 @@ def test_client_interceptor_setup_with_path(): target="grpc://localhost:8081/my-path", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) mock_debug.assert_called_with( @@ -297,13 +297,14 @@ def test_client_interceptor_setup_with_path(): mock_intercept_channel.assert_called_once() -def test_client_interceptor_not_setup_without_path(): +def test_client_interceptor_not_setup_without_path(mock_diode_authentication): """Check that DiodeClient.__init__() does not set up the gRPC interceptor when no path is provided.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with ( mock.patch("grpc.intercept_channel") as mock_intercept_channel, @@ -313,7 +314,8 @@ def test_client_interceptor_not_setup_without_path(): target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) mock_debug.assert_called_with( @@ -322,13 +324,14 @@ def test_client_interceptor_not_setup_without_path(): mock_intercept_channel.assert_not_called() -def test_client_setup_sentry_called_when_sentry_dsn_exists(): +def test_client_setup_sentry_called_when_sentry_dsn_exists(mock_diode_authentication): """Check that DiodeClient._setup_sentry() is called when sentry_dsn exists.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", sentry_dsn="https://user@password.mock.dsn/123456", ) with mock.patch.object(client, "_setup_sentry") as mock_setup_sentry: @@ -336,39 +339,41 @@ def test_client_setup_sentry_called_when_sentry_dsn_exists(): target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", sentry_dsn="https://user@password.mock.dsn/123456", ) - mock_setup_sentry.assert_called_once_with( - "https://user@password.mock.dsn/123456", 1.0, 1.0 - ) + mock_setup_sentry.assert_called_once_with("https://user@password.mock.dsn/123456", 1.0, 1.0) -def test_client_setup_sentry_not_called_when_sentry_dsn_not_exists(): +def test_client_setup_sentry_not_called_when_sentry_dsn_not_exists(mock_diode_authentication): """Check that DiodeClient._setup_sentry() is not called when sentry_dsn does not exist.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with mock.patch.object(client, "_setup_sentry") as mock_setup_sentry: client.__init__( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) mock_setup_sentry.assert_not_called() -def test_client_properties_return_expected_values(): +def test_client_properties_return_expected_values(mock_diode_authentication): """Check that DiodeClient properties return the expected values.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) assert client.name == "diode-sdk-python" assert client.version == version_semver() @@ -380,50 +385,54 @@ def test_client_properties_return_expected_values(): assert isinstance(client.channel, grpc.Channel) -def test_client_enter_returns_self(): +def test_client_enter_returns_self(mock_diode_authentication): """Check that DiodeClient.__enter__() returns self.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) assert client.__enter__() is client -def test_client_exit_closes_channel(): +def test_client_exit_closes_channel(mock_diode_authentication): """Check that DiodeClient.__exit__() closes the channel.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with mock.patch.object(client._channel, "close") as mock_close: client.__exit__(None, None, None) mock_close.assert_called_once() -def test_client_close_closes_channel(): +def test_client_close_closes_channel(mock_diode_authentication): """Check that DiodeClient.close() closes the channel.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with mock.patch.object(client._channel, "close") as mock_close: client.close() mock_close.assert_called_once() -def test_setup_sentry_sets_correct_tags(): +def test_setup_sentry_sets_correct_tags(mock_diode_authentication): """Check that DiodeClient._setup_sentry() sets the correct tags.""" client = DiodeClient( target="grpc://localhost:8081", app_name="my-producer", app_version="0.0.1", - api_key="abcde", + client_id="abcde", + client_secret="123456", ) with mock.patch("sentry_sdk.set_tag") as mock_set_tag: client._setup_sentry("https://user@password.mock.dsn/123456", 0.5, 0.5) @@ -458,10 +467,7 @@ def continuation(x, _): None, ) request = None - assert ( - interceptor.intercept_unary_unary(continuation, client_call_details, request) - == "/my/path/diode.v1.IngesterService/Ingest" - ) + assert interceptor.intercept_unary_unary(continuation, client_call_details, request) == "/my/path/diode.v1.IngesterService/Ingest" def test_interceptor_intercepts_stream_unary_calls(): @@ -481,8 +487,128 @@ def continuation(x, _): ) request_iterator = None assert ( - interceptor.intercept_stream_unary( - continuation, client_call_details, request_iterator - ) + interceptor.intercept_stream_unary(continuation, client_call_details, request_iterator) == "/my/path/diode.v1.IngesterService/Ingest" ) + + +@pytest.fixture +def mock_diode_authentication(): + """ + Fixture to mock the Diode authentication process. + + This mock replaces the _DiodeAuthentication class with a mock object + that returns a mocked token for authentication. + """ + with patch("netboxlabs.diode.sdk.client._DiodeAuthentication") as MockAuth: + mock_instance = MockAuth.return_value + mock_instance.authenticate.return_value = "mocked_token" + yield MockAuth + + +def test_diode_client_with_mocked_authentication(mock_diode_authentication): + """ + Test the DiodeClient initialization with mocked authentication. + + This test verifies that the client is initialized correctly with the mocked + authentication token and that the metadata includes the expected platform + and authorization headers. + """ + client = DiodeClient( + target="grpc://localhost:8080/diode", + app_name="my-test-app", + app_version="0.0.1", + client_id="test_client_id", + client_secret="test_client_secret", + ) + assert client._metadata[0] == ("platform", client._platform) + assert client._metadata[-1] == ("authorization", "Bearer mocked_token") + + +def test_ingest_retries_on_unauthenticated_error(mock_diode_authentication): + """Test that the ingest method retries on UNAUTHENTICATED error.""" + # Create a mock stub that raises UNAUTHENTICATED error + mock_stub = MagicMock() + mock_stub.Ingest.side_effect = grpc.RpcError() + mock_stub.Ingest.side_effect.code = lambda: grpc.StatusCode.UNAUTHENTICATED + mock_stub.Ingest.side_effect.details = lambda: "Something went wrong" + + client = DiodeClient( + target="grpc://localhost:8081", + app_name="my-producer", + app_version="0.0.1", + client_id="abcde", + client_secret="123456", + ) + + # Patch the DiodeClient to use the mock stub + client._stub = mock_stub + + # Attempt to ingest entities and expect a DiodeClientError after retries + with pytest.raises(DiodeClientError): + client.ingest(entities=[]) + + # Verify that the Ingest method was called the expected number of times + assert mock_stub.Ingest.call_count == client._max_auth_retries + + +def test_diode_authentication_success(mock_diode_authentication): + """Test successful authentication in _DiodeAuthentication.""" + auth = _DiodeAuthentication( + target="localhost:8081", + path="/diode", + tls_verify=False, + client_id="test_client_id", + client_secret="test_client_secret", + ) + with mock.patch("http.client.HTTPConnection") as mock_http_conn: + mock_conn_instance = mock_http_conn.return_value + mock_conn_instance.getresponse.return_value.status = 200 + mock_conn_instance.getresponse.return_value.read.return_value = json.dumps({"access_token": "mocked_token"}).encode() + + token = auth.authenticate() + assert token == "mocked_token" + + +def test_diode_authentication_failure(mock_diode_authentication): + """Test authentication failure in _DiodeAuthentication.""" + auth = _DiodeAuthentication( + target="localhost:8081", + path="/diode", + tls_verify=False, + client_id="test_client_id", + client_secret="test_client_secret", + ) + with mock.patch("http.client.HTTPConnection") as mock_http_conn: + mock_conn_instance = mock_http_conn.return_value + mock_conn_instance.getresponse.return_value.status = 401 + mock_conn_instance.getresponse.return_value.reason = "Unauthorized" + + with pytest.raises(DiodeConfigError) as excinfo: + auth.authenticate() + assert "Failed to obtain access token" in str(excinfo.value) + +@pytest.mark.parametrize("path", [ + "/diode", + "", + None, + "/diode/", + "diode", + "diode/", + ]) +def test_diode_authentication_url_with_path(mock_diode_authentication, path): + """Test that the authentication URL is correctly formatted with a path.""" + auth = _DiodeAuthentication( + target="localhost:8081", + path=path, + tls_verify=False, + client_id="test_client_id", + client_secret="test_client_secret", + ) + with mock.patch("http.client.HTTPConnection") as mock_http_conn: + mock_conn_instance = mock_http_conn.return_value + mock_conn_instance.getresponse.return_value.status = 200 + mock_conn_instance.getresponse.return_value.read.return_value = json.dumps({"access_token": "mocked_token"}).encode() + auth.authenticate() + mock_conn_instance.request.assert_called_once_with("POST", f"{(path or '').rstrip('/')}/auth/token", mock.ANY, mock.ANY) +