diff --git a/.github/codeql-config.yml b/.github/codeql-config.yml index 0913a434136..8fd16afa079 100644 --- a/.github/codeql-config.yml +++ b/.github/codeql-config.yml @@ -1,3 +1,4 @@ name: "CodeQL config" paths-ignore: - - 'tests/appsec/iast_packages/packages/**' + - "tests/appsec/iast_packages/packages/**" + - "tests/appsec/contrib_appsec/**" diff --git a/ddtrace/_trace/span.py b/ddtrace/_trace/span.py index f275a2f2703..7f81e6dea68 100644 --- a/ddtrace/_trace/span.py +++ b/ddtrace/_trace/span.py @@ -128,6 +128,7 @@ class Span(object): "_context", "_parent_context", "_local_root_value", + "_service_entry_span_value", "_parent", "_ignored_exceptions", "_on_finish_callbacks", @@ -221,6 +222,7 @@ def __init__( self._parent: Optional["Span"] = None self._ignored_exceptions: Optional[List[Type[Exception]]] = None self._local_root_value: Optional["Span"] = None # None means this is the root span. + self._service_entry_span_value: Optional["Span"] = None # None means this is the service entry span. self._store: Optional[Dict[str, Any]] = None def _update_tags_from_context(self) -> None: @@ -710,21 +712,28 @@ def context(self) -> Context: @property def _local_root(self) -> "Span": - if self._local_root_value is None: - return self - return self._local_root_value + return self._local_root_value or self @_local_root.setter def _local_root(self, value: "Span") -> None: - if value is not self: - self._local_root_value = value - else: - self._local_root_value = None + self._local_root_value = value if value is not self else None @_local_root.deleter def _local_root(self) -> None: del self._local_root_value + @property + def _service_entry_span(self) -> "Span": + return self._service_entry_span_value or self + + @_service_entry_span.setter + def _service_entry_span(self, span: "Span") -> None: + self._service_entry_span_value = None if span is self else span + + @_service_entry_span.deleter + def _service_entry_span(self) -> None: + del self._service_entry_span_value + def link_span(self, context: Context, attributes: Optional[Dict[str, Any]] = None) -> None: """Defines a causal relationship between two spans""" if not context.trace_id or not context.span_id: @@ -859,7 +868,8 @@ def __repr__(self) -> str: f"metrics={self._metrics}, " f"links={self._links}, " f"events={self._events}, " - f"context={self._context})" + f"context={self._context}, " + f"service_entry_span_name={self._service_entry_span.name})" ) def __str__(self) -> str: diff --git a/ddtrace/_trace/tracer.py b/ddtrace/_trace/tracer.py index 8c6457a0a2a..a7fb81dc519 100644 --- a/ddtrace/_trace/tracer.py +++ b/ddtrace/_trace/tracer.py @@ -530,6 +530,8 @@ def _start_span( if parent: span._parent = parent span._local_root = parent._local_root + if span._parent.service == service: + span._service_entry_span = parent._service_entry_span for k, v in _get_metas_to_propagate(context): # We do not want to propagate AppSec propagation headers diff --git a/ddtrace/appsec/_api_security/api_manager.py b/ddtrace/appsec/_api_security/api_manager.py index 702e088425e..d05f07a1002 100644 --- a/ddtrace/appsec/_api_security/api_manager.py +++ b/ddtrace/appsec/_api_security/api_manager.py @@ -131,7 +131,7 @@ def _should_collect_schema(self, env, priority: int) -> Optional[bool]: def _schema_callback(self, env): if env.span is None or not asm_config._api_security_feature_active: return - root = env.span._local_root or env.span + root = env.entry_span collected = self.BLOCK_COLLECTED if env.blocked else self.COLLECTED if not root or any(meta_name in root._meta for _, meta_name, _ in collected): return diff --git a/ddtrace/appsec/_asm_request_context.py b/ddtrace/appsec/_asm_request_context.py index 5fca9d3448c..2921a2ba99b 100644 --- a/ddtrace/appsec/_asm_request_context.py +++ b/ddtrace/appsec/_asm_request_context.py @@ -64,16 +64,12 @@ class WARNING_TAGS(metaclass=Constant_Class): GLOBAL_CALLBACKS: Dict[str, List[Callable]] = {_CONTEXT_CALL: []} -def report_error_on_span(error: str, message: str) -> None: - span = getattr(_get_asm_context(), "span", None) or core.get_span() - if not span: - root_span = core.get_root_span() - else: - root_span = span._local_root or span - if not root_span: +def report_error_on_entry_span(error: str, message: str) -> None: + entry_span = get_entry_span() + if not entry_span: return - root_span.set_tag_str(APPSEC.ERROR_TYPE, error) - root_span.set_tag_str(APPSEC.ERROR_MESSAGE, message) + entry_span.set_tag_str(APPSEC.ERROR_TYPE, error) + entry_span.set_tag_str(APPSEC.ERROR_MESSAGE, message) class ASM_Environment: @@ -93,6 +89,7 @@ def __init__(self, span: Optional[Span] = None, rc_products: str = ""): logger.warning(WARNING_TAGS.ASM_ENV_NO_SPAN, extra=log_extra, stack_info=True) raise TypeError("ASM_Environment requires a span") self.span: Span = context_span + self.entry_span: Span = self.span._service_entry_span if self.span.name.endswith(".request"): self.framework = self.span.name[:-8] else: @@ -132,6 +129,17 @@ def get_blocked() -> Dict[str, Any]: return env.blocked or {} +def get_entry_span() -> Optional[Span]: + env = _get_asm_context() + if env is None: + span = core.get_span() + if span: + return span._service_entry_span + else: + return core.get_root_span() + return env.entry_span + + def get_framework() -> str: env = _get_asm_context() if env is None: @@ -193,42 +201,41 @@ def update_span_metrics(span: Span, name: str, value: Union[float, int]) -> None def flush_waf_triggers(env: ASM_Environment) -> None: from ddtrace.appsec._metrics import ddwaf_version - # Make sure we find a root span to attach the triggers to - root_span = env.span._local_root or env.span + entry_span = env.entry_span if env.waf_triggers: - report_list = get_triggers(root_span) + report_list = get_triggers(entry_span) if report_list is not None: report_list.extend(env.waf_triggers) else: report_list = env.waf_triggers if asm_config._use_metastruct_for_triggers: - root_span.set_struct_tag(APPSEC.STRUCT, {"triggers": report_list}) + entry_span.set_struct_tag(APPSEC.STRUCT, {"triggers": report_list}) else: - root_span.set_tag(APPSEC.JSON, json.dumps({"triggers": report_list}, separators=(",", ":"))) + entry_span.set_tag(APPSEC.JSON, json.dumps({"triggers": report_list}, separators=(",", ":"))) env.waf_triggers = [] telemetry_results: Telemetry_result = env.telemetry - root_span.set_tag_str(APPSEC.WAF_VERSION, ddwaf_version) + entry_span.set_tag_str(APPSEC.WAF_VERSION, ddwaf_version) if telemetry_results.total_duration: - update_span_metrics(root_span, APPSEC.WAF_DURATION, telemetry_results.duration) + update_span_metrics(entry_span, APPSEC.WAF_DURATION, telemetry_results.duration) telemetry_results.duration = 0.0 - update_span_metrics(root_span, APPSEC.WAF_DURATION_EXT, telemetry_results.total_duration) + update_span_metrics(entry_span, APPSEC.WAF_DURATION_EXT, telemetry_results.total_duration) telemetry_results.total_duration = 0.0 if telemetry_results.timeout: - update_span_metrics(root_span, APPSEC.WAF_TIMEOUTS, telemetry_results.timeout) + update_span_metrics(entry_span, APPSEC.WAF_TIMEOUTS, telemetry_results.timeout) rasp_timeouts = sum(telemetry_results.rasp.timeout.values()) if rasp_timeouts: - update_span_metrics(root_span, APPSEC.RASP_TIMEOUTS, rasp_timeouts) + update_span_metrics(entry_span, APPSEC.RASP_TIMEOUTS, rasp_timeouts) if telemetry_results.rasp.sum_eval: - update_span_metrics(root_span, APPSEC.RASP_DURATION, telemetry_results.rasp.duration) - update_span_metrics(root_span, APPSEC.RASP_DURATION_EXT, telemetry_results.rasp.total_duration) - update_span_metrics(root_span, APPSEC.RASP_RULE_EVAL, telemetry_results.rasp.sum_eval) + update_span_metrics(entry_span, APPSEC.RASP_DURATION, telemetry_results.rasp.duration) + update_span_metrics(entry_span, APPSEC.RASP_DURATION_EXT, telemetry_results.rasp.total_duration) + update_span_metrics(entry_span, APPSEC.RASP_RULE_EVAL, telemetry_results.rasp.sum_eval) if telemetry_results.truncation.string_length: - root_span.set_metric(APPSEC.TRUNCATION_STRING_LENGTH, max(telemetry_results.truncation.string_length)) + entry_span.set_metric(APPSEC.TRUNCATION_STRING_LENGTH, max(telemetry_results.truncation.string_length)) if telemetry_results.truncation.container_size: - root_span.set_metric(APPSEC.TRUNCATION_CONTAINER_SIZE, max(telemetry_results.truncation.container_size)) + entry_span.set_metric(APPSEC.TRUNCATION_CONTAINER_SIZE, max(telemetry_results.truncation.container_size)) if telemetry_results.truncation.container_depth: - root_span.set_metric(APPSEC.TRUNCATION_CONTAINER_DEPTH, max(telemetry_results.truncation.container_depth)) + entry_span.set_metric(APPSEC.TRUNCATION_CONTAINER_DEPTH, max(telemetry_results.truncation.container_depth)) def finalize_asm_env(env: ASM_Environment) -> None: @@ -240,31 +247,31 @@ def finalize_asm_env(env: ASM_Environment) -> None: flush_waf_triggers(env) for function in env.callbacks[_CONTEXT_CALL]: function(env) - root_span = env.span._local_root or env.span - if root_span: + entry_span = env.entry_span + if entry_span: if env.waf_info: info = env.waf_info() try: if info.errors: - root_span.set_tag_str(APPSEC.EVENT_RULE_ERRORS, info.errors) + entry_span.set_tag_str(APPSEC.EVENT_RULE_ERRORS, info.errors) extra = {"product": "appsec", "more_info": info.errors, "stack_limit": 4} logger.debug("asm_context::finalize_asm_env::waf_errors", extra=extra, stack_info=True) - root_span.set_tag_str(APPSEC.EVENT_RULE_VERSION, info.version) - root_span.set_metric(APPSEC.EVENT_RULE_LOADED, info.loaded) - root_span.set_metric(APPSEC.EVENT_RULE_ERROR_COUNT, info.failed) + entry_span.set_tag_str(APPSEC.EVENT_RULE_VERSION, info.version) + entry_span.set_metric(APPSEC.EVENT_RULE_LOADED, info.loaded) + entry_span.set_metric(APPSEC.EVENT_RULE_ERROR_COUNT, info.failed) except Exception: logger.debug("asm_context::finalize_asm_env::exception", extra=log_extra, exc_info=True) if asm_config._rc_client_id is not None: - root_span._local_root.set_tag(APPSEC.RC_CLIENT_ID, asm_config._rc_client_id) + entry_span.set_tag(APPSEC.RC_CLIENT_ID, asm_config._rc_client_id) waf_adresses = env.waf_addresses req_headers = waf_adresses.get(SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES, {}) if req_headers: - _set_headers(root_span, req_headers, kind="request") + _set_headers(entry_span, req_headers, kind="request") res_headers = waf_adresses.get(SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES, {}) if res_headers: - _set_headers(root_span, res_headers, kind="response") + _set_headers(entry_span, res_headers, kind="response") if env.rc_products: - root_span.set_tag_str(APPSEC.RC_PRODUCTS, env.rc_products) + entry_span.set_tag_str(APPSEC.RC_PRODUCTS, env.rc_products) core.discard_local_item(_ASM_CONTEXT) @@ -364,7 +371,7 @@ def call_waf_callback(custom_data: Optional[Dict[str, Any]] = None, **kwargs) -> return callback(custom_data, **kwargs) else: logger.warning(WARNING_TAGS.CALL_WAF_CALLBACK_NOT_SET, extra=log_extra, stack_info=True) - report_error_on_span("appsec::instrumentation::diagnostic", WARNING_TAGS.CALL_WAF_CALLBACK_NOT_SET) + report_error_on_entry_span("appsec::instrumentation::diagnostic", WARNING_TAGS.CALL_WAF_CALLBACK_NOT_SET) return None @@ -661,7 +668,7 @@ def _set_headers(span: Span, headers: Any, kind: str, only_asm_enabled: bool = F value = value.decode() if key.lower() in (_COLLECTED_REQUEST_HEADERS_ASM_ENABLED if only_asm_enabled else _COLLECTED_REQUEST_HEADERS): # since the header value can be a list, use `set_tag()` to ensure it is converted to a string - (span._local_root or span).set_tag(_normalize_tag_name(kind, key), value) + span.set_tag(_normalize_tag_name(kind, key), value) def asm_listen(): diff --git a/ddtrace/appsec/_exploit_prevention/stack_traces.py b/ddtrace/appsec/_exploit_prevention/stack_traces.py index 186f1903209..391b8cabcc6 100644 --- a/ddtrace/appsec/_exploit_prevention/stack_traces.py +++ b/ddtrace/appsec/_exploit_prevention/stack_traces.py @@ -6,6 +6,7 @@ from typing import Optional from ddtrace._trace.span import Span +from ddtrace.appsec import _asm_request_context from ddtrace.appsec._constants import STACK_TRACE from ddtrace.internal import core from ddtrace.settings.asm import config as asm_config @@ -33,13 +34,15 @@ def report_stack( # iast stack trace with iast disabled return False - if span is None: + if namespace == STACK_TRACE.IAST and asm_config._iast_use_root_span: span = core.get_root_span() + if span is None: + span = _asm_request_context.get_entry_span() + if span is None or stack_id is None: return False - root_span = span._local_root or span - appsec_traces = root_span.get_struct_tag(STACK_TRACE.TAG) or {} + appsec_traces = span.get_struct_tag(STACK_TRACE.TAG) or {} current_list = appsec_traces.get(namespace, []) total_length = len(current_list) @@ -77,5 +80,5 @@ def report_stack( res["frames"] = frames current_list.append(res) appsec_traces[namespace] = current_list - root_span.set_struct_tag(STACK_TRACE.TAG, appsec_traces) + span.set_struct_tag(STACK_TRACE.TAG, appsec_traces) return True diff --git a/ddtrace/appsec/_iast/__init__.py b/ddtrace/appsec/_iast/__init__.py index 2c91713346e..07664c43e79 100644 --- a/ddtrace/appsec/_iast/__init__.py +++ b/ddtrace/appsec/_iast/__init__.py @@ -27,6 +27,7 @@ def wrapped_function(wrapped, instance, args, kwargs): ) return wrapped(*args, **kwargs) """ + import os import sys import types @@ -107,7 +108,6 @@ def _iast_pytest_activation(): if _iast_propagation_enabled: return os.environ["DD_IAST_ENABLED"] = os.environ.get("DD_IAST_ENABLED") or "1" - os.environ["_DD_IAST_USE_ROOT_SPAN"] = os.environ.get("_DD_IAST_USE_ROOT_SPAN") or "true" os.environ["DD_IAST_REQUEST_SAMPLING"] = os.environ.get("DD_IAST_REQUEST_SAMPLING") or "100.0" os.environ["_DD_APPSEC_DEDUPLICATION_ENABLED"] = os.environ.get("_DD_APPSEC_DEDUPLICATION_ENABLED") or "false" os.environ["DD_IAST_VULNERABILITIES_PER_REQUEST"] = os.environ.get("DD_IAST_VULNERABILITIES_PER_REQUEST") or "1000" diff --git a/ddtrace/appsec/_iast/_iast_request_context.py b/ddtrace/appsec/_iast/_iast_request_context.py index 02c4ad4f217..e8637e44196 100644 --- a/ddtrace/appsec/_iast/_iast_request_context.py +++ b/ddtrace/appsec/_iast/_iast_request_context.py @@ -77,7 +77,7 @@ def _create_and_attach_iast_report_to_span( def _iast_end_request(ctx=None, span=None, *args, **kwargs): try: - move_to_root = base._move_iast_data_to_root_span() + move_to_root = asm_config._iast_use_root_span if move_to_root: req_span = core.get_root_span() else: diff --git a/ddtrace/appsec/_iast/_iast_request_context_base.py b/ddtrace/appsec/_iast/_iast_request_context_base.py index 6a4384df10e..f0055d0554a 100644 --- a/ddtrace/appsec/_iast/_iast_request_context_base.py +++ b/ddtrace/appsec/_iast/_iast_request_context_base.py @@ -1,4 +1,3 @@ -import os from typing import Optional from ddtrace._trace.span import Span @@ -13,7 +12,6 @@ from ddtrace.appsec._iast.sampling.vulnerability_detection import update_global_vulnerability_limit from ddtrace.internal import core from ddtrace.internal.logger import get_logger -from ddtrace.internal.utils.formats import asbool from ddtrace.settings.asm import config as asm_config @@ -80,10 +78,6 @@ def set_iast_request_endpoint(method, route) -> None: log.debug("iast::propagation::context::Trying to set IAST request endpoint but no context is present") -def _move_iast_data_to_root_span(): - return asbool(os.getenv("_DD_IAST_USE_ROOT_SPAN")) - - def _iast_start_request(span=None, *args, **kwargs): try: if asm_config._iast_enabled: diff --git a/ddtrace/appsec/_processor.py b/ddtrace/appsec/_processor.py index 005f9f2ea6d..ff65ea59316 100644 --- a/ddtrace/appsec/_processor.py +++ b/ddtrace/appsec/_processor.py @@ -211,13 +211,12 @@ def on_span_start(self, span: Span) -> None: peer_ip = _asm_request_context.get_ip() headers = _asm_request_context.get_headers() headers_case_sensitive = _asm_request_context.get_headers_case_sensitive() - root_span = span._local_root or span - - root_span.set_metric(APPSEC.ENABLED, 1.0) - root_span.set_tag_str(_RUNTIME_FAMILY, "python") + entry_span = span._service_entry_span + entry_span.set_metric(APPSEC.ENABLED, 1.0) + entry_span.set_tag_str(_RUNTIME_FAMILY, "python") def waf_callable(custom_data=None, **kwargs): - return self._waf_action(root_span, ctx, custom_data, **kwargs) + return self._waf_action(entry_span, ctx, custom_data, **kwargs) _asm_request_context.set_waf_callback(waf_callable) _asm_request_context.add_context_callback(self.metrics._set_waf_request_metrics) @@ -238,7 +237,7 @@ def waf_callable(custom_data=None, **kwargs): def _waf_action( self, - span: Span, + entry_span: Span, ctx: "ddwaf.ddwaf_types.ddwaf_context_capsule", custom_data: Optional[Dict[str, Any]] = None, crop_trace: Optional[str] = None, @@ -304,18 +303,17 @@ def _waf_action( log.debug("appsec::processor::waf::run", exc_info=True) waf_results = Binding_error _asm_request_context.set_waf_info(lambda: self._ddwaf.info) - root_span = span._local_root or span if waf_results.return_code < 0: error_tag = APPSEC.RASP_ERROR if rule_type else APPSEC.WAF_ERROR - previous = root_span.get_tag(error_tag) + previous = entry_span.get_tag(error_tag) if previous is None: - root_span.set_tag_str(error_tag, str(waf_results.return_code)) + entry_span.set_tag_str(error_tag, str(waf_results.return_code)) else: try: int_previous = int(previous) except ValueError: int_previous = -128 - root_span.set_tag_str(error_tag, str(max(int_previous, waf_results.return_code))) + entry_span.set_tag_str(error_tag, str(max(int_previous, waf_results.return_code))) blocked = {} for action, parameters in waf_results.actions.items(): @@ -326,15 +324,17 @@ def _waf_action( blocked[WAF_ACTIONS.TYPE] = "none" elif action == WAF_ACTIONS.STACK_ACTION: stack_trace_id = parameters["stack_id"] - report_stack("exploit detected", span, crop_trace, stack_id=stack_trace_id, namespace=STACK_TRACE.RASP) + report_stack( + "exploit detected", entry_span, crop_trace, stack_id=stack_trace_id, namespace=STACK_TRACE.RASP + ) for rule in waf_results.data: rule[EXPLOIT_PREVENTION.STACK_TRACE_ID] = stack_trace_id # Trace tagging for key, value in waf_results.meta_tags.items(): - root_span.set_tag_str(key, value) + entry_span.set_tag_str(key, value) for key, value in waf_results.metrics.items(): - root_span.set_metric(key, value) + entry_span.set_metric(key, value) if waf_results.data: log.debug("[DDAS-011-00] ASM In-App WAF returned: %s. Timeout %s", waf_results.data, waf_results.timeout) @@ -357,24 +357,24 @@ def _waf_action( if waf_results.data: _asm_request_context.store_waf_results_data(waf_results.data) if blocked: - span.set_tag(APPSEC.BLOCKED, "true") + entry_span.set_tag(APPSEC.BLOCKED, "true") # Partial DDAS-011-00 - span.set_tag_str(APPSEC.EVENT, "true") + entry_span.set_tag_str(APPSEC.EVENT, "true") remote_ip = _asm_request_context.get_waf_address(SPAN_DATA_NAMES.REQUEST_HTTP_IP) if remote_ip: # Note that if the ip collection is disabled by the env var # DD_TRACE_CLIENT_IP_HEADER_DISABLED actor.ip won't be sent - span.set_tag_str("actor.ip", remote_ip) + entry_span.set_tag_str("actor.ip", remote_ip) # Right now, we overwrite any value that could be already there. We need to reconsider when ASM/AppSec's # specs are updated. - if span.get_tag(_ORIGIN_KEY) is None: - span.set_tag_str(_ORIGIN_KEY, APPSEC.ORIGIN_VALUE) + if entry_span.get_tag(_ORIGIN_KEY) is None: + entry_span.set_tag_str(_ORIGIN_KEY, APPSEC.ORIGIN_VALUE) if waf_results.keep and allowed: - _asm_manual_keep(span) + _asm_manual_keep(entry_span) return waf_results diff --git a/ddtrace/appsec/_trace_utils.py b/ddtrace/appsec/_trace_utils.py index bb28bf5aba9..0d8f86d9bc8 100644 --- a/ddtrace/appsec/_trace_utils.py +++ b/ddtrace/appsec/_trace_utils.py @@ -37,7 +37,7 @@ def _asm_manual_keep(span: Span) -> None: span.context._meta[APPSEC.PROPAGATION_HEADER] = "02" -def _handle_metadata(root_span: Span, prefix: str, metadata: dict) -> None: +def _handle_metadata(entry_span: Span, prefix: str, metadata: dict) -> None: MAX_DEPTH = 6 if metadata is None: return @@ -55,7 +55,7 @@ def _handle_metadata(root_span: Span, prefix: str, metadata: dict) -> None: else: if isinstance(data, bool): data = "true" if data else "false" - root_span.set_tag_str(f"{prefix}", str(data)) + entry_span.set_tag_str(f"{prefix}", str(data)) def _track_user_login_common( @@ -69,7 +69,9 @@ def _track_user_login_common( span: Optional[Span] = None, ) -> Optional[Span]: if span is None: - span = core.get_root_span() + span = _asm_request_context.get_entry_span() + if not span and (current_span := core.get_span()): + span = current_span._service_entry_span if span: success_str = "success" if success else "failure" tag_prefix = "%s.%s" % (APPSEC.USER_LOGIN_EVENT_PREFIX, success_str) @@ -237,7 +239,7 @@ def track_user_signup_event( login: Optional[str] = None, login_events_mode: str = LOGIN_EVENTS_MODE.SDK, ) -> None: - span = core.get_root_span() + span = _asm_request_context.get_entry_span() if span: success_str = "true" if success else "false" span.set_tag_str(APPSEC.USER_SIGNUP_EVENT, success_str) @@ -286,7 +288,7 @@ def track_custom_event(tracer: Any, event_name: str, metadata: Dict[str, Any]) - log.warning("Empty metadata given to track_custom_event. Skipping setting tags.") return - span = core.get_root_span() + span = _asm_request_context.get_entry_span() if not span: log.warning( "No root span in the current execution. Skipping track_custom_event tags. " @@ -337,7 +339,7 @@ def block_request() -> None: meaning that if you capture the exception the request blocking could not work. """ if not asm_config._asm_enabled: - log.warning("block_request() is disabled. To use this feature please enable" "Application Security Monitoring") + log.warning("block_request() is disabled. To use this feature please enable, Application Security Monitoring") return _asm_request_context.block_request() @@ -361,16 +363,16 @@ def block_request_if_user_blocked( return if mode == LOGIN_EVENTS_MODE.AUTO: mode = asm_config._user_event_mode - root_span = core.get_root_span() - if root_span: - root_span.set_tag_str(APPSEC.AUTO_LOGIN_EVENTS_COLLECTION_MODE, mode) + entry_span = _asm_request_context.get_entry_span() + if entry_span: + entry_span.set_tag_str(APPSEC.AUTO_LOGIN_EVENTS_COLLECTION_MODE, mode) if userid: if mode == LOGIN_EVENTS_MODE.ANON: userid = _hash_user_id(str(userid)) - root_span.set_tag_str(APPSEC.AUTO_LOGIN_EVENTS_COLLECTION_MODE, mode) + entry_span.set_tag_str(APPSEC.AUTO_LOGIN_EVENTS_COLLECTION_MODE, mode) if mode != LOGIN_EVENTS_MODE.SDK: - root_span.set_tag_str(APPSEC.USER_LOGIN_USERID, str(userid)) - root_span.set_tag_str(user.ID, str(userid)) + entry_span.set_tag_str(APPSEC.USER_LOGIN_USERID, str(userid)) + entry_span.set_tag_str(user.ID, str(userid)) if should_block_user(None, userid, session_id): _asm_request_context.block_request() @@ -460,7 +462,9 @@ def _on_django_process(result_user, session_key, mode, kwargs, pin, info_retriev user_login = user_extra.get("login") res = None if result_user and result_user.is_authenticated: - span = core.get_root_span() + span = _asm_request_context.get_entry_span() + if span is None: + return if mode == LOGIN_EVENTS_MODE.ANON: hash_id = "" if isinstance(user_id, str): @@ -507,7 +511,9 @@ def _on_django_signup_user(django_config, pin, func, instance, args, kwargs, use return user_id, user_extra = get_user_info(info_retriever, django_config) if user: - span = core.get_root_span() + span = _asm_request_context.get_entry_span() + if span is None: + return _asm_manual_keep(span) span.set_tag_str(APPSEC.USER_SIGNUP_EVENT_MODE, str(asm_config._user_event_mode)) span.set_tag_str(APPSEC.USER_SIGNUP_EVENT, "true") diff --git a/ddtrace/appsec/track_user_sdk.py b/ddtrace/appsec/track_user_sdk.py index cf1f55ca904..1724a3667f7 100644 --- a/ddtrace/appsec/track_user_sdk.py +++ b/ddtrace/appsec/track_user_sdk.py @@ -15,7 +15,6 @@ from ddtrace.appsec._asm_request_context import get_blocked as _get_blocked from ddtrace.appsec._constants import WAF_ACTIONS as _WAF_ACTIONS import ddtrace.appsec.trace_utils # noqa: F401 -from ddtrace.internal import core as _core from ddtrace.internal._exceptions import BlockingException @@ -85,7 +84,7 @@ def track_user( This function should be called when a user is authenticated in the application." """ - span = _core.get_root_span() + span = _asm_request_context.get_entry_span() if span is None: return if user_id: diff --git a/ddtrace/settings/asm.py b/ddtrace/settings/asm.py index 4b2628e8413..70b7a3ae759 100644 --- a/ddtrace/settings/asm.py +++ b/ddtrace/settings/asm.py @@ -156,6 +156,7 @@ class ASMConfig(DDConfig): _iast_lazy_taint = DDConfig.var(bool, IAST.LAZY_TAINT, default=False) _iast_deduplication_enabled = DDConfig.var(bool, "DD_IAST_DEDUPLICATION_ENABLED", default=True) _iast_security_controls = DDConfig.var(str, "DD_IAST_SECURITY_CONTROLS_CONFIGURATION", default="") + _iast_use_root_span = DDConfig.var(bool, "_DD_IAST_USE_ROOT_SPAN", default=False) _iast_is_testing = False @@ -205,6 +206,7 @@ class ASMConfig(DDConfig): "_iast_telemetry_report_lvl", "_iast_security_controls", "_iast_is_testing", + "_iast_use_root_span", "_ep_enabled", "_use_metastruct_for_triggers", "_use_metastruct_for_iast", diff --git a/releasenotes/notes/appsec-report-tags-on-entry-span-7b43ae9779d05d17.yaml b/releasenotes/notes/appsec-report-tags-on-entry-span-7b43ae9779d05d17.yaml new file mode 100644 index 00000000000..b08206df4c8 --- /dev/null +++ b/releasenotes/notes/appsec-report-tags-on-entry-span-7b43ae9779d05d17.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + AAP: Fixes an issue where security signals would be incorrectly reported on an inferred proxy + service instead of the current service. diff --git a/tests/appsec/appsec/api_security/test_api_security_manager.py b/tests/appsec/appsec/api_security/test_api_security_manager.py index 908d8f0ddaa..f18560f10e6 100644 --- a/tests/appsec/appsec/api_security/test_api_security_manager.py +++ b/tests/appsec/appsec/api_security/test_api_security_manager.py @@ -36,12 +36,12 @@ def api_manager(self): def mock_environment(self): # Create a mock environment with required attributes env = MagicMock() - root_span = MagicMock(spec=Span) - root_span._meta = {} + entry_span = MagicMock(spec=Span) + entry_span._meta = {} env.span = MagicMock(spec=Span) - env.span._local_root = root_span + env.entry_span = entry_span env.span.context.sampling_priority = None - root_span.context.sampling_priority = None + entry_span.context.sampling_priority = None env.waf_addresses = {} env.blocked = None return env @@ -71,8 +71,8 @@ def test_schema_callback_already_collected(self, api_manager, mock_environment): """Test that _schema_callback exits early when schema data is already collected in the span. Expects that _should_collect_schema is not called. """ - root_span = mock_environment.span._local_root - root_span._meta = {api_manager.COLLECTED[0][1]: "some_value"} + entry_span = mock_environment.entry_span + entry_span._meta = {api_manager.COLLECTED[0][1]: "some_value"} api_manager._schema_callback(mock_environment) api_manager._should_collect_schema.assert_not_called() @@ -82,8 +82,8 @@ def test_schema_callback_sampling_priority_reject(self, api_manager, mock_enviro """Test that _schema_callback doesn't collect schema when sampling priority indicates rejection. Expects that _should_collect_schema is called but call_waf_callback is not called. """ - root_span = mock_environment.span._local_root - root_span.context.sampling_priority = sampling_priority + entry_span = mock_environment.entry_span + entry_span.context.sampling_priority = sampling_priority api_manager._should_collect_schema.return_value = False api_manager._schema_callback(mock_environment) @@ -96,8 +96,8 @@ def test_schema_callback_sampling_priority_keep(self, api_manager, mock_environm """Test that _schema_callback properly processes schemas when sampling priority indicates keep. Expects schema collection to occur and metadata to be added to the root span. """ - root_span = mock_environment.span._local_root - root_span.context.sampling_priority = sampling_priority + entry_span = mock_environment.entry_span + entry_span.context.sampling_priority = sampling_priority mock_waf_result = MagicMock() mock_waf_result.api_security = {"_dd.appsec.s.req.body": {"type": "object"}} @@ -114,8 +114,8 @@ def test_schema_callback_sampling_priority_keep(self, api_manager, mock_environm api_manager._asm_context.call_waf_callback.assert_called_once() api_manager._metrics._report_api_security.assert_called_with(True, 1) - assert len(root_span._meta) == 1 - assert "_dd.appsec.s.req.body" in root_span._meta + assert len(entry_span._meta) == 1 + assert "_dd.appsec.s.req.body" in entry_span._meta @pytest.mark.parametrize("should_collect_return", [True, False, None]) @pytest.mark.parametrize("sampling_priority", [USER_REJECT, AUTO_REJECT, AUTO_KEEP, USER_KEEP]) @@ -138,7 +138,7 @@ def test_schema_callback_apm_tracing_disabled( api_manager._asm_context.call_waf_callback.return_value = mock_waf_result api_manager._should_collect_schema.return_value = should_collect_return - mock_environment.span._local_root.context.sampling_priority = sampling_priority + mock_environment.entry_span.context.sampling_priority = sampling_priority with override_global_config(values=dict(_apm_tracing_enabled=False)): with patch("ddtrace.appsec._api_security.api_manager._asm_manual_keep") as mock_keep: @@ -146,7 +146,7 @@ def test_schema_callback_apm_tracing_disabled( # Verify manual keep was called only if should_collect_schema returns True if should_collect_return: - mock_keep.assert_called_once_with(mock_environment.span._local_root) + mock_keep.assert_called_once_with(mock_environment.entry_span) else: mock_keep.assert_not_called() @@ -195,9 +195,9 @@ def test_schema_callback_with_valid_waf_addresses(self, api_manager, mock_enviro ]: assert call_arg in call_args - root_span = mock_environment.span._local_root + entry_span = mock_environment.entry_span # Verify all schemas are stored in span metadata - assert len(root_span._meta) == 7 + assert len(entry_span._meta) == 7 for meta in [ "_dd.appsec.s.req.body", "_dd.appsec.s.req.headers", @@ -207,7 +207,7 @@ def test_schema_callback_with_valid_waf_addresses(self, api_manager, mock_enviro "_dd.appsec.s.res.headers", "_dd.appsec.s.res.body", ]: - assert meta in root_span._meta + assert meta in entry_span._meta api_manager._metrics._report_api_security.assert_called_with(True, 7) @@ -228,8 +228,8 @@ def test_schema_callback_oversized_schema(self, api_manager, mock_environment): api_manager._schema_callback(mock_environment) mock_log.warning.assert_called_once() - root_span = mock_environment.span._local_root - assert len(root_span._meta) == 0 + entry_span = mock_environment.entry_span + assert len(entry_span._meta) == 0 api_manager._metrics._report_api_security.assert_called_with(True, 0) def test_schema_callback_parse_response_body_disabled(self, api_manager, mock_environment, caplog): @@ -249,5 +249,5 @@ def test_schema_callback_parse_response_body_disabled(self, api_manager, mock_en call_args = api_manager._asm_context.call_waf_callback.call_args[0][0] assert "RESPONSE_BODY" not in call_args - assert len(mock_environment.span._local_root._meta) == 0 + assert len(mock_environment.entry_span._meta) == 0 api_manager._metrics._report_api_security.assert_called_with(True, 0) diff --git a/tests/appsec/appsec/test_appsec_trace_utils.py b/tests/appsec/appsec/test_appsec_trace_utils.py index 5970efeed76..0bb70945e46 100644 --- a/tests/appsec/appsec/test_appsec_trace_utils.py +++ b/tests/appsec/appsec/test_appsec_trace_utils.py @@ -40,7 +40,7 @@ def inject_fixtures(self, tracer, caplog): # noqa: F811 self.tracer = tracer def test_track_user_login_event_success_without_metadata(self): - with asm_context(tracer=self.tracer, span_name="test_success1", config=config_asm): + with asm_context(tracer=self.tracer, span_name="test_success1", config=config_asm) as span: track_user_login_success_event( self.tracer, "1234", @@ -52,21 +52,21 @@ def test_track_user_login_event_success_without_metadata(self): session_id="test_session_id", ) - root_span = self.tracer.current_root_span() + entry_span = span._service_entry_span failure_prefix = "%s.failure" % APPSEC.USER_LOGIN_EVENT_PREFIX - assert root_span.get_tag("appsec.events.users.login.success.track") == "true" - assert root_span.get_tag("_dd.appsec.events.users.login.success.sdk") == "true" - assert root_span.get_tag(APPSEC.AUTO_LOGIN_EVENTS_SUCCESS_MODE) == LOGIN_EVENTS_MODE.IDENT - assert not root_span.get_tag("%s.track" % failure_prefix) - assert root_span.context.sampling_priority == constants.USER_KEEP + assert entry_span.get_tag("appsec.events.users.login.success.track") == "true" + assert entry_span.get_tag("_dd.appsec.events.users.login.success.sdk") == "true" + assert entry_span.get_tag(APPSEC.AUTO_LOGIN_EVENTS_SUCCESS_MODE) == LOGIN_EVENTS_MODE.IDENT + assert not entry_span.get_tag("%s.track" % failure_prefix) + assert entry_span.context.sampling_priority == constants.USER_KEEP # set_user tags - assert root_span.get_tag(user.ID) == "1234" - assert root_span.get_tag(user.NAME) == "John" - assert root_span.get_tag(user.EMAIL) == "test@test.com" - assert root_span.get_tag(user.SCOPE) == "test_scope" - assert root_span.get_tag(user.ROLE) == "boss" - assert root_span.get_tag(user.SESSION_ID) == "test_session_id" + assert entry_span.get_tag(user.ID) == "1234" + assert entry_span.get_tag(user.NAME) == "John" + assert entry_span.get_tag(user.EMAIL) == "test@test.com" + assert entry_span.get_tag(user.SCOPE) == "test_scope" + assert entry_span.get_tag(user.ROLE) == "boss" + assert entry_span.get_tag(user.SESSION_ID) == "test_session_id" def test_track_user_login_event_success_in_span_without_metadata(self): with asm_context(tracer=self.tracer, span_name="test_success1", config=config_asm) as parent_span: @@ -104,7 +104,7 @@ def test_track_user_login_event_success_in_span_without_metadata(self): user_span.finish() def test_track_user_login_event_success_auto_mode_safe(self): - with asm_context(tracer=self.tracer, span_name="test_success1", config=config_asm): + with asm_context(tracer=self.tracer, span_name="test_success1", config=config_asm) as span: track_user_login_success_event( self.tracer, "1234", @@ -117,14 +117,14 @@ def test_track_user_login_event_success_auto_mode_safe(self): login_events_mode=LOGIN_EVENTS_MODE.ANON, ) - root_span = self.tracer.current_root_span() + entry_span = span._service_entry_span success_prefix = "%s.success" % APPSEC.USER_LOGIN_EVENT_PREFIX_PUBLIC - assert root_span.get_tag("%s.track" % success_prefix) == "true" - assert not root_span.get_tag("_dd.appsec.events.users.login.success.sdk") - assert root_span.get_tag(APPSEC.AUTO_LOGIN_EVENTS_SUCCESS_MODE) == str(LOGIN_EVENTS_MODE.ANON) + assert entry_span.get_tag("%s.track" % success_prefix) == "true" + assert not entry_span.get_tag("_dd.appsec.events.users.login.success.sdk") + assert entry_span.get_tag(APPSEC.AUTO_LOGIN_EVENTS_SUCCESS_MODE) == str(LOGIN_EVENTS_MODE.ANON) def test_track_user_login_event_success_auto_mode_extended(self): - with asm_context(tracer=self.tracer, span_name="test_success1", config=config_asm): + with asm_context(tracer=self.tracer, span_name="test_success1", config=config_asm) as span: track_user_login_success_event( self.tracer, "1234", @@ -137,30 +137,30 @@ def test_track_user_login_event_success_auto_mode_extended(self): login_events_mode=LOGIN_EVENTS_MODE.IDENT, ) - root_span = self.tracer.current_root_span() + entry_span = span._service_entry_span success_prefix = "%s.success" % APPSEC.USER_LOGIN_EVENT_PREFIX_PUBLIC - assert root_span.get_tag("%s.track" % success_prefix) == "true" - assert not root_span.get_tag("_dd.appsec.events.users.login.success.sdk") - assert root_span.get_tag(APPSEC.AUTO_LOGIN_EVENTS_SUCCESS_MODE) == str(LOGIN_EVENTS_MODE.IDENT) + assert entry_span.get_tag("%s.track" % success_prefix) == "true" + assert not entry_span.get_tag("_dd.appsec.events.users.login.success.sdk") + assert entry_span.get_tag(APPSEC.AUTO_LOGIN_EVENTS_SUCCESS_MODE) == str(LOGIN_EVENTS_MODE.IDENT) def test_track_user_login_event_success_with_metadata(self): with mock_patch.object( ddtrace.internal.telemetry.telemetry_writer, "_namespace", MagicMock() - ) as telemetry_mock, asm_context(tracer=self.tracer, span_name="test_success2", config=config_asm): + ) as telemetry_mock, asm_context(tracer=self.tracer, span_name="test_success2", config=config_asm) as span: track_user_login_success_event(self.tracer, "1234", metadata={"foo": "bar"}) - root_span = self.tracer.current_root_span() - assert root_span.get_tag("appsec.events.users.login.success.track") == "true" - assert root_span.get_tag("_dd.appsec.events.users.login.success.sdk") == "true" - assert root_span.get_tag(APPSEC.AUTO_LOGIN_EVENTS_SUCCESS_MODE) == LOGIN_EVENTS_MODE.IDENT - assert root_span.get_tag("%s.success.foo" % APPSEC.USER_LOGIN_EVENT_PREFIX_PUBLIC) == "bar" - assert root_span.context.sampling_priority == constants.USER_KEEP + entry_span = span._service_entry_span + assert entry_span.get_tag("appsec.events.users.login.success.track") == "true" + assert entry_span.get_tag("_dd.appsec.events.users.login.success.sdk") == "true" + assert entry_span.get_tag(APPSEC.AUTO_LOGIN_EVENTS_SUCCESS_MODE) == LOGIN_EVENTS_MODE.IDENT + assert entry_span.get_tag("%s.success.foo" % APPSEC.USER_LOGIN_EVENT_PREFIX_PUBLIC) == "bar" + assert entry_span.context.sampling_priority == constants.USER_KEEP # set_user tags - assert root_span.get_tag(user.ID) == "1234" - assert not root_span.get_tag(user.NAME) - assert not root_span.get_tag(user.EMAIL) - assert not root_span.get_tag(user.SCOPE) - assert not root_span.get_tag(user.ROLE) - assert not root_span.get_tag(user.SESSION_ID) + assert entry_span.get_tag(user.ID) == "1234" + assert not entry_span.get_tag(user.NAME) + assert not entry_span.get_tag(user.EMAIL) + assert not entry_span.get_tag(user.SCOPE) + assert not entry_span.get_tag(user.ROLE) + assert not entry_span.get_tag(user.SESSION_ID) metrics = get_telemetry_metrics(telemetry_mock) assert ( "count", @@ -171,7 +171,7 @@ def test_track_user_login_event_success_with_metadata(self): ) in metrics def test_track_user_login_event_failure_user_exists(self): - with asm_context(tracer=self.tracer, span_name="test_failure", config=config_asm): + with asm_context(tracer=self.tracer, span_name="test_failure", config=config_asm) as span: track_user_login_failure_event( self.tracer, "1234", @@ -181,46 +181,46 @@ def test_track_user_login_event_failure_user_exists(self): name="John Test", email="john@test.net", ) - root_span = self.tracer.current_root_span() + entry_span = span._service_entry_span success_prefix = "%s.success" % APPSEC.USER_LOGIN_EVENT_PREFIX_PUBLIC failure_prefix = "%s.failure" % APPSEC.USER_LOGIN_EVENT_PREFIX_PUBLIC - assert root_span.get_tag("%s.track" % failure_prefix) == "true" - assert root_span.get_tag("_dd.appsec.events.users.login.failure.sdk") == "true" - assert root_span.get_tag(APPSEC.AUTO_LOGIN_EVENTS_FAILURE_MODE) == LOGIN_EVENTS_MODE.IDENT - assert not root_span.get_tag("%s.track" % success_prefix) - assert not root_span.get_tag("_dd.appsec.events.users.login.success.sdk") - assert not root_span.get_tag(APPSEC.AUTO_LOGIN_EVENTS_SUCCESS_MODE) - assert root_span.get_tag("%s.%s" % (failure_prefix, user.ID)) == "1234" - assert root_span.get_tag("%s.%s" % (failure_prefix, user.EXISTS)) == "true" - assert root_span.get_tag("%s.foo" % failure_prefix) == "bar" - assert root_span.get_tag("%s.%s" % (failure_prefix, "login")) == "johntest" - assert root_span.get_tag("%s.%s" % (failure_prefix, "username")) == "John Test" - assert root_span.get_tag("%s.%s" % (failure_prefix, "email")) == "john@test.net" - - assert root_span.context.sampling_priority == constants.USER_KEEP + assert entry_span.get_tag("%s.track" % failure_prefix) == "true" + assert entry_span.get_tag("_dd.appsec.events.users.login.failure.sdk") == "true" + assert entry_span.get_tag(APPSEC.AUTO_LOGIN_EVENTS_FAILURE_MODE) == LOGIN_EVENTS_MODE.IDENT + assert not entry_span.get_tag("%s.track" % success_prefix) + assert not entry_span.get_tag("_dd.appsec.events.users.login.success.sdk") + assert not entry_span.get_tag(APPSEC.AUTO_LOGIN_EVENTS_SUCCESS_MODE) + assert entry_span.get_tag("%s.%s" % (failure_prefix, user.ID)) == "1234" + assert entry_span.get_tag("%s.%s" % (failure_prefix, user.EXISTS)) == "true" + assert entry_span.get_tag("%s.foo" % failure_prefix) == "bar" + assert entry_span.get_tag("%s.%s" % (failure_prefix, "login")) == "johntest" + assert entry_span.get_tag("%s.%s" % (failure_prefix, "username")) == "John Test" + assert entry_span.get_tag("%s.%s" % (failure_prefix, "email")) == "john@test.net" + + assert entry_span.context.sampling_priority == constants.USER_KEEP # set_user tags: shouldn't have been called - assert not root_span.get_tag(user.ID) - assert not root_span.get_tag(user.NAME) - assert not root_span.get_tag(user.EMAIL) - assert not root_span.get_tag(user.SCOPE) - assert not root_span.get_tag(user.ROLE) - assert not root_span.get_tag(user.SESSION_ID) + assert not entry_span.get_tag(user.ID) + assert not entry_span.get_tag(user.NAME) + assert not entry_span.get_tag(user.EMAIL) + assert not entry_span.get_tag(user.SCOPE) + assert not entry_span.get_tag(user.ROLE) + assert not entry_span.get_tag(user.SESSION_ID) def test_track_user_login_event_failure_user_doesnt_exists(self): with mock_patch.object( ddtrace.internal.telemetry.telemetry_writer, "_namespace", MagicMock() - ) as telemetry_mock, self.trace("test_failure"): + ) as telemetry_mock, self.trace("test_failure") as span: track_user_login_failure_event( self.tracer, "john", False, metadata={"foo": "bar"}, ) - root_span = self.tracer.current_root_span() + entry_span = span._service_entry_span failure_prefix = "%s.failure" % APPSEC.USER_LOGIN_EVENT_PREFIX_PUBLIC - assert root_span.get_tag("%s.%s" % (failure_prefix, user.EXISTS)) == "false" + assert entry_span.get_tag("%s.%s" % (failure_prefix, user.EXISTS)) == "false" metrics = get_telemetry_metrics(telemetry_mock) assert metrics == [ ("count", "appsec", "sdk.event", 1, (("event_type", "login_failure"), ("sdk_version", "v1"))) @@ -229,24 +229,24 @@ def test_track_user_login_event_failure_user_doesnt_exists(self): def test_track_user_signup_event_exists(self): with mock_patch.object( ddtrace.internal.telemetry.telemetry_writer, "_namespace", MagicMock() - ) as telemetry_mock, self.trace("test_signup_exists"): + ) as telemetry_mock, self.trace("test_signup_exists") as span: track_user_signup_event(self.tracer, "john", True) - root_span = self.tracer.current_root_span() - assert root_span.get_tag(APPSEC.USER_SIGNUP_EVENT) == "true" - assert root_span.get_tag(user.ID) == "john" + entry_span = span._service_entry_span + assert entry_span.get_tag(APPSEC.USER_SIGNUP_EVENT) == "true" + assert entry_span.get_tag(user.ID) == "john" metrics = get_telemetry_metrics(telemetry_mock) assert metrics == [("count", "appsec", "sdk.event", 1, (("event_type", "signup"), ("sdk_version", "v1")))] def test_custom_event(self): with mock_patch.object( ddtrace.internal.telemetry.telemetry_writer, "_namespace", MagicMock() - ) as telemetry_mock, self.trace("test_custom"): + ) as telemetry_mock, self.trace("test_custom") as span: event = "some_event" track_custom_event(self.tracer, event, {"foo": "bar"}) - root_span = self.tracer.current_root_span() + entry_span = span._service_entry_span - assert root_span.get_tag("%s.%s.foo" % (APPSEC.CUSTOM_EVENT_PREFIX, event)) == "bar" - assert root_span.get_tag("%s.%s.track" % (APPSEC.CUSTOM_EVENT_PREFIX, event)) == "true" + assert entry_span.get_tag("%s.%s.foo" % (APPSEC.CUSTOM_EVENT_PREFIX, event)) == "bar" + assert entry_span.get_tag("%s.%s.track" % (APPSEC.CUSTOM_EVENT_PREFIX, event)) == "true" metrics = get_telemetry_metrics(telemetry_mock) assert ("count", "appsec", "sdk.event", 1, (("event_type", "custom"), ("sdk_version", "v1"))) in metrics diff --git a/tests/appsec/appsec/test_processor.py b/tests/appsec/appsec/test_processor.py index 0429f57beda..475186d16e8 100644 --- a/tests/appsec/appsec/test_processor.py +++ b/tests/appsec/appsec/test_processor.py @@ -803,3 +803,18 @@ def test_lambda_unsupported_event(tracer, skip_event): else: # When skip_event is False, the metric should not be set assert span.get_metric(APPSEC.UNSUPPORTED_EVENT_TYPE) is None + + +def test_lambda_inferred_span(tracer): + """ + Ensure that when the service entry span is not the root span, the service entry span is tagged + and not the root span + """ + config = {"_asm_enabled": True, "_asm_processed_span_types": {SpanTypes.SERVERLESS}} + + with asm_context(tracer=tracer, config=config, span_type=SpanTypes.HTTP, span_name="aws-apigateway") as root_span: + with tracer.trace("aws.lambda", service="test_function", span_type=SpanTypes.SERVERLESS) as entry_span: + pass + + assert root_span.get_metric(APPSEC.ENABLED) is None + assert entry_span.get_metric(APPSEC.ENABLED) == 1.0 diff --git a/tests/appsec/contrib_appsec/conftest.py b/tests/appsec/contrib_appsec/conftest.py index 4fe2c8f5e62..9ea76f10aae 100644 --- a/tests/appsec/contrib_appsec/conftest.py +++ b/tests/appsec/contrib_appsec/conftest.py @@ -1,4 +1,5 @@ -import ddtrace.auto # noqa: F401 +import ddtrace.auto +from ddtrace.ext import SpanTypes # noqa: F401 # ensure the tracer is loaded and started first for possible iast patching @@ -41,6 +42,18 @@ def get_root_span(): yield get_root_span +@pytest.fixture +def entry_span(test_spans): + def get_entry_span(): + for span in test_spans.spans: + if span._is_top_level and span.span_type == SpanTypes.WEB: + return _build_tree(test_spans.spans, span) + + return None + + yield get_entry_span + + @pytest.fixture def check_waf_timeout(request): # change timeout to 50 seconds to avoid flaky timeouts @@ -64,6 +77,24 @@ def get(name): yield get +@pytest.fixture +def get_entry_span_tag(entry_span): + def get(name): + return entry_span().get_tag(name) + + yield get + + +@pytest.fixture +def get_metric(root_span): + yield lambda name: root_span().get_metric(name) + + +@pytest.fixture +def get_entry_span_metric(entry_span): + yield lambda name: entry_span().get_metric(name) + + @pytest.fixture def find_resource(test_spans, root_span): # checking both root spans and web spans for the tag @@ -78,11 +109,6 @@ def find(resource_name): yield find -@pytest.fixture -def get_metric(root_span): - yield lambda name: root_span().get_metric(name) - - def no_op(msg: str) -> None: # noqa: ARG001 """Do nothing.""" diff --git a/tests/appsec/contrib_appsec/django_app/app/middlewares.py b/tests/appsec/contrib_appsec/django_app/app/middlewares.py new file mode 100644 index 00000000000..7155eabe71d --- /dev/null +++ b/tests/appsec/contrib_appsec/django_app/app/middlewares.py @@ -0,0 +1,24 @@ +from ddtrace.trace import tracer + + +class ServiceRenamingMiddleware: + """ + If the request carries `X-Rename-Service: true`, rewrite the current + Datadog root span’s service name to “sub-service” and tag it. + """ + + def __init__(self, get_response): + self.get_response = get_response # standard Django hook + + def __call__(self, request): + # ---- before-view logic (runs on the way in) ----------------- + if request.headers.get("X-Rename-Service", "false").lower() == "true": + service_name = "sub-service" + root_span = tracer.current_root_span() + if root_span is not None: + root_span.service = service_name + root_span.set_tag("scope", service_name) + + # ---- call the view / downstream middleware ------------------ + response = self.get_response(request) + return response diff --git a/tests/appsec/contrib_appsec/django_app/settings.py b/tests/appsec/contrib_appsec/django_app/settings.py index 86c5f7df593..c28e5ff126d 100644 --- a/tests/appsec/contrib_appsec/django_app/settings.py +++ b/tests/appsec/contrib_appsec/django_app/settings.py @@ -73,6 +73,7 @@ "tests.contrib.django.middleware.ClsMiddleware", "tests.contrib.django.middleware.fn_middleware", "tests.contrib.django.middleware.EverythingMiddleware", + "tests.appsec.contrib_appsec.django_app.app.middlewares.ServiceRenamingMiddleware", ] INSTALLED_APPS = [ diff --git a/tests/appsec/contrib_appsec/django_app/urls.py b/tests/appsec/contrib_appsec/django_app/urls.py index 2e4de06b7a0..be62b5c23b4 100644 --- a/tests/appsec/contrib_appsec/django_app/urls.py +++ b/tests/appsec/contrib_appsec/django_app/urls.py @@ -91,7 +91,7 @@ def rasp(request, endpoint: str): res.append(f"File: {f.read()}") except Exception as e: res.append(f"Error: {e}") - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return HttpResponse("<\br>\n".join(res)) elif endpoint == "ssrf": res = ["ssrf endpoint"] @@ -119,7 +119,7 @@ def rasp(request, endpoint: str): res.append(f"Url: {r.text}") except Exception as e: res.append(f"Error: {e}") - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return HttpResponse("<\\br>\n".join(res)) elif endpoint == "sql_injection": res = ["sql_injection endpoint"] @@ -132,7 +132,7 @@ def rasp(request, endpoint: str): res.append(f"Url: {list(cursor)}") except Exception as e: res.append(f"Error: {e}") - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return HttpResponse("<\\br>\n".join(res)) elif endpoint == "shell_injection": res = ["shell_injection endpoint"] @@ -141,12 +141,12 @@ def rasp(request, endpoint: str): cmd = query_params[param] try: if param.startswith("cmdsys"): - res.append(f'cmd stdout: {os.system(f"ls {cmd}")}') + res.append(f"cmd stdout: {os.system(f'ls {cmd}')}") else: - res.append(f'cmd stdout: {subprocess.run(f"ls {cmd}", shell=True)}') + res.append(f"cmd stdout: {subprocess.run(f'ls {cmd}', shell=True)}") except Exception as e: res.append(f"Error: {e}") - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return HttpResponse("<\\br>\n".join(res)) elif endpoint == "command_injection": res = ["command_injection endpoint"] @@ -163,9 +163,9 @@ def rasp(request, endpoint: str): res.append(f"cmd stdout: {subprocess.run(cmd, timeout=0.5)}") except Exception as e: res.append(f"Error: {e}") - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return HttpResponse("<\\br>\n".join(res)) - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return HttpResponse(f"Unknown endpoint: {endpoint}") diff --git a/tests/appsec/contrib_appsec/fastapi_app/app.py b/tests/appsec/contrib_appsec/fastapi_app/app.py index dc714965709..d4982a0ebe6 100644 --- a/tests/appsec/contrib_appsec/fastapi_app/app.py +++ b/tests/appsec/contrib_appsec/fastapi_app/app.py @@ -51,6 +51,17 @@ async def passthrough_middleware(request: Request, call_next): """ return await call_next(request) + @app.middleware("http") + async def rename_service(request: Request, call_next): + if request.headers.get("x-rename-service", "false").lower() == "true": + service_name = "sub-service" + root_span = tracer.current_root_span() + if root_span is not None: + root_span.service = service_name + root_span.set_tag("scope", service_name) + + return await call_next(request) + @app.get("/") @app.post("/") @app.options("/") @@ -149,7 +160,7 @@ async def rasp(endpoint: str, request: Request): res.append(f"File: {f.read()}") except Exception as e: res.append(f"Error: {e}") - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return HTMLResponse("<\br>\n".join(res)) elif endpoint == "ssrf": res = ["ssrf endpoint"] @@ -177,7 +188,7 @@ async def rasp(endpoint: str, request: Request): res.append(f"Url: {r.text}") except Exception as e: res.append(f"Error: {e}") - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return HTMLResponse("<\\br>\n".join(res)) elif endpoint == "sql_injection": res = ["sql_injection endpoint"] @@ -190,7 +201,7 @@ async def rasp(endpoint: str, request: Request): res.append(f"Url: {list(cursor)}") except Exception as e: res.append(f"Error: {e}") - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return HTMLResponse("<\\br>\n".join(res)) elif endpoint == "shell_injection": res = ["shell_injection endpoint"] @@ -199,12 +210,12 @@ async def rasp(endpoint: str, request: Request): cmd = query_params[param] try: if param.startswith("cmdsys"): - res.append(f'cmd stdout: {os.system(f"ls {cmd}")}') + res.append(f"cmd stdout: {os.system(f'ls {cmd}')}") else: - res.append(f'cmd stdout: {subprocess.run(f"ls {cmd}", shell=True, timeout=1)}') + res.append(f"cmd stdout: {subprocess.run(f'ls {cmd}', shell=True, timeout=1)}") except Exception as e: res.append(f"Error: {e}") - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return HTMLResponse("<\\br>\n".join(res)) elif endpoint == "command_injection": res = ["command_injection endpoint"] @@ -212,7 +223,7 @@ async def rasp(endpoint: str, request: Request): if param.startswith("cmda"): cmd = query_params[param] try: - res.append(f'cmd stdout: {subprocess.run([cmd, "-c", "3", "localhost"], timeout=1)}') + res.append(f"cmd stdout: {subprocess.run([cmd, '-c', '3', 'localhost'], timeout=1)}") except Exception as e: res.append(f"Error: {e}") elif param.startswith("cmds"): @@ -221,9 +232,9 @@ async def rasp(endpoint: str, request: Request): res.append(f"cmd stdout: {subprocess.run(cmd, timeout=1)}") except Exception as e: res.append(f"Error: {e}") - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return HTMLResponse("<\\br>\n".join(res)) - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return HTMLResponse(f"Unknown endpoint: {endpoint}") @app.get("/login/") diff --git a/tests/appsec/contrib_appsec/flask_app/app.py b/tests/appsec/contrib_appsec/flask_app/app.py index 479dad5fdd0..069efc5f17f 100644 --- a/tests/appsec/contrib_appsec/flask_app/app.py +++ b/tests/appsec/contrib_appsec/flask_app/app.py @@ -80,7 +80,7 @@ def rasp(endpoint: str): res.append(f"File: {f.read()}") except Exception as e: res.append(f"Error: {e}") - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return "<\\br>\n".join(res) elif endpoint == "ssrf": res = ["ssrf endpoint"] @@ -108,7 +108,7 @@ def rasp(endpoint: str): res.append(f"Url: {r.text}") except Exception as e: res.append(f"Error: {e}") - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return "<\\br>\n".join(res) elif endpoint == "sql_injection": res = ["sql_injection endpoint"] @@ -121,7 +121,7 @@ def rasp(endpoint: str): res.append(f"Url: {list(cursor)}") except Exception as e: res.append(f"Error: {e}") - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return "<\\br>\n".join(res) elif endpoint == "shell_injection": res = ["shell_injection endpoint"] @@ -130,12 +130,12 @@ def rasp(endpoint: str): cmd = query_params[param] try: if param.startswith("cmdsys"): - res.append(f'cmd stdout: {os.system(f"ls {cmd}")}') + res.append(f"cmd stdout: {os.system(f'ls {cmd}')}") else: - res.append(f'cmd stdout: {subprocess.run(f"ls {cmd}", shell=True)}') + res.append(f"cmd stdout: {subprocess.run(f'ls {cmd}', shell=True)}") except Exception as e: res.append(f"Error: {e}") - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return "<\\br>\n".join(res) elif endpoint == "command_injection": res = ["command_injection endpoint"] @@ -143,7 +143,7 @@ def rasp(endpoint: str): if param.startswith("cmda"): cmd = query_params[param] try: - res.append(f'cmd stdout: {subprocess.run([cmd, "-c", "3", "localhost"])}') + res.append(f"cmd stdout: {subprocess.run([cmd, '-c', '3', 'localhost'])}") except Exception as e: res.append(f"Error: {e}") elif param.startswith("cmds"): @@ -152,9 +152,9 @@ def rasp(endpoint: str): res.append(f"cmd stdout: {subprocess.run(cmd)}") except Exception as e: res.append(f"Error: {e}") - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return "<\\br>\n".join(res) - tracer.current_span()._local_root.set_tag("rasp.request.done", endpoint) + tracer.current_span()._service_entry_span.set_tag("rasp.request.done", endpoint) return f"Unknown endpoint: {endpoint}" @@ -250,3 +250,13 @@ def login(user_id: str, login: str) -> None: login(user_id, username) return "OK" return "login failure", 401 + + +@app.before_request +def service_renaming(): + if request.headers.get("x-rename-service", "false") == "true": + service_name = "sub-service" + root_span = tracer.current_root_span() + if root_span is not None: + root_span.service = service_name + root_span.set_tag("scope", service_name) diff --git a/tests/appsec/contrib_appsec/utils.py b/tests/appsec/contrib_appsec/utils.py index ff95d4fde73..073a8254d45 100644 --- a/tests/appsec/contrib_appsec/utils.py +++ b/tests/appsec/contrib_appsec/utils.py @@ -85,27 +85,27 @@ def location(self, response) -> str: def body(self, response) -> str: raise NotImplementedError - def get_stack_trace(self, root_span, namespace): - appsec_traces = root_span().get_struct_tag(asm_constants.STACK_TRACE.TAG) or {} + def get_stack_trace(self, entry_span, namespace): + appsec_traces = entry_span().get_struct_tag(asm_constants.STACK_TRACE.TAG) or {} stacks = appsec_traces.get(namespace, []) return stacks - def check_for_stack_trace(self, root_span): - exploit = self.get_stack_trace(root_span, "exploit") + def check_for_stack_trace(self, entry_span): + exploit = self.get_stack_trace(entry_span, "exploit") stack_ids = sorted(set(t["id"] for t in exploit)) - triggers = get_triggers(root_span()) + triggers = get_triggers(entry_span()) stack_id_in_triggers = sorted(set(t["stack_id"] for t in (triggers or []) if "stack_id" in t)) assert stack_ids == stack_id_in_triggers, f"stack_ids={stack_ids}, stack_id_in_triggers={stack_id_in_triggers}" return exploit - def check_single_rule_triggered(self, rule_id: str, root_span): - triggers = get_triggers(root_span()) + def check_single_rule_triggered(self, rule_id: str, entry_span): + triggers = get_triggers(entry_span()) assert triggers is not None, "no appsec struct in root span" result = [t["rule"]["id"] for t in triggers] assert result == [rule_id], f"result={result}, expected={[rule_id]}" - def check_rules_triggered(self, rule_id: List[str], root_span): - triggers = get_triggers(root_span()) + def check_rules_triggered(self, rule_id: List[str], entry_span): + triggers = get_triggers(entry_span()) assert triggers is not None, "no appsec struct in root span" result = sorted([t["rule"]["id"] for t in triggers]) assert result == rule_id, f"result={result}, expected={rule_id}" @@ -130,7 +130,7 @@ def setup_method(self, method): _addresses_store.clear() @pytest.mark.parametrize("asm_enabled", [True, False]) - def test_healthcheck(self, interface: Interface, get_tag, asm_enabled: bool): + def test_healthcheck(self, interface: Interface, get_entry_span_tag, asm_enabled: bool): # you can disable any test in a framework like that: # if interface.name == "fastapi": # raise pytest.skip("fastapi does not have a healthcheck endpoint") @@ -142,21 +142,19 @@ def test_healthcheck(self, interface: Interface, get_tag, asm_enabled: bool): from ddtrace.settings.asm import config as asm_config assert asm_config._asm_enabled is asm_enabled - assert get_tag("http.status_code") == "200" + assert get_entry_span_tag("http.status_code") == "200" assert self.headers(response)["content-type"] == "text/html; charset=utf-8" - def test_simple_attack(self, interface: Interface, root_span, get_tag): + def test_simple_attack(self, interface: Interface, entry_span, get_entry_span_tag): with override_global_config(dict(_asm_enabled=True)): self.update_tracer(interface) response = interface.client.get("/.git?q=1") assert response.status_code == 404 - triggers = get_triggers(root_span()) + triggers = get_triggers(entry_span()) assert triggers is not None, "no appsec struct in root span" - assert get_tag("http.response.headers.content-length") - # DEV: fastapi may send "requests" instead of "fastapi" - # assert get_tag("component") == interface.name + assert get_entry_span_tag("http.response.headers.content-length") - def test_simple_attack_timeout(self, interface: Interface, root_span, get_metric): + def test_simple_attack_timeout(self, interface: Interface, entry_span, get_entry_span_metric): from unittest.mock import MagicMock from unittest.mock import patch as mock_patch @@ -172,7 +170,7 @@ def test_simple_attack_timeout(self, interface: Interface, root_span, get_metric url = f"/?{query_params}" response = interface.client.get(url, headers={"User-Agent": "Arachni/v1.5.1"}) assert response.status_code == 200 - assert get_metric("_dd.appsec.waf.timeouts") > 0, (root_span()._meta, root_span()._metrics) + assert get_entry_span_metric("_dd.appsec.waf.timeouts") > 0, (entry_span()._meta, entry_span()._metrics) args_list = [ (args[0].value, args[1].value) + args[2:] for args, kwargs in mocked.add_metric.call_args_list @@ -231,16 +229,16 @@ def parse(path: str) -> str: ("user_agent", "priority"), [("Mozilla/5.0", False), ("Arachni/v1.5.1", True), ("dd-test-scanner-log-block", True)], ) - def test_priority(self, interface: Interface, root_span, get_tag, asm_enabled, user_agent, priority): + def test_priority(self, interface: Interface, entry_span, get_entry_span_tag, asm_enabled, user_agent, priority): """Check that we only set manual keep for traces with appsec events.""" with override_global_config(dict(_asm_enabled=asm_enabled)): self.update_tracer(interface) response = interface.client.get("/", headers={"User-Agent": user_agent}) assert response.status_code == (403 if user_agent == "dd-test-scanner-log-block" and asm_enabled else 200) - span_priority = root_span()._span.context.sampling_priority + span_priority = entry_span()._span.context.sampling_priority assert (span_priority == 2) if asm_enabled and priority else (span_priority < 2) - def test_querystrings(self, interface: Interface, root_span): + def test_querystrings(self, interface: Interface, entry_span): with override_global_config(dict(_asm_enabled=True)): self.update_tracer(interface) response = interface.client.get("/?a=1&b&c=d") @@ -252,14 +250,14 @@ def test_querystrings(self, interface: Interface, root_span): {"a": ["1"], "c": ["d"]}, ] - def test_no_querystrings(self, interface: Interface, root_span): + def test_no_querystrings(self, interface: Interface, entry_span): with override_global_config(dict(_asm_enabled=True)): self.update_tracer(interface) response = interface.client.get("/") assert self.status(response) == 200 assert not _addresses_store[0].get("http.request.query") - def test_truncation_tags(self, interface: Interface, get_metric): + def test_truncation_tags(self, interface: Interface, get_entry_span_metric): with override_global_config(dict(_asm_enabled=True)): self.update_tracer(interface) body: Dict[str, Any] = {"val": "x" * 5000} @@ -270,13 +268,13 @@ def test_truncation_tags(self, interface: Interface, get_metric): content_type="application/json", ) assert self.status(response) == 200 - assert get_metric(asm_constants.APPSEC.TRUNCATION_STRING_LENGTH) + assert get_entry_span_metric(asm_constants.APPSEC.TRUNCATION_STRING_LENGTH) # 12030 is due to response encoding - assert int(get_metric(asm_constants.APPSEC.TRUNCATION_STRING_LENGTH)) == 12029 - assert get_metric(asm_constants.APPSEC.TRUNCATION_CONTAINER_SIZE) - assert int(get_metric(asm_constants.APPSEC.TRUNCATION_CONTAINER_SIZE)) == 518 + assert int(get_entry_span_metric(asm_constants.APPSEC.TRUNCATION_STRING_LENGTH)) == 12029 + assert get_entry_span_metric(asm_constants.APPSEC.TRUNCATION_CONTAINER_SIZE) + assert int(get_entry_span_metric(asm_constants.APPSEC.TRUNCATION_CONTAINER_SIZE)) == 518 - def test_truncation_telemetry(self, interface: Interface, get_metric): + def test_truncation_telemetry(self, interface: Interface, get_entry_span_metric): from unittest.mock import ANY from unittest.mock import MagicMock from unittest.mock import patch as mock_patch @@ -331,7 +329,7 @@ def test_truncation_telemetry(self, interface: Interface, get_metric): ("cookies", "attack"), [({"mytestingcookie_key": "mytestingcookie_value"}, False), ({"attack": "1' or '1' = '1'"}, True)], ) - def test_request_cookies(self, interface: Interface, root_span, asm_enabled, cookies, attack): + def test_request_cookies(self, interface: Interface, entry_span, asm_enabled, cookies, attack): with override_global_config(dict(_asm_enabled=asm_enabled, _asm_static_rule_file=rules.RULES_GOOD_PATH)): self.update_tracer(interface) response = interface.client.get("/", cookies=cookies) @@ -344,7 +342,7 @@ def test_request_cookies(self, interface: Interface, root_span, asm_enabled, coo assert cookies_parsed == cookies, f"cookies={cookies_parsed}, expected={cookies}" else: assert not _addresses_store - triggers = get_triggers(root_span()) + triggers = get_triggers(entry_span()) if asm_enabled and attack: assert triggers is not None, "no appsec struct in root span" assert len(triggers) == 1 @@ -369,7 +367,7 @@ def test_request_cookies(self, interface: Interface, root_span, asm_enabled, coo def test_request_body( self, interface: Interface, - root_span, + entry_span, asm_enabled, encode_payload, content_type, @@ -391,7 +389,7 @@ def test_request_body( else: assert not body # DEV: Flask send {} for text/plain with asm - triggers = get_triggers(root_span()) + triggers = get_triggers(entry_span()) if asm_enabled and attack and content_type != "text/plain": assert triggers is not None, "no appsec struct in root span" @@ -408,7 +406,7 @@ def test_request_body( ("text/xml"), ], ) - def test_request_body_bad(self, caplog, interface: Interface, root_span, get_tag, content_type): + def test_request_body_bad(self, caplog, interface: Interface, entry_span, get_entry_span_tag, content_type): # Ensure no crash when body is not parsable import logging @@ -421,7 +419,7 @@ def test_request_body_bad(self, caplog, interface: Interface, root_span, get_tag assert response.status_code == 200 @pytest.mark.parametrize("asm_enabled", [True, False]) - def test_request_path_params(self, interface: Interface, root_span, asm_enabled): + def test_request_path_params(self, interface: Interface, entry_span, asm_enabled): with override_global_config(dict(_asm_enabled=asm_enabled)): self.update_tracer(interface) response = interface.client.get("/asm/137/abc/") @@ -433,14 +431,14 @@ def test_request_path_params(self, interface: Interface, root_span, asm_enabled) else: assert path_params is None - def test_useragent(self, interface: Interface, root_span, get_tag): + def test_useragent(self, interface: Interface, entry_span, get_entry_span_tag): from ddtrace.ext import http with override_global_config(dict(_asm_enabled=True)): self.update_tracer(interface) response = interface.client.get("/", headers={"user-agent": "test/1.2.3"}) assert self.status(response) == 200 - assert get_tag(http.USER_AGENT) == "test/1.2.3" + assert get_entry_span_tag(http.USER_AGENT) == "test/1.2.3" @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize( @@ -453,16 +451,18 @@ def test_useragent(self, interface: Interface, root_span, get_tag): ({"x-client-ip": "192.168.1.10,192.168.1.20"}, "192.168.1.10"), ], ) - def test_client_ip_asm_enabled_reported(self, interface: Interface, get_tag, asm_enabled, headers, expected): + def test_client_ip_asm_enabled_reported( + self, interface: Interface, get_entry_span_tag, asm_enabled, headers, expected + ): from ddtrace.ext import http with override_global_config(dict(_asm_enabled=asm_enabled)): self.update_tracer(interface) interface.client.get("/", headers=headers) if asm_enabled: - assert get_tag(http.CLIENT_IP) == expected # only works on Django for now + assert get_entry_span_tag(http.CLIENT_IP) == expected # only works on Django for now else: - assert get_tag(http.CLIENT_IP) is None + assert get_entry_span_tag(http.CLIENT_IP) is None @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize( @@ -475,7 +475,7 @@ def test_client_ip_asm_enabled_reported(self, interface: Interface, get_tag, asm ], ) def test_client_ip_header_set_by_env_var( - self, interface: Interface, get_tag, root_span, asm_enabled, env_var, headers, expected + self, interface: Interface, get_entry_span_tag, entry_span, asm_enabled, env_var, headers, expected ): from ddtrace.ext import http @@ -484,11 +484,11 @@ def test_client_ip_header_set_by_env_var( response = interface.client.get("/", headers=headers) assert self.status(response) == 200 if asm_enabled: - assert get_tag(http.CLIENT_IP) == expected or ( - expected is None and get_tag(http.CLIENT_IP) == "127.0.0.1" + assert get_entry_span_tag(http.CLIENT_IP) == expected or ( + expected is None and get_entry_span_tag(http.CLIENT_IP) == "127.0.0.1" ) else: - assert get_tag(http.CLIENT_IP) is None + assert get_entry_span_tag(http.CLIENT_IP) is None @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize( @@ -505,7 +505,7 @@ def test_client_ip_header_set_by_env_var( ], ) def test_request_ipblock( - self, interface: Interface, get_tag, root_span, asm_enabled, headers, blocked, body, content_type + self, interface: Interface, get_entry_span_tag, entry_span, asm_enabled, headers, blocked, body, content_type ): from ddtrace.ext import http @@ -515,13 +515,14 @@ def test_request_ipblock( if blocked and asm_enabled: assert self.status(response) == 403 assert self.body(response) == getattr(constants, body, None) - assert get_tag("actor.ip") == rules._IP.BLOCKED - assert get_tag(http.STATUS_CODE) == "403" - assert get_tag(http.URL) == "http://localhost:8000/" - assert get_tag(http.METHOD) == "GET" - self.check_single_rule_triggered("blk-001-001", root_span) + assert get_entry_span_tag("actor.ip") == rules._IP.BLOCKED + assert get_entry_span_tag(http.STATUS_CODE) == "403" + assert get_entry_span_tag(http.URL) == "http://localhost:8000/" + assert get_entry_span_tag(http.METHOD) == "GET" + self.check_single_rule_triggered("blk-001-001", entry_span) assert ( - get_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") == content_type + get_entry_span_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") + == content_type ) assert self.headers(response)["content-type"] == content_type else: @@ -546,7 +547,16 @@ def test_request_ipblock( ], ) def test_request_ipmonitor( - self, interface: Interface, get_tag, root_span, asm_enabled, headers, monitored, bypassed, query, blocked + self, + interface: Interface, + get_entry_span_tag, + entry_span, + asm_enabled, + headers, + monitored, + bypassed, + query, + blocked, ): from ddtrace.ext import http @@ -556,19 +566,23 @@ def test_request_ipmonitor( code = 403 if not bypassed and not monitored and asm_enabled and blocked else 200 rule = "tst-421-001" if blocked else "tst-421-002" assert self.status(response) == code, f"status={self.status(response)}, expected={code}" - assert get_tag(http.STATUS_CODE) == str(code), f"status_code={get_tag(http.STATUS_CODE)}, expected={code}" + assert get_entry_span_tag(http.STATUS_CODE) == str( + code + ), f"status_code={get_entry_span_tag(http.STATUS_CODE)}, expected={code}" if asm_enabled and not bypassed: - assert get_tag(http.URL) == f"http://localhost:8000/{query}" - assert get_tag(http.METHOD) == "GET", f"method={get_tag(http.METHOD)}, expected=GET" + assert get_entry_span_tag(http.URL) == f"http://localhost:8000/{query}" assert ( - get_tag("actor.ip") == headers["X-Real-Ip"] - ), f"actor.ip={get_tag('actor.ip')}, expected={headers['X-Real-Ip']}" + get_entry_span_tag(http.METHOD) == "GET" + ), f"method={get_entry_span_tag(http.METHOD)}, expected=GET" + assert ( + get_entry_span_tag("actor.ip") == headers["X-Real-Ip"] + ), f"actor.ip={get_entry_span_tag('actor.ip')}, expected={headers['X-Real-Ip']}" if monitored: - self.check_rules_triggered(["blk-001-010", rule], root_span) + self.check_rules_triggered(["blk-001-010", rule], entry_span) else: - self.check_rules_triggered([rule], root_span) + self.check_rules_triggered([rule], entry_span) else: - assert get_triggers(root_span()) is None, f"asm struct in root span {get_triggers(root_span())}" + assert get_triggers(entry_span()) is None, f"asm struct in root span {get_triggers(entry_span())}" SUSPICIOUS_IP = "34.65.27.85" @@ -588,7 +602,7 @@ def test_request_ipmonitor( ], ) def test_request_suspicious_attacker_blocking( - self, interface: Interface, get_tag, root_span, asm_enabled, ip, agent, event, status + self, interface: Interface, get_entry_span_tag, entry_span, asm_enabled, ip, agent, event, status ): from ddtrace.ext import http @@ -606,19 +620,21 @@ def test_request_suspicious_attacker_blocking( if event and ip == self.SUSPICIOUS_IP: status = 402 assert self.status(response) == status, f"status={self.status(response)}, expected={status}" - assert get_tag(http.STATUS_CODE) == str(status), f"status_code={self.status(response)}, expected={status}" + assert get_entry_span_tag(http.STATUS_CODE) == str( + status + ), f"status_code={self.status(response)}, expected={status}" if event: self.check_single_rule_triggered( - "ua0-600-56x" if agent == "dd-test-scanner-log-block" else "ua0-600-12x", root_span + "ua0-600-56x" if agent == "dd-test-scanner-log-block" else "ua0-600-12x", entry_span ) else: - assert get_triggers(root_span()) is None + assert get_triggers(entry_span()) is None @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize("metastruct", [True, False]) @pytest.mark.parametrize(("method", "kwargs"), [("get", {}), ("post", {"data": {"key": "value"}}), ("options", {})]) def test_request_suspicious_request_block_match_method( - self, interface: Interface, get_tag, root_span, asm_enabled, metastruct, method, kwargs + self, interface: Interface, get_entry_span_tag, entry_span, asm_enabled, metastruct, method, kwargs ): # GET must be blocked from ddtrace.ext import http @@ -632,28 +648,28 @@ def test_request_suspicious_request_block_match_method( ): self.update_tracer(interface) response = getattr(interface.client, method)("/", **kwargs) - assert get_tag(http.URL) == "http://localhost:8000/" - assert get_tag(http.METHOD) == method.upper() + assert get_entry_span_tag(http.URL) == "http://localhost:8000/" + assert get_entry_span_tag(http.METHOD) == method.upper() if asm_enabled and method == "get": assert self.status(response) == 403 - assert get_tag(http.STATUS_CODE) == "403" + assert get_entry_span_tag(http.STATUS_CODE) == "403" assert self.body(response) == constants.BLOCKED_RESPONSE_JSON - self.check_single_rule_triggered("tst-037-006", root_span) + self.check_single_rule_triggered("tst-037-006", entry_span) assert ( - get_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") + get_entry_span_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") == "application/json" ) assert self.headers(response)["content-type"] == "application/json" else: assert self.status(response) == 200 - assert get_tag(http.STATUS_CODE) == "200" - assert get_triggers(root_span()) is None + assert get_entry_span_tag(http.STATUS_CODE) == "200" + assert get_triggers(entry_span()) is None @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize("metastruct", [True, False]) @pytest.mark.parametrize(("uri", "blocked"), [("/.git", True), ("/legit", False)]) def test_request_suspicious_request_block_match_uri( - self, interface: Interface, get_tag, root_span, asm_enabled, metastruct, uri, blocked + self, interface: Interface, get_entry_span_tag, entry_span, asm_enabled, metastruct, uri, blocked ): # GET must be blocked from ddtrace.ext import http @@ -665,28 +681,28 @@ def test_request_suspicious_request_block_match_uri( ): self.update_tracer(interface) response = interface.client.get(uri) - assert get_tag(http.URL) == f"http://localhost:8000{uri}" - assert get_tag(http.METHOD) == "GET" + assert get_entry_span_tag(http.URL) == f"http://localhost:8000{uri}" + assert get_entry_span_tag(http.METHOD) == "GET" if asm_enabled and blocked: assert self.status(response) == 403 - assert get_tag(http.STATUS_CODE) == "403" + assert get_entry_span_tag(http.STATUS_CODE) == "403" assert self.body(response) == constants.BLOCKED_RESPONSE_JSON - self.check_single_rule_triggered("tst-037-002", root_span) + self.check_single_rule_triggered("tst-037-002", entry_span) assert ( - get_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") + get_entry_span_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") == "application/json" ) assert self.headers(response)["content-type"] == "application/json" else: assert self.status(response) == 404 - assert get_tag(http.STATUS_CODE) == "404" - assert get_triggers(root_span()) is None + assert get_entry_span_tag(http.STATUS_CODE) == "404" + assert get_triggers(entry_span()) is None @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize("metastruct", [True, False]) @pytest.mark.parametrize("uri", ["/waf/../"]) def test_request_suspicious_request_block_match_uri_lfi( - self, interface: Interface, get_tag, root_span, asm_enabled, metastruct, uri + self, interface: Interface, get_entry_span_tag, entry_span, asm_enabled, metastruct, uri ): if interface.name in ("fastapi",): raise pytest.skip(f"TODO: fix {interface.name}") @@ -696,12 +712,12 @@ def test_request_suspicious_request_block_match_uri_lfi( with override_global_config(dict(_asm_enabled=asm_enabled, _use_metastruct_for_triggers=metastruct)): self.update_tracer(interface) interface.client.get(uri) - # assert get_tag(http.URL) == f"http://localhost:8000{uri}" - assert get_tag(http.METHOD) == "GET" + # assert get_entry_span_tag(http.URL) == f"http://localhost:8000{uri}" + assert get_entry_span_tag(http.METHOD) == "GET" if asm_enabled: - self.check_single_rule_triggered("crs-930-110", root_span) + self.check_single_rule_triggered("crs-930-110", entry_span) else: - assert get_triggers(root_span()) is None + assert get_triggers(entry_span()) is None @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize("metastruct", [True, False]) @@ -715,7 +731,7 @@ def test_request_suspicious_request_block_match_uri_lfi( ], ) def test_request_suspicious_request_block_match_path_params( - self, interface: Interface, get_tag, root_span, asm_enabled, metastruct, path, blocked + self, interface: Interface, get_entry_span_tag, entry_span, asm_enabled, metastruct, path, blocked ): from ddtrace.ext import http @@ -728,22 +744,22 @@ def test_request_suspicious_request_block_match_path_params( self.update_tracer(interface) response = interface.client.get(uri) # DEV Warning: encoded URL will behave differently - assert get_tag(http.URL) == "http://localhost:8000" + uri - assert get_tag(http.METHOD) == "GET" + assert get_entry_span_tag(http.URL) == "http://localhost:8000" + uri + assert get_entry_span_tag(http.METHOD) == "GET" if asm_enabled and blocked: assert self.status(response) == 403 - assert get_tag(http.STATUS_CODE) == "403" + assert get_entry_span_tag(http.STATUS_CODE) == "403" assert self.body(response) == constants.BLOCKED_RESPONSE_JSON - self.check_single_rule_triggered("tst-037-007", root_span) + self.check_single_rule_triggered("tst-037-007", entry_span) assert ( - get_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") + get_entry_span_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") == "application/json" ) assert self.headers(response)["content-type"] == "application/json" else: assert self.status(response) == 200 - assert get_tag(http.STATUS_CODE) == "200" - assert get_triggers(root_span()) is None + assert get_entry_span_tag(http.STATUS_CODE) == "200" + assert get_triggers(entry_span()) is None @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize("metastruct", [True, False]) @@ -756,7 +772,7 @@ def test_request_suspicious_request_block_match_path_params( ], ) def test_request_suspicious_request_block_match_query_params( - self, interface: Interface, get_tag, root_span, asm_enabled, metastruct, query, blocked + self, interface: Interface, get_entry_span_tag, entry_span, asm_enabled, metastruct, query, blocked ): if interface.name in ("django",) and query == "?toto=xtrace&toto=ytrace": raise pytest.skip(f"{interface.name} does not support multiple query params with same name") @@ -772,22 +788,22 @@ def test_request_suspicious_request_block_match_query_params( self.update_tracer(interface) response = interface.client.get(uri) # DEV Warning: encoded URL will behave differently - assert get_tag(http.URL) == "http://localhost:8000" + uri - assert get_tag(http.METHOD) == "GET" + assert get_entry_span_tag(http.URL) == "http://localhost:8000" + uri + assert get_entry_span_tag(http.METHOD) == "GET" if asm_enabled and blocked: assert self.status(response) == 403 - assert get_tag(http.STATUS_CODE) == "403" + assert get_entry_span_tag(http.STATUS_CODE) == "403" assert self.body(response) == constants.BLOCKED_RESPONSE_JSON - self.check_single_rule_triggered("tst-037-001", root_span) + self.check_single_rule_triggered("tst-037-001", entry_span) assert ( - get_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") + get_entry_span_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") == "application/json" ) assert self.headers(response)["content-type"] == "application/json" else: assert self.status(response) == 200 - assert get_tag(http.STATUS_CODE) == "200" - assert get_triggers(root_span()) is None + assert get_entry_span_tag(http.STATUS_CODE) == "200" + assert get_triggers(entry_span()) is None @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize("metastruct", [True, False]) @@ -799,7 +815,7 @@ def test_request_suspicious_request_block_match_query_params( ], ) def test_request_suspicious_request_block_match_request_headers( - self, interface: Interface, get_tag, root_span, asm_enabled, metastruct, headers, blocked + self, interface: Interface, get_entry_span_tag, entry_span, asm_enabled, metastruct, headers, blocked ): from ddtrace.ext import http @@ -811,22 +827,22 @@ def test_request_suspicious_request_block_match_request_headers( self.update_tracer(interface) response = interface.client.get("/", headers=headers) # DEV Warning: encoded URL will behave differently - assert get_tag(http.URL) == "http://localhost:8000/" - assert get_tag(http.METHOD) == "GET" + assert get_entry_span_tag(http.URL) == "http://localhost:8000/" + assert get_entry_span_tag(http.METHOD) == "GET" if asm_enabled and blocked: assert self.status(response) == 403 - assert get_tag(http.STATUS_CODE) == "403" + assert get_entry_span_tag(http.STATUS_CODE) == "403" assert self.body(response) == constants.BLOCKED_RESPONSE_JSON - self.check_single_rule_triggered("tst-037-004", root_span) + self.check_single_rule_triggered("tst-037-004", entry_span) assert ( - get_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") + get_entry_span_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") == "application/json" ) assert self.headers(response)["content-type"] == "application/json" else: assert self.status(response) == 200 - assert get_tag(http.STATUS_CODE) == "200" - assert get_triggers(root_span()) is None + assert get_entry_span_tag(http.STATUS_CODE) == "200" + assert get_triggers(entry_span()) is None @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize("metastruct", [True, False]) @@ -838,7 +854,7 @@ def test_request_suspicious_request_block_match_request_headers( ], ) def test_request_suspicious_request_block_match_request_cookies( - self, interface: Interface, get_tag, root_span, asm_enabled, metastruct, cookies, blocked + self, interface: Interface, get_entry_span_tag, entry_span, asm_enabled, metastruct, cookies, blocked ): from ddtrace.ext import http @@ -850,22 +866,22 @@ def test_request_suspicious_request_block_match_request_cookies( self.update_tracer(interface) response = interface.client.get("/", cookies=cookies) # DEV Warning: encoded URL will behave differently - assert get_tag(http.URL) == "http://localhost:8000/" - assert get_tag(http.METHOD) == "GET" + assert get_entry_span_tag(http.URL) == "http://localhost:8000/" + assert get_entry_span_tag(http.METHOD) == "GET" if asm_enabled and blocked: assert self.status(response) == 403 - assert get_tag(http.STATUS_CODE) == "403" + assert get_entry_span_tag(http.STATUS_CODE) == "403" assert self.body(response) == constants.BLOCKED_RESPONSE_JSON - self.check_single_rule_triggered("tst-037-008", root_span) + self.check_single_rule_triggered("tst-037-008", entry_span) assert ( - get_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") + get_entry_span_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") == "application/json" ) assert self.headers(response)["content-type"] == "application/json" else: assert self.status(response) == 200 - assert get_tag(http.STATUS_CODE) == "200" - assert get_triggers(root_span()) is None + assert get_entry_span_tag(http.STATUS_CODE) == "200" + assert get_triggers(entry_span()) is None @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize("metastruct", [True, False]) @@ -879,7 +895,7 @@ def test_request_suspicious_request_block_match_request_cookies( ], ) def test_request_suspicious_request_block_match_response_status( - self, interface: Interface, get_tag, root_span, asm_enabled, metastruct, uri, status, blocked + self, interface: Interface, get_entry_span_tag, entry_span, asm_enabled, metastruct, uri, status, blocked ): from ddtrace.ext import http @@ -893,22 +909,22 @@ def test_request_suspicious_request_block_match_response_status( self.update_tracer(interface) response = interface.client.get(uri) # DEV Warning: encoded URL will behave differently - assert get_tag(http.URL) == "http://localhost:8000" + uri - assert get_tag(http.METHOD) == "GET" + assert get_entry_span_tag(http.URL) == "http://localhost:8000" + uri + assert get_entry_span_tag(http.METHOD) == "GET" if asm_enabled and blocked: assert self.status(response) == 403 - assert get_tag(http.STATUS_CODE) == "403" + assert get_entry_span_tag(http.STATUS_CODE) == "403" assert self.body(response) == constants.BLOCKED_RESPONSE_JSON - self.check_single_rule_triggered(blocked, root_span) + self.check_single_rule_triggered(blocked, entry_span) assert ( - get_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") + get_entry_span_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") == "application/json" ) assert self.headers(response)["content-type"] == "application/json" else: assert self.status(response) == status - assert get_tag(http.STATUS_CODE) == str(status) - assert get_triggers(root_span()) is None + assert get_entry_span_tag(http.STATUS_CODE) == str(status) + assert get_triggers(entry_span()) is None @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize("metastruct", [True, False]) @@ -921,8 +937,18 @@ def test_request_suspicious_request_block_match_response_status( ("/asm/1/a", {"header_name": "NoWorryBeHappy"}, None), ], ) + @pytest.mark.parametrize("rename_service", [True, False]) def test_request_suspicious_request_block_match_response_headers( - self, interface: Interface, get_tag, asm_enabled, metastruct, root_span, uri, headers, blocked + self, + interface: Interface, + get_entry_span_tag, + asm_enabled, + metastruct, + entry_span, + uri, + headers, + blocked, + rename_service, ): from ddtrace.ext import http @@ -934,17 +960,17 @@ def test_request_suspicious_request_block_match_response_headers( self.update_tracer(interface) if headers: uri += "?headers=" + quote(",".join(f"{k}={v}" for k, v in headers.items())) - response = interface.client.get(uri) + response = interface.client.get(uri, headers={"x-rename-service": "true" if rename_service else "false"}) # DEV Warning: encoded URL will behave differently - assert get_tag(http.URL) == "http://localhost:8000" + uri - assert get_tag(http.METHOD) == "GET" + assert get_entry_span_tag(http.URL) == "http://localhost:8000" + uri + assert get_entry_span_tag(http.METHOD) == "GET" if asm_enabled and blocked: assert self.status(response) == 403 - assert get_tag(http.STATUS_CODE) == "403" + assert get_entry_span_tag(http.STATUS_CODE) == "403" assert self.body(response) == constants.BLOCKED_RESPONSE_JSON - self.check_single_rule_triggered(blocked, root_span) + self.check_single_rule_triggered(blocked, entry_span) assert ( - get_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") + get_entry_span_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") == "application/json" ) assert self.headers(response)["content-type"] == "application/json" @@ -952,8 +978,8 @@ def test_request_suspicious_request_block_match_response_headers( assert k not in self.headers(response) else: assert self.status(response) == 200 - assert get_tag(http.STATUS_CODE) == "200" - assert get_triggers(root_span()) is None + assert get_entry_span_tag(http.STATUS_CODE) == "200" + assert get_triggers(entry_span()) is None assert "content-length" in self.headers(response) assert int(self.headers(response)["content-length"]) == len(self.body(response).encode()) @@ -995,7 +1021,7 @@ def test_request_suspicious_request_block_match_response_headers( ids=["json", "text_json", "json_large", "xml", "form", "form_multipart", "text", "no_attack"], ) def test_request_suspicious_request_block_match_request_body( - self, interface: Interface, get_tag, asm_enabled, metastruct, root_span, body, content_type, blocked + self, interface: Interface, get_entry_span_tag, asm_enabled, metastruct, entry_span, body, content_type, blocked ): from ddtrace.ext import http @@ -1007,22 +1033,22 @@ def test_request_suspicious_request_block_match_request_body( self.update_tracer(interface) response = interface.client.post("/asm/", data=body, content_type=content_type) # DEV Warning: encoded URL will behave differently - assert get_tag(http.URL) == "http://localhost:8000/asm/" - assert get_tag(http.METHOD) == "POST" + assert get_entry_span_tag(http.URL) == "http://localhost:8000/asm/" + assert get_entry_span_tag(http.METHOD) == "POST" if asm_enabled and blocked: assert self.status(response) == 403 - assert get_tag(http.STATUS_CODE) == "403" + assert get_entry_span_tag(http.STATUS_CODE) == "403" assert self.body(response) == constants.BLOCKED_RESPONSE_JSON - self.check_single_rule_triggered(blocked, root_span) + self.check_single_rule_triggered(blocked, entry_span) assert ( - get_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") + get_entry_span_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") == "application/json" ) assert self.headers(response)["content-type"] == "application/json" else: assert self.status(response) == 200 - assert get_tag(http.STATUS_CODE) == "200" - assert get_triggers(root_span()) is None + assert get_entry_span_tag(http.STATUS_CODE) == "200" + assert get_triggers(entry_span()) is None @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize("metastruct", [True, False]) @@ -1052,10 +1078,10 @@ def test_request_suspicious_request_block_match_request_body( def test_request_suspicious_request_block_custom_actions( self, interface: Interface, - get_tag, + get_entry_span_tag, asm_enabled, metastruct, - root_span, + entry_span, query, status, rule_id, @@ -1086,19 +1112,21 @@ def test_request_suspicious_request_block_custom_actions( self.update_tracer(interface) response = interface.client.get(uri, headers=headers) # DEV Warning: encoded URL will behave differently - assert get_tag(http.URL) == "http://localhost:8000" + uri - assert get_tag(http.METHOD) == "GET" + assert get_entry_span_tag(http.URL) == "http://localhost:8000" + uri + assert get_entry_span_tag(http.METHOD) == "GET" if asm_enabled and action: assert self.status(response) == status - assert get_tag(http.STATUS_CODE) == str(status) - self.check_single_rule_triggered(rule_id, root_span) + assert get_entry_span_tag(http.STATUS_CODE) == str(status) + self.check_single_rule_triggered(rule_id, entry_span) if action == "blocked": content_type = ( "text/html" if "html" in query or (("auto" in query) and use_html) else "application/json" ) assert ( - get_tag(asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type") + get_entry_span_tag( + asm_constants.SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES + ".content-type" + ) == content_type ) assert self.headers(response)["content-type"] == content_type @@ -1111,8 +1139,8 @@ def test_request_suspicious_request_block_custom_actions( assert self.location(response) == "https://www.datadoghq.com" else: assert self.status(response) == 200 - assert get_tag(http.STATUS_CODE) == "200" - assert get_triggers(root_span()) is None + assert get_entry_span_tag(http.STATUS_CODE) == "200" + assert get_triggers(entry_span()) is None finally: # remove cache to avoid using custom templates in other tests http_cache._HTML_BLOCKED_TEMPLATE_CACHE = None @@ -1122,8 +1150,8 @@ def test_request_suspicious_request_block_custom_actions( def test_nested_appsec_events( self, interface: Interface, - get_tag, - root_span, + get_entry_span_tag, + entry_span, asm_enabled, ): from ddtrace.ext import http @@ -1132,14 +1160,14 @@ def test_nested_appsec_events( self.update_tracer(interface) response = interface.client.get("/config.php", headers={"user-agent": "Arachni/v1.5.1"}) # DEV Warning: encoded URL will behave differently - assert get_tag(http.URL) == "http://localhost:8000/config.php" - assert get_tag(http.METHOD) == "GET" + assert get_entry_span_tag(http.URL) == "http://localhost:8000/config.php" + assert get_entry_span_tag(http.METHOD) == "GET" assert self.status(response) == 404 - assert get_tag(http.STATUS_CODE) == "404" + assert get_entry_span_tag(http.STATUS_CODE) == "404" if asm_enabled: - self.check_rules_triggered(["nfd-000-001", "ua0-600-12x"], root_span) + self.check_rules_triggered(["nfd-000-001", "ua0-600-12x"], entry_span) else: - assert get_triggers(root_span()) is None + assert get_triggers(entry_span()) is None @pytest.mark.parametrize("apisec_enabled", [True, False]) @pytest.mark.parametrize("apm_tracing_enabled", [True, False]) @@ -1211,8 +1239,8 @@ def test_nested_appsec_events( def test_api_security_schemas( self, interface: Interface, - get_tag, - root_span, + get_entry_span_tag, + entry_span, apisec_enabled, apm_tracing_enabled, name, @@ -1247,12 +1275,12 @@ def test_api_security_schemas( assert asm_config._api_security_enabled == apisec_enabled assert self.status(response) == (403 if blocked else 200) - assert get_tag(http.STATUS_CODE) == ("403" if blocked else "200") + assert get_entry_span_tag(http.STATUS_CODE) == ("403" if blocked else "200") if event: - assert get_triggers(root_span()) is not None + assert get_triggers(entry_span()) is not None else: - assert get_triggers(root_span()) is None - value = get_tag(name) + assert get_triggers(entry_span()) is None + value = get_entry_span_tag(name) if apisec_enabled and not (name.startswith("_dd.appsec.s.res") and blocked): assert value, name api = json.loads(gzip.decompress(base64.b64decode(value)).decode()) @@ -1277,8 +1305,8 @@ def test_api_security_schemas( ) in telemetry_calls if not apm_tracing_enabled: - span_sampling_priority = root_span()._span.context.sampling_priority - sampling_decision = root_span().get_tag(constants.SAMPLING_DECISION_TRACE_TAG_KEY) + span_sampling_priority = entry_span()._span.context.sampling_priority + sampling_decision = get_entry_span_tag(constants.SAMPLING_DECISION_TRACE_TAG_KEY) assert ( span_sampling_priority == constants.USER_KEEP ), f"Expected 2 (USER_KEEP), got {span_sampling_priority}" @@ -1299,7 +1327,9 @@ def test_api_security_schemas( ({"SSN": "123-45-6789"}, [{"SSN": [8, {"category": "pii", "type": "us_ssn"}]}]), ], ) - def test_api_security_scanners(self, interface: Interface, get_tag, apisec_enabled, payload, expected_value): + def test_api_security_scanners( + self, interface: Interface, get_entry_span_tag, apisec_enabled, payload, expected_value + ): import base64 import gzip @@ -1313,10 +1343,10 @@ def test_api_security_scanners(self, interface: Interface, get_tag, apisec_enabl content_type="application/json", ) assert self.status(response) == 200 - assert get_tag(http.STATUS_CODE) == "200" + assert get_entry_span_tag(http.STATUS_CODE) == "200" assert asm_config._api_security_enabled == apisec_enabled - value = get_tag("_dd.appsec.s.req.body") + value = get_entry_span_tag("_dd.appsec.s.req.body") if apisec_enabled: assert value api = json.loads(gzip.decompress(base64.b64decode(value)).decode()) @@ -1327,7 +1357,7 @@ def test_api_security_scanners(self, interface: Interface, get_tag, apisec_enabl @pytest.mark.parametrize("apisec_enabled", [True, False]) @pytest.mark.parametrize("priority", ["keep", "drop"]) @pytest.mark.parametrize("delay", [0.0, 120.0]) - def test_api_security_sampling(self, interface: Interface, get_tag, apisec_enabled, priority, delay): + def test_api_security_sampling(self, interface: Interface, get_entry_span_tag, apisec_enabled, priority, delay): from ddtrace.ext import http payload = {"mastercard": "5123456789123456"} @@ -1341,10 +1371,10 @@ def test_api_security_sampling(self, interface: Interface, get_tag, apisec_enabl content_type="application/json", ) assert self.status(response) == 200 - assert get_tag(http.STATUS_CODE) == "200" + assert get_entry_span_tag(http.STATUS_CODE) == "200" assert asm_config._api_security_enabled == apisec_enabled - value = get_tag("_dd.appsec.s.req.body") + value = get_entry_span_tag("_dd.appsec.s.req.body") if apisec_enabled and priority == "keep": assert value else: @@ -1357,10 +1387,10 @@ def test_api_security_sampling(self, interface: Interface, get_tag, apisec_enabl content_type="application/json", ) assert self.status(response) == 200 - assert get_tag(http.STATUS_CODE) == "200" + assert get_entry_span_tag(http.STATUS_CODE) == "200" assert asm_config._api_security_enabled == apisec_enabled - value = get_tag("_dd.appsec.s.req.body") + value = get_entry_span_tag("_dd.appsec.s.req.body") if apisec_enabled and priority == "keep" and delay == 0.0: assert value else: @@ -1392,7 +1422,7 @@ def test_multiple_service_name(self, interface): raise AssertionError("extra service not found") @pytest.mark.parametrize("asm_enabled", [True, False]) - def test_asm_enabled_headers(self, asm_enabled, interface, get_tag, root_span): + def test_asm_enabled_headers(self, asm_enabled, interface, get_entry_span_tag, entry_span): with override_global_config(dict(_asm_enabled=asm_enabled)): self.update_tracer(interface) response = interface.client.get( @@ -1402,13 +1432,13 @@ def test_asm_enabled_headers(self, asm_enabled, interface, get_tag, root_span): assert response.status_code == 200 assert self.status(response) == 200 if asm_enabled: - assert get_tag("http.request.headers.accept") == "testheaders/a1b2c3" - assert get_tag("http.request.headers.user-agent") == "UnitTestAgent" - assert get_tag("http.request.headers.content-type") == "test/x0y9z8" + assert get_entry_span_tag("http.request.headers.accept") == "testheaders/a1b2c3" + assert get_entry_span_tag("http.request.headers.user-agent") == "UnitTestAgent" + assert get_entry_span_tag("http.request.headers.content-type") == "test/x0y9z8" else: - assert get_tag("http.request.headers.accept") is None - assert get_tag("http.request.headers.user-agent") is None - assert get_tag("http.request.headers.content-type") is None + assert get_entry_span_tag("http.request.headers.accept") is None + assert get_entry_span_tag("http.request.headers.user-agent") is None + assert get_entry_span_tag("http.request.headers.content-type") is None @pytest.mark.parametrize( "header", @@ -1425,7 +1455,9 @@ def test_asm_enabled_headers(self, asm_enabled, interface, get_tag, root_span): ) @pytest.mark.parametrize("asm_enabled", [True, False]) # RFC: https://docs.google.com/document/d/1xf-s6PtSr6heZxmO_QLUtcFzY_X_rT94lRXNq6-Ghws/edit - def test_asm_waf_integration_identify_requests(self, asm_enabled, header, interface, get_tag, root_span): + def test_asm_waf_integration_identify_requests( + self, asm_enabled, header, interface, get_entry_span_tag, entry_span + ): import random import string @@ -1440,9 +1472,9 @@ def test_asm_waf_integration_identify_requests(self, asm_enabled, header, interf assert self.status(response) == 200 meta_tagname = "http.request.headers." + header.lower() if asm_enabled: - assert get_tag(meta_tagname) == random_value + assert get_entry_span_tag(meta_tagname) == random_value else: - assert get_tag(meta_tagname) is None + assert get_entry_span_tag(meta_tagname) is None def test_global_callback_list_length(self, interface): from ddtrace.appsec import _asm_request_context @@ -1468,10 +1500,10 @@ def test_global_callback_list_length(self, interface): def test_stream_response( self, interface: Interface, - get_tag, + get_entry_span_tag, asm_enabled, metastruct, - root_span, + entry_span, ): if interface.name != "fastapi": raise pytest.skip("only fastapi tests have support for stream response") @@ -1531,9 +1563,9 @@ def test_stream_response( def test_exploit_prevention( self, interface, - root_span, - get_tag, - get_metric, + entry_span, + get_entry_span_tag, + get_entry_span_metric, asm_enabled, ep_enabled, endpoint, @@ -1572,16 +1604,16 @@ def validate_top_function(trace): response = interface.client.get(f"/rasp/{endpoint}/?{parameters}") code = status_expected if asm_enabled and ep_enabled else 200 assert self.status(response) == code, (self.status(response), code) - assert get_tag(http.STATUS_CODE) == str(code), (get_tag(http.STATUS_CODE), code) + assert get_entry_span_tag(http.STATUS_CODE) == str(code), (get_entry_span_tag(http.STATUS_CODE), code) if code == 200: assert self.body(response).startswith(f"{endpoint} endpoint") telemetry_calls = { (c.value, f"{ns.value}.{nm}", t): v for (c, ns, nm, v, t), _ in mocked.add_metric.call_args_list } if asm_enabled and ep_enabled and action_level > 0: - self.check_rules_triggered([rule] * (1 if action_level == 2 else 2), root_span) - assert self.check_for_stack_trace(root_span) - for trace in self.check_for_stack_trace(root_span): + self.check_rules_triggered([rule] * (1 if action_level == 2 else 2), entry_span) + assert self.check_for_stack_trace(entry_span) + for trace in self.check_for_stack_trace(entry_span): assert "frames" in trace assert validate_top_function( trace @@ -1614,20 +1646,22 @@ def validate_top_function(trace): # there may have been multiple evaluations of other rules too assert expected_tags in evals, (expected_tags, evals) if action_level == 2: - assert get_tag("rasp.request.done") is None, get_tag("rasp.request.done") + assert get_entry_span_tag("rasp.request.done") is None, get_entry_span_tag("rasp.request.done") else: - assert get_tag("rasp.request.done") == endpoint, get_tag("rasp.request.done") - assert get_metric(APPSEC.RASP_DURATION) is not None - assert get_metric(APPSEC.RASP_DURATION_EXT) is not None - assert get_metric(APPSEC.RASP_RULE_EVAL) is not None - assert float(get_metric(APPSEC.RASP_DURATION_EXT)) >= float(get_metric(APPSEC.RASP_DURATION)) - assert int(get_metric(APPSEC.RASP_RULE_EVAL)) > 0 + assert get_entry_span_tag("rasp.request.done") == endpoint, get_entry_span_tag("rasp.request.done") + assert get_entry_span_metric(APPSEC.RASP_DURATION) is not None + assert get_entry_span_metric(APPSEC.RASP_DURATION_EXT) is not None + assert get_entry_span_metric(APPSEC.RASP_RULE_EVAL) is not None + assert float(get_entry_span_metric(APPSEC.RASP_DURATION_EXT)) >= float( + get_entry_span_metric(APPSEC.RASP_DURATION) + ) + assert int(get_entry_span_metric(APPSEC.RASP_RULE_EVAL)) > 0 else: for _, n, _ in telemetry_calls: assert "rasp" not in n - assert get_triggers(root_span()) is None - assert self.check_for_stack_trace(root_span) == [] - assert get_tag("rasp.request.done") == endpoint, get_tag("rasp.request.done") + assert get_triggers(entry_span()) is None + assert self.check_for_stack_trace(entry_span) == [] + assert get_entry_span_tag("rasp.request.done") == endpoint, get_entry_span_tag("rasp.request.done") @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize("auto_events_enabled", [True, False]) @@ -1644,8 +1678,8 @@ def validate_top_function(trace): def test_auto_user_events( self, interface, - root_span, - get_tag, + entry_span, + get_entry_span_tag, asm_enabled, auto_events_enabled, local_mode, @@ -1669,53 +1703,56 @@ def test_auto_user_events( self.update_tracer(interface) response = interface.client.get(f"/login/?username={user}&password={password}") assert self.status(response) == status_code - assert get_tag("http.status_code") == str(status_code) + assert get_entry_span_tag("http.status_code") == str(status_code) username = user if mode == "identification" else _hash_user_id(user) user_id_hash = user_id if mode == "identification" else _hash_user_id(user_id) if asm_enabled and auto_events_enabled and mode != "disabled": if status_code == 401: - assert get_tag("appsec.events.users.login.failure.track") == "true" - assert get_tag("_dd.appsec.events.users.login.failure.auto.mode") == mode - assert get_tag("appsec.events.users.login.failure.usr.id") == ( + assert get_entry_span_tag("appsec.events.users.login.failure.track") == "true" + assert get_entry_span_tag("_dd.appsec.events.users.login.failure.auto.mode") == mode + assert get_entry_span_tag("appsec.events.users.login.failure.usr.id") == ( user_id_hash if user_id else username ) - assert get_tag("appsec.events.users.login.failure.usr.exists") == str(user == "testuuid").lower() + assert ( + get_entry_span_tag("appsec.events.users.login.failure.usr.exists") + == str(user == "testuuid").lower() + ) # check for manual instrumentation tag in manual instrumented frameworks if interface.name in ["flask", "fastapi"]: - assert get_tag("_dd.appsec.events.users.login.failure.sdk") == "true" + assert get_entry_span_tag("_dd.appsec.events.users.login.failure.sdk") == "true" else: - assert get_tag("_dd.appsec.events.users.login.success.sdk") is None + assert get_entry_span_tag("_dd.appsec.events.users.login.success.sdk") is None if mode == "identification": - assert get_tag("_dd.appsec.usr.login") == user + assert get_entry_span_tag("_dd.appsec.usr.login") == user elif mode == "anonymization": - assert get_tag("_dd.appsec.usr.login") == _hash_user_id(user) + assert get_entry_span_tag("_dd.appsec.usr.login") == _hash_user_id(user) else: - assert get_tag("appsec.events.users.login.success.track") == "true" - assert get_tag("usr.id") == user_id_hash - assert get_tag("_dd.appsec.usr.id") == user_id_hash + assert get_entry_span_tag("appsec.events.users.login.success.track") == "true" + assert get_entry_span_tag("usr.id") == user_id_hash + assert get_entry_span_tag("_dd.appsec.usr.id") == user_id_hash if mode == "identification": - assert get_tag("_dd.appsec.usr.login") == user + assert get_entry_span_tag("_dd.appsec.usr.login") == user # check for manual instrumentation tag in manual instrumented frameworks if interface.name in ["flask", "fastapi"]: - assert get_tag("_dd.appsec.events.users.login.success.sdk") == "true" + assert get_entry_span_tag("_dd.appsec.events.users.login.success.sdk") == "true" else: - assert get_tag("_dd.appsec.events.users.login.success.sdk") is None + assert get_entry_span_tag("_dd.appsec.events.users.login.success.sdk") is None else: - assert get_tag("usr.id") is None - assert not any(tag.startswith("appsec.events.users.login") for tag in root_span()._meta) - assert not any(tag.startswith("_dd_appsec.events.users.login") for tag in root_span()._meta) + assert get_entry_span_tag("usr.id") is None + assert not any(tag.startswith("appsec.events.users.login") for tag in entry_span()._meta) + assert not any(tag.startswith("_dd_appsec.events.users.login") for tag in entry_span()._meta) # check for fingerprints when user events if asm_enabled: - assert get_tag(asm_constants.FINGERPRINTING.HEADER) - assert get_tag(asm_constants.FINGERPRINTING.NETWORK) - assert get_tag(asm_constants.FINGERPRINTING.ENDPOINT) - assert get_tag(asm_constants.FINGERPRINTING.SESSION) + assert get_entry_span_tag(asm_constants.FINGERPRINTING.HEADER) + assert get_entry_span_tag(asm_constants.FINGERPRINTING.NETWORK) + assert get_entry_span_tag(asm_constants.FINGERPRINTING.ENDPOINT) + assert get_entry_span_tag(asm_constants.FINGERPRINTING.SESSION) else: - # assert get_tag(asm_constants.FINGERPRINTING.HEADER) is None - assert get_tag(asm_constants.FINGERPRINTING.NETWORK) is None - assert get_tag(asm_constants.FINGERPRINTING.ENDPOINT) is None - assert get_tag(asm_constants.FINGERPRINTING.SESSION) is None + # assert get_entry_span_tag(asm_constants.FINGERPRINTING.HEADER) is None + assert get_entry_span_tag(asm_constants.FINGERPRINTING.NETWORK) is None + assert get_entry_span_tag(asm_constants.FINGERPRINTING.ENDPOINT) is None + assert get_entry_span_tag(asm_constants.FINGERPRINTING.SESSION) is None @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize("auto_events_enabled", [True, False]) @@ -1732,8 +1769,8 @@ def test_auto_user_events( def test_auto_user_events_sdk_v2( self, interface, - root_span, - get_tag, + entry_span, + get_entry_span_tag, asm_enabled, auto_events_enabled, local_mode, @@ -1784,61 +1821,64 @@ def test_auto_user_events_sdk_v2( ) response = interface.client.get(f"/login_sdk/?username={username}&password={password}&metadata={metadata}") assert self.status(response) == status_code - assert get_tag("http.status_code") == str(status_code) + assert get_entry_span_tag("http.status_code") == str(status_code) telemetry_calls = { (c.value, f"{ns.value}.{nm}", t): v for (c, ns, nm, v, t), _ in telemetry_mock.add_metric.call_args_list } if status_code == 401: - assert get_tag("appsec.events.users.login.failure.track") == "true" + assert get_entry_span_tag("appsec.events.users.login.failure.track") == "true" if user_id: - assert get_tag("appsec.events.users.login.failure.usr.id") == user_id - assert get_tag("appsec.events.users.login.failure.usr.exists") == str(username == "testuuid").lower() - assert get_tag("_dd.appsec.events.users.login.failure.sdk") == "true" + assert get_entry_span_tag("appsec.events.users.login.failure.usr.id") == user_id + assert ( + get_entry_span_tag("appsec.events.users.login.failure.usr.exists") + == str(username == "testuuid").lower() + ) + assert get_entry_span_tag("_dd.appsec.events.users.login.failure.sdk") == "true" assert any( t[:2] == ("count", "appsec.sdk.event") and ("event_type", "login_failure") == t[2][0] for t in telemetry_calls ), telemetry_calls else: - assert get_tag("appsec.events.users.login.success.track") == "true" - assert get_tag("usr.id") == user_id - assert get_tag("usr.id") == user_id, (user_id, get_tag("usr.id")) - assert any(tag.startswith("appsec.events.users.login") for tag in root_span()._meta) - assert get_tag("_dd.appsec.events.users.login.success.sdk") == "true" + assert get_entry_span_tag("appsec.events.users.login.success.track") == "true" + assert get_entry_span_tag("usr.id") == user_id + assert get_entry_span_tag("usr.id") == user_id, (user_id, get_entry_span_tag("usr.id")) + assert any(tag.startswith("appsec.events.users.login") for tag in entry_span()._meta) + assert get_entry_span_tag("_dd.appsec.events.users.login.success.sdk") == "true" assert any( t[:2] == ("count", "appsec.sdk.event") and ("event_type", "login_success") == t[2][0] for t in telemetry_calls ), telemetry_calls # no auto instrumentation - assert not any(tag.startswith("_dd_appsec.events.users.login") for tag in root_span()._meta) + assert not any(tag.startswith("_dd_appsec.events.users.login") for tag in entry_span()._meta) # check for fingerprints when user events if asm_enabled: - assert get_tag(asm_constants.FINGERPRINTING.HEADER) - assert get_tag(asm_constants.FINGERPRINTING.NETWORK) - assert get_tag(asm_constants.FINGERPRINTING.ENDPOINT) - assert get_tag(asm_constants.FINGERPRINTING.SESSION) + assert get_entry_span_tag(asm_constants.FINGERPRINTING.HEADER) + assert get_entry_span_tag(asm_constants.FINGERPRINTING.NETWORK) + assert get_entry_span_tag(asm_constants.FINGERPRINTING.ENDPOINT) + assert get_entry_span_tag(asm_constants.FINGERPRINTING.SESSION) else: - # assert get_tag(asm_constants.FINGERPRINTING.HEADER) is None - assert get_tag(asm_constants.FINGERPRINTING.NETWORK) is None - assert get_tag(asm_constants.FINGERPRINTING.ENDPOINT) is None - assert get_tag(asm_constants.FINGERPRINTING.SESSION) is None + # assert get_entry_span_tag(asm_constants.FINGERPRINTING.HEADER) is None + assert get_entry_span_tag(asm_constants.FINGERPRINTING.NETWORK) is None + assert get_entry_span_tag(asm_constants.FINGERPRINTING.ENDPOINT) is None + assert get_entry_span_tag(asm_constants.FINGERPRINTING.SESSION) is None # metadata success = "success" if status_code == 200 else "failure" - assert get_tag(f"appsec.events.users.login.{success}.a") == "a", root_span()._meta - assert get_tag(f"appsec.events.users.login.{success}.load_a.b") == "true", root_span()._meta - assert get_tag(f"appsec.events.users.login.{success}.load_a.load_b.c") == "3", root_span()._meta + assert get_entry_span_tag(f"appsec.events.users.login.{success}.a") == "a", entry_span()._meta + assert get_entry_span_tag(f"appsec.events.users.login.{success}.load_a.b") == "true", entry_span()._meta + assert get_entry_span_tag(f"appsec.events.users.login.{success}.load_a.load_b.c") == "3", entry_span()._meta assert ( - get_tag(f"appsec.events.users.login.{success}.load_a.load_b.load_c.load_d.e") == "1.32" - ), root_span()._meta + get_entry_span_tag(f"appsec.events.users.login.{success}.load_a.load_b.load_c.load_d.e") == "1.32" + ), entry_span()._meta assert ( - get_tag(f"appsec.events.users.login.{success}.load_a.load_b.load_c.load_d.load_e.f") is None - ), root_span()._meta + get_entry_span_tag(f"appsec.events.users.login.{success}.load_a.load_b.load_c.load_d.load_e.f") is None + ), entry_span()._meta @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize("user_agent", ["dd-test-scanner-log-block", "UnitTestAgent"]) - def test_fingerprinting(self, interface, root_span, get_tag, asm_enabled, user_agent): + def test_fingerprinting(self, interface, entry_span, get_entry_span_tag, asm_enabled, user_agent): with override_global_config(dict(_asm_enabled=asm_enabled, _asm_static_rule_file=None)): self.update_tracer(interface) response = interface.client.post( @@ -1846,23 +1886,29 @@ def test_fingerprinting(self, interface, root_span, get_tag, asm_enabled, user_a ) code = 403 if asm_enabled and user_agent == "dd-test-scanner-log-block" else 200 assert self.status(response) == code - assert get_tag("http.status_code") == str(code) + assert get_entry_span_tag("http.status_code") == str(code) # check for fingerprints when security events if asm_enabled: - assert get_tag(asm_constants.FINGERPRINTING.HEADER) - assert get_tag(asm_constants.FINGERPRINTING.NETWORK) - assert get_tag(asm_constants.FINGERPRINTING.ENDPOINT) - assert get_tag(asm_constants.FINGERPRINTING.SESSION) + assert get_entry_span_tag(asm_constants.FINGERPRINTING.HEADER) + assert get_entry_span_tag(asm_constants.FINGERPRINTING.NETWORK) + assert get_entry_span_tag(asm_constants.FINGERPRINTING.ENDPOINT) + assert get_entry_span_tag(asm_constants.FINGERPRINTING.SESSION) else: - assert get_tag(asm_constants.FINGERPRINTING.HEADER) is None - assert get_tag(asm_constants.FINGERPRINTING.NETWORK) is None - assert get_tag(asm_constants.FINGERPRINTING.ENDPOINT) is None - assert get_tag(asm_constants.FINGERPRINTING.SESSION) is None + assert get_entry_span_tag(asm_constants.FINGERPRINTING.HEADER) is None + assert get_entry_span_tag(asm_constants.FINGERPRINTING.NETWORK) is None + assert get_entry_span_tag(asm_constants.FINGERPRINTING.ENDPOINT) is None + assert get_entry_span_tag(asm_constants.FINGERPRINTING.SESSION) is None @pytest.mark.parametrize("exploit_prevention_enabled", [True, False]) @pytest.mark.parametrize("api_security_enabled", [True, False]) def test_trace_tagging( - self, interface, root_span, get_tag, get_metric, exploit_prevention_enabled, api_security_enabled + self, + interface, + entry_span, + get_entry_span_tag, + get_entry_span_metric, + exploit_prevention_enabled, + api_security_enabled, ): with override_global_config( dict( @@ -1876,26 +1922,27 @@ def test_trace_tagging( random_value = "oweh1jfoi4wejflk7sdgf" response = interface.client.get(f"/?test_tag=tag_this_trace_{random_value}") assert self.status(response) == 200 - assert get_tag("http.status_code") == "200" + assert get_entry_span_tag("http.status_code") == "200" # test for trace tagging with fixed value - assert get_tag("dd.appsec.custom_tag") == "tagged_trace" + assert get_entry_span_tag("dd.appsec.custom_tag") == "tagged_trace" # test for metric tagging with fixed value - assert get_metric("dd.appsec.custom_metric") == 37 + assert get_entry_span_metric("dd.appsec.custom_metric") == 37 # test for trace tagging with dynamic value - assert get_tag("dd.appsec.custom_tag_value") == f"tag_this_trace_{random_value}" + assert get_entry_span_tag("dd.appsec.custom_tag_value") == f"tag_this_trace_{random_value}" # test for sampling priority changes. Appsec should not change the sampling priority (keep=false) - span_sampling_priority = root_span()._span.context.sampling_priority - sampling_decision = root_span().get_tag(constants.SAMPLING_DECISION_TRACE_TAG_KEY) + span_sampling_priority = entry_span()._span.context.sampling_priority + sampling_decision = get_entry_span_tag(constants.SAMPLING_DECISION_TRACE_TAG_KEY) assert span_sampling_priority < 2 or sampling_decision != f"-{constants.SamplingMechanism.APPSEC}" + @pytest.mark.parametrize("rename_service", [True, False]) @pytest.mark.parametrize("metastruct", [True, False]) - def test_iast(self, interface, root_span, get_tag, metastruct): + def test_iast(self, interface, root_span, get_tag, metastruct, rename_service): from ddtrace.ext import http - with override_global_config(dict(_use_metastruct_for_iast=metastruct)): + with override_global_config(dict(_use_metastruct_for_iast=metastruct, _iast_use_root_span=True)): url = "/rasp/command_injection/?cmds=." self.update_tracer(interface) - response = interface.client.get(url) + response = interface.client.get(url, headers={"x-rename-service": str(rename_service).lower()}) assert self.status(response) == 200 assert get_tag(http.STATUS_CODE) == "200" assert self.body(response).startswith("command_injection endpoint") diff --git a/tests/appsec/iast/aspects/test_add_aspect.py b/tests/appsec/iast/aspects/test_add_aspect.py index 25907df87a9..6a5bd2213fa 100644 --- a/tests/appsec/iast/aspects/test_add_aspect.py +++ b/tests/appsec/iast/aspects/test_add_aspect.py @@ -18,7 +18,6 @@ from tests.appsec.iast.iast_utils import _end_iast_context_and_oce from tests.appsec.iast.iast_utils import _start_iast_context_and_oce from tests.appsec.iast.iast_utils import iast_hypothesis_test -from tests.utils import override_env from tests.utils import override_global_config @@ -344,9 +343,7 @@ def test_propagate_ranges_with_no_context(caplog): ) reset_context() - with override_env({"_DD_IAST_USE_ROOT_SPAN": "false"}), override_global_config( - dict(_iast_debug=True) - ), caplog.at_level(logging.DEBUG): + with override_global_config(dict(_iast_debug=True)), caplog.at_level(logging.DEBUG): result_2 = add_aspect(result, "another_string") create_context() diff --git a/tests/appsec/iast/conftest.py b/tests/appsec/iast/conftest.py index b71aa2d1fd8..79542890bea 100644 --- a/tests/appsec/iast/conftest.py +++ b/tests/appsec/iast/conftest.py @@ -111,9 +111,7 @@ def check_native_code_exception_in_each_python_aspect_test(request, caplog): if "skip_iast_check_logs" in request.keywords: yield else: - with override_env({"_DD_IAST_USE_ROOT_SPAN": "false"}), override_global_config( - dict(_iast_debug=True) - ), caplog.at_level(logging.DEBUG): + with override_global_config(dict(_iast_debug=True)), caplog.at_level(logging.DEBUG): yield for record in caplog.get_records("call"): diff --git a/tests/appsec/iast/taint_tracking/test_native_taint_range.py b/tests/appsec/iast/taint_tracking/test_native_taint_range.py index dcb136a840c..a234315d5b9 100644 --- a/tests/appsec/iast/taint_tracking/test_native_taint_range.py +++ b/tests/appsec/iast/taint_tracking/test_native_taint_range.py @@ -33,7 +33,6 @@ from ddtrace.appsec._iast._taint_tracking.aspects import format_aspect from ddtrace.appsec._iast._taint_tracking.aspects import join_aspect from tests.appsec.iast.iast_utils import IAST_VALID_LOG -from tests.utils import override_env from tests.utils import override_global_config @@ -573,9 +572,7 @@ def test_race_conditions_reset_contexts_threads(caplog, telemetry_writer): """we want to validate context is working correctly among multiple request and no race condition creating and destroying contexts """ - with override_env({"_DD_IAST_USE_ROOT_SPAN": "false"}), override_global_config( - dict(_iast_debug=True) - ), caplog.at_level(logging.DEBUG): + with override_global_config(dict(_iast_debug=True)), caplog.at_level(logging.DEBUG): pool = ThreadPool(processes=3) results_async = [pool.apply_async(reset_contexts_loop) for _ in range(20)] _ = [res.get() for res in results_async] diff --git a/tests/appsec/iast/taint_tracking/test_taint_tracking.py b/tests/appsec/iast/taint_tracking/test_taint_tracking.py index 8e0325330c9..1d77a226991 100644 --- a/tests/appsec/iast/taint_tracking/test_taint_tracking.py +++ b/tests/appsec/iast/taint_tracking/test_taint_tracking.py @@ -18,7 +18,6 @@ from ddtrace.appsec._iast.reporter import Source from tests.appsec.iast.iast_utils import iast_hypothesis_test from tests.appsec.iast.iast_utils import non_empty_text -from tests.utils import override_env from tests.utils import override_global_config @@ -70,9 +69,7 @@ def test_taint_object_with_no_context_should_be_noop(): @pytest.mark.skip_iast_check_logs def test_propagate_ranges_with_no_context(caplog): reset_context() - with override_env({"_DD_IAST_USE_ROOT_SPAN": "false"}), override_global_config( - dict(_iast_debug=True) - ), caplog.at_level(logging.DEBUG): + with override_global_config(dict(_iast_debug=True)), caplog.at_level(logging.DEBUG): string_input = taint_pyobject( pyobject="abcde", source_name="abcde", source_value="abcde", source_origin=OriginType.PARAMETER ) diff --git a/tests/appsec/integrations/fastapi_tests/test_fastapi_appsec_iast.py b/tests/appsec/integrations/fastapi_tests/test_fastapi_appsec_iast.py index ca00fa7318a..09261f1bf0e 100644 --- a/tests/appsec/integrations/fastapi_tests/test_fastapi_appsec_iast.py +++ b/tests/appsec/integrations/fastapi_tests/test_fastapi_appsec_iast.py @@ -36,7 +36,6 @@ from tests.appsec.iast.iast_utils import get_line_and_hash from tests.appsec.iast.iast_utils import load_iast_report from tests.appsec.iast.taint_sinks.test_stacktrace_leak import _load_text_stacktrace -from tests.utils import override_env from tests.utils import override_global_config @@ -96,9 +95,7 @@ def check_native_code_exception_in_each_fastapi_test(request, caplog, telemetry_ yield else: caplog.set_level(logging.DEBUG) - with override_env({"_DD_IAST_USE_ROOT_SPAN": "false"}), override_global_config( - dict(_iast_debug=True) - ), caplog.at_level(logging.DEBUG): + with override_global_config(dict(_iast_debug=True)), caplog.at_level(logging.DEBUG): yield log_messages = [record.msg for record in caplog.get_records("call")] diff --git a/tests/telemetry/test_writer.py b/tests/telemetry/test_writer.py index 83f416e29cf..b888d67a6c2 100644 --- a/tests/telemetry/test_writer.py +++ b/tests/telemetry/test_writer.py @@ -533,6 +533,7 @@ def test_app_started_event_configuration_override(test_agent_session, run_python {"name": "DD_VERSION", "origin": "default", "value": None}, {"name": "_DD_APPSEC_DEDUPLICATION_ENABLED", "origin": "default", "value": True}, {"name": "_DD_IAST_LAZY_TAINT", "origin": "default", "value": False}, + {"name": "_DD_IAST_USE_ROOT_SPAN", "origin": "default", "value": False}, {"name": "_DD_TRACE_WRITER_LOG_ERROR_PAYLOADS", "origin": "default", "value": False}, {"name": "_DD_TRACE_WRITER_NATIVE", "origin": "default", "value": False}, {"name": "instrumentation_source", "origin": "code", "value": "manual"}, diff --git a/tests/tracer/test_span.py b/tests/tracer/test_span.py index 0b18307c93d..a89b0434c02 100644 --- a/tests/tracer/test_span.py +++ b/tests/tracer/test_span.py @@ -614,6 +614,23 @@ def test_span_record_exception_with_invalid_attributes(self, span_log): span_log.warning.assert_has_calls(expected_calls, any_order=True) assert span_log.warning.call_count == 4 + def test_service_entry_span(self): + parent = self.start_span("parent", service="service1") + child1 = self.start_span("child1", service="service1", child_of=parent) + child2 = self.start_span("child2", service="service2", child_of=parent) + + assert parent._service_entry_span is parent + assert child1._service_entry_span is parent + assert child2._service_entry_span is child2 + + # Renaming the service does not change the service entry span + child1.service = "service3" + assert child1._service_entry_span is parent + + # Service entry span only works for the immediate parent + grandchild = self.start_span("grandchild", service="service1", child_of=child2) + assert grandchild._service_entry_span is grandchild + @pytest.mark.parametrize( "value,assertion",