diff --git a/src/elasticotel/distro/__init__.py b/src/elasticotel/distro/__init__.py index aed9299..329606c 100644 --- a/src/elasticotel/distro/__init__.py +++ b/src/elasticotel/distro/__init__.py @@ -55,7 +55,7 @@ from elasticotel.distro import version from elasticotel.distro.environment_variables import ELASTIC_OTEL_OPAMP_ENDPOINT, ELASTIC_OTEL_SYSTEM_METRICS_ENABLED from elasticotel.distro.resource_detectors import get_cloud_resource_detectors -from elasticotel.distro.config import opamp_handler +from elasticotel.distro.config import opamp_handler, DEFAULT_SAMPLING_RATE logger = logging.getLogger(__name__) @@ -152,7 +152,7 @@ def _configure(self, **kwargs): # preference to use DELTA temporality as we can handle only this kind of Histograms os.environ.setdefault(OTEL_EXPORTER_OTLP_METRICS_TEMPORALITY_PREFERENCE, "DELTA") os.environ.setdefault(OTEL_TRACES_SAMPLER, "parentbased_traceidratio") - os.environ.setdefault(OTEL_TRACES_SAMPLER_ARG, "1.0") + os.environ.setdefault(OTEL_TRACES_SAMPLER_ARG, str(DEFAULT_SAMPLING_RATE)) base_resource_detectors = [ "process_runtime", diff --git a/src/elasticotel/distro/config.py b/src/elasticotel/distro/config.py index 89fef39..0e68963 100644 --- a/src/elasticotel/distro/config.py +++ b/src/elasticotel/distro/config.py @@ -16,6 +16,9 @@ import logging +from opentelemetry import trace + +from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio from opentelemetry._opamp import messages from opentelemetry._opamp.agent import OpAMPAgent from opentelemetry._opamp.client import OpAMPClient @@ -34,32 +37,79 @@ "off": 1000, } +DEFAULT_SAMPLING_RATE = 1.0 + + +def _handle_logging_level(config) -> str: + error_message = "" + # when config option has default value you don't get it so need to handle the default + config_logging_level = config.get("logging_level") + if config_logging_level is not None: + logging_level = _LOG_LEVELS_MAP.get(config_logging_level) # type: ignore[reportArgumentType] + else: + logging_level = logging.INFO + + if logging_level is None: + logger.error("Logging level not handled: %s", config_logging_level) + error_message = f"Logging level not handled: {config_logging_level}" + else: + # update upstream and distro logging levels + logging.getLogger("opentelemetry").setLevel(logging_level) + logging.getLogger("elasticotel").setLevel(logging_level) + return error_message + + +def _handle_sampling_rate(config) -> str: + config_sampling_rate = config.get("sampling_rate") + sampling_rate = DEFAULT_SAMPLING_RATE + if config_sampling_rate is not None: + try: + sampling_rate = float(config_sampling_rate) + if sampling_rate < 0 or sampling_rate > 1.0: + raise ValueError() + except ValueError: + logger.error("Invalid `sampling_rate` from config `%s`", config_sampling_rate) + return f"Invalid sampling_rate {config_sampling_rate}" + + sampler = getattr(trace.get_tracer_provider(), "sampler", None) + if sampler is None: + logger.debug("Cannot get sampler from tracer provider.") + return "" + + # FIXME: this needs to be updated for the consistent probability samplers + if not isinstance(sampler, ParentBasedTraceIdRatio): + logger.warning("Sampler %s is not supported, not applying sampling_rate.", type(sampler)) + return "" + + # since sampler is parent based we need to update its root sampler + root_sampler = sampler._root # type: ignore[reportAttributeAccessIssue] + if root_sampler.rate != sampling_rate: # type: ignore[reportAttributeAccessIssue] + # we don't have a proper way to update it :) + root_sampler._rate = sampling_rate # type: ignore[reportAttributeAccessIssue] + root_sampler._bound = root_sampler.get_bound_for_rate(root_sampler._rate) # type: ignore[reportAttributeAccessIssue] + logger.debug("Updated sampler rate to %s", sampling_rate) + return "" + def opamp_handler(agent: OpAMPAgent, client: OpAMPClient, message: opamp_pb2.ServerToAgent): # we check config_hash because we need to track last received config and remote_config seems to be always truthy if not message.remote_config or not message.remote_config.config_hash: return - error_message = "" + error_messages = [] for config_filename, config in messages._decode_remote_config(message.remote_config): # we don't have standardized config values so limit to configs coming from our backend if config_filename == "elastic": logger.debug("Config %s: %s", config_filename, config) - # when config option has default value you don't get it so need to handle the default - config_logging_level = config.get("logging_level") - if config_logging_level is not None: - logging_level = _LOG_LEVELS_MAP.get(config_logging_level) # type: ignore[reportArgumentType] - else: - logging_level = logging.INFO - - if logging_level is None: - logger.warning("Logging level not handled: %s", config_logging_level) - error_message = f"Logging level not handled: {config_logging_level}" - else: - # update upstream and distro logging levels - logging.getLogger("opentelemetry").setLevel(logging_level) - logging.getLogger("elasticotel").setLevel(logging_level) + error_message = _handle_logging_level(config) + if error_message: + error_messages.append(error_message) + + error_message = _handle_sampling_rate(config) + if error_message: + error_messages.append(error_message) + error_message = "\n".join(error_messages) status = opamp_pb2.RemoteConfigStatuses_FAILED if error_message else opamp_pb2.RemoteConfigStatuses_APPLIED updated_remote_config = client._update_remote_config_status( remote_config_hash=message.remote_config.config_hash, status=status, error_message=error_message diff --git a/tests/distro/test_distro.py b/tests/distro/test_distro.py index bc3a2b0..5e1c23e 100644 --- a/tests/distro/test_distro.py +++ b/tests/distro/test_distro.py @@ -325,10 +325,134 @@ def test_warns_if_logging_level_does_not_match_our_map(self, get_logger_mock): remote_config = opamp_pb2.AgentRemoteConfig(config=config, config_hash=b"1234") message = opamp_pb2.ServerToAgent(remote_config=remote_config) - with self.assertLogs(config_logger, logging.WARNING): + with self.assertLogs(config_logger, logging.ERROR) as cm: opamp_handler(agent, client, message) + self.assertEqual(cm.output, ["ERROR:elasticotel.distro.config:Logging level not handled: unexpected"]) client._build_remote_config_status_response_message.assert_called_once_with( client._update_remote_config_status() ) agent.send.assert_called_once_with(payload=mock.ANY) + + @mock.patch("opentelemetry.trace.get_tracer_provider") + def test_sets_matching_sampling_rate(self, get_tracer_provider_mock): + sampler = sampling.ParentBasedTraceIdRatio(rate=1.0) + get_tracer_provider_mock.return_value.sampler = sampler + agent = mock.Mock() + client = mock.Mock() + config = opamp_pb2.AgentConfigMap() + config.config_map["elastic"].body = json.dumps({"sampling_rate": "0.5"}).encode() + config.config_map["elastic"].content_type = "application/json" + remote_config = opamp_pb2.AgentRemoteConfig(config=config, config_hash=b"1234") + message = opamp_pb2.ServerToAgent(remote_config=remote_config) + opamp_handler(agent, client, message) + + self.assertEqual(sampler._root.rate, 0.5) + + client._update_remote_config_status.assert_called_once_with( + remote_config_hash=b"1234", status=opamp_pb2.RemoteConfigStatuses_APPLIED, error_message="" + ) + client._build_remote_config_status_response_message.assert_called_once_with( + client._update_remote_config_status() + ) + agent.send.assert_called_once_with(payload=mock.ANY) + + @mock.patch("opentelemetry.trace.get_tracer_provider") + def test_sets_sampling_rate_to_default_info_without_sampling_rate_entry_in_config(self, get_tracer_provider_mock): + sampler = sampling.ParentBasedTraceIdRatio(rate=1.0) + get_tracer_provider_mock.return_value.sampler = sampler + agent = mock.Mock() + client = mock.Mock() + config = opamp_pb2.AgentConfigMap() + config.config_map["elastic"].body = json.dumps({}).encode() + config.config_map["elastic"].content_type = "application/json" + remote_config = opamp_pb2.AgentRemoteConfig(config=config, config_hash=b"1234") + message = opamp_pb2.ServerToAgent(remote_config=remote_config) + opamp_handler(agent, client, message) + + self.assertEqual(sampler._root.rate, 1.0) + + client._update_remote_config_status.assert_called_once_with( + remote_config_hash=b"1234", status=opamp_pb2.RemoteConfigStatuses_APPLIED, error_message="" + ) + client._build_remote_config_status_response_message.assert_called_once_with( + client._update_remote_config_status() + ) + agent.send.assert_called_once_with(payload=mock.ANY) + + @mock.patch("opentelemetry.trace.get_tracer_provider") + def test_warns_if_sampling_rate_value_is_invalid(self, get_tracer_provider_mock): + sampler = sampling.ParentBasedTraceIdRatio(rate=1.0) + get_tracer_provider_mock.return_value.sampler = sampler + agent = mock.Mock() + client = mock.Mock() + config = opamp_pb2.AgentConfigMap() + config.config_map["elastic"].body = json.dumps({"sampling_rate": "unexpected"}).encode() + config.config_map["elastic"].content_type = "application/json" + remote_config = opamp_pb2.AgentRemoteConfig(config=config, config_hash=b"1234") + message = opamp_pb2.ServerToAgent(remote_config=remote_config) + + with self.assertLogs(config_logger, logging.ERROR) as cm: + opamp_handler(agent, client, message) + self.assertEqual( + cm.output, ["ERROR:elasticotel.distro.config:Invalid `sampling_rate` from config `unexpected`"] + ) + + client._update_remote_config_status.assert_called_once_with( + remote_config_hash=b"1234", + status=opamp_pb2.RemoteConfigStatuses_FAILED, + error_message="Invalid sampling_rate unexpected", + ) + client._build_remote_config_status_response_message.assert_called_once_with( + client._update_remote_config_status() + ) + agent.send.assert_called_once_with(payload=mock.ANY) + + @mock.patch("opentelemetry.trace.get_tracer_provider") + def test_warns_if_sampler_is_not_what_we_expect(self, get_tracer_provider_mock): + get_tracer_provider_mock.return_value.sampler = 5 + agent = mock.Mock() + client = mock.Mock() + config = opamp_pb2.AgentConfigMap() + config.config_map["elastic"].body = json.dumps({"sampling_rate": "1.0"}).encode() + config.config_map["elastic"].content_type = "application/json" + remote_config = opamp_pb2.AgentRemoteConfig(config=config, config_hash=b"1234") + message = opamp_pb2.ServerToAgent(remote_config=remote_config) + + with self.assertLogs(config_logger, logging.WARNING) as cm: + opamp_handler(agent, client, message) + self.assertEqual( + cm.output, + ["WARNING:elasticotel.distro.config:Sampler is not supported, not applying sampling_rate."], + ) + + client._update_remote_config_status.assert_called_once_with( + remote_config_hash=b"1234", status=opamp_pb2.RemoteConfigStatuses_APPLIED, error_message="" + ) + client._build_remote_config_status_response_message.assert_called_once_with( + client._update_remote_config_status() + ) + agent.send.assert_called_once_with(payload=mock.ANY) + + @mock.patch("opentelemetry.trace.get_tracer_provider") + def test_ignores_tracer_provider_without_a_sampler(self, get_tracer_provider_mock): + get_tracer_provider_mock.return_value.sampler = None + agent = mock.Mock() + client = mock.Mock() + config = opamp_pb2.AgentConfigMap() + config.config_map["elastic"].body = json.dumps({"sampling_rate": "1.0"}).encode() + config.config_map["elastic"].content_type = "application/json" + remote_config = opamp_pb2.AgentRemoteConfig(config=config, config_hash=b"1234") + message = opamp_pb2.ServerToAgent(remote_config=remote_config) + + with self.assertLogs(config_logger, logging.DEBUG) as cm: + opamp_handler(agent, client, message) + self.assertIn("DEBUG:elasticotel.distro.config:Cannot get sampler from tracer provider.", cm.output) + + client._update_remote_config_status.assert_called_once_with( + remote_config_hash=b"1234", status=opamp_pb2.RemoteConfigStatuses_APPLIED, error_message="" + ) + client._build_remote_config_status_response_message.assert_called_once_with( + client._update_remote_config_status() + ) + agent.send.assert_called_once_with(payload=mock.ANY)