diff --git a/aikido_zen/context/__init__.py b/aikido_zen/context/__init__.py index ed26eba06..6fc6d4edd 100644 --- a/aikido_zen/context/__init__.py +++ b/aikido_zen/context/__init__.py @@ -54,7 +54,7 @@ def __init__(self, context_obj=None, body=None, req=None, source=None): self.headers: Headers = Headers() self.cookies = dict() self.query = dict() - self.protection_forced_off = None + self.should_skip_attack_scan = None # Parse WSGI/ASGI/... request : self.method = self.remote_address = self.url = None @@ -139,5 +139,5 @@ def get_route_metadata(self): def get_user_agent(self): return self.headers.get_header("USER_AGENT") - def set_force_protection_off(self, value: bool): - self.protection_forced_off = value + def set_should_skip_attack_scan(self, value: bool): + self.should_skip_attack_scan = value diff --git a/aikido_zen/context/init_test.py b/aikido_zen/context/init_test.py index dff7e3632..e871eb3dc 100644 --- a/aikido_zen/context/init_test.py +++ b/aikido_zen/context/init_test.py @@ -72,7 +72,7 @@ def test_wsgi_context_1(): "outgoing_req_redirects": [], "executed_middleware": False, "route_params": [], - "protection_forced_off": None, + "should_skip_attack_scan": None, } assert context.get_user_agent() is None @@ -104,7 +104,7 @@ def test_wsgi_context_2(): "outgoing_req_redirects": [], "executed_middleware": False, "route_params": [], - "protection_forced_off": None, + "should_skip_attack_scan": None, } assert context.get_user_agent() == "Mozilla/5.0" @@ -288,11 +288,11 @@ def test_set_valid_json_with_special_characters_bytes(): assert context.body == {"key": "value with special characters !@#$%^&*()"} -def test_set_protection_forced_off(): +def test_set_should_skip_attack_scan(): context = Context(req=basic_wsgi_req, body=None, source="flask") - context.set_force_protection_off(True) - assert context.protection_forced_off is True - context.set_force_protection_off(False) - assert context.protection_forced_off is False - context.set_force_protection_off(None) - assert context.protection_forced_off is None + context.set_should_skip_attack_scan(True) + assert context.should_skip_attack_scan is True + context.set_should_skip_attack_scan(False) + assert context.should_skip_attack_scan is False + context.set_should_skip_attack_scan(None) + assert context.should_skip_attack_scan is None diff --git a/aikido_zen/helpers/is_protection_forced_off_cached.py b/aikido_zen/helpers/is_protection_forced_off_cached.py deleted file mode 100644 index a6b908b4c..000000000 --- a/aikido_zen/helpers/is_protection_forced_off_cached.py +++ /dev/null @@ -1,30 +0,0 @@ -from aikido_zen.thread.thread_cache import get_cache -from aikido_zen.helpers.protection_forced_off import protection_forced_off -from aikido_zen.context import Context - - -def is_protection_forced_off_cached(context: Context) -> bool: - """ - Check if protection is forced off using cached endpoints. - This function assumes that the thread cache has already been retrieved - and uses it to determine if protection is forced off for the given context. - """ - if not context: - return False - - if context.protection_forced_off is not None: - # Retrieving from cache, we don't want to constantly go through - # all the endpoints for every single vulnerability check. - return context.protection_forced_off - - thread_cache = get_cache() - if not thread_cache: - return False - - is_forced_off = protection_forced_off( - context.get_route_metadata(), thread_cache.get_endpoints() - ) - context.set_force_protection_off(is_forced_off) - context.set_as_current_context() - - return is_forced_off diff --git a/aikido_zen/helpers/should_skip_attack_scan.py b/aikido_zen/helpers/should_skip_attack_scan.py new file mode 100644 index 000000000..207aeea2f --- /dev/null +++ b/aikido_zen/helpers/should_skip_attack_scan.py @@ -0,0 +1,38 @@ +from aikido_zen.thread.thread_cache import get_cache +from aikido_zen.helpers.protection_forced_off import protection_forced_off +from aikido_zen.helpers.logging import logger +from aikido_zen.context import Context + + +def should_skip_attack_scan(context: Context) -> bool: + """ + Check if protection is forced off or IP bypassed using cache stored in the context. + This function assumes that the thread cache has already been retrieved + and uses it to determine if protection is forced off for the given context. + """ + if not context: + return False + + if context.should_skip_attack_scan is not None: + # Retrieving from cache, we don't want to constantly go through + # all the endpoints for every single vulnerability check. + return context.should_skip_attack_scan + + thread_cache = get_cache() + if not thread_cache: + return False + + should_skip = False + # We check for a boolean protectionForcedOff on the matching endpoints, allows users to disable scans on certain routes. + if protection_forced_off( + context.get_route_metadata(), thread_cache.get_endpoints() + ): + should_skip = True + # We check for Bypassed IPs : Allows users to let their DAST not be blocked by Zen + if thread_cache.is_bypassed_ip(context.remote_address): + should_skip = True + + context.set_should_skip_attack_scan(should_skip) + context.set_as_current_context() + + return should_skip diff --git a/aikido_zen/sinks/tests/clickhouse_driver_test.py b/aikido_zen/sinks/tests/clickhouse_driver_test.py index bedeabeeb..cf3255954 100644 --- a/aikido_zen/sinks/tests/clickhouse_driver_test.py +++ b/aikido_zen/sinks/tests/clickhouse_driver_test.py @@ -17,7 +17,7 @@ def __init__(self, body): self.source = "express" self.route = "/" self.parsed_userinput = {} - self.protection_forced_off = False + self.should_skip_attack_scan = False @pytest.fixture(autouse=True) diff --git a/aikido_zen/vulnerabilities/__init__.py b/aikido_zen/vulnerabilities/__init__.py index bf37a49f7..17dc8ce53 100644 --- a/aikido_zen/vulnerabilities/__init__.py +++ b/aikido_zen/vulnerabilities/__init__.py @@ -16,8 +16,8 @@ from aikido_zen.helpers.logging import logger from aikido_zen.helpers.get_clean_stacktrace import get_clean_stacktrace from aikido_zen.helpers.blocking_enabled import is_blocking_enabled -from aikido_zen.helpers.is_protection_forced_off_cached import ( - is_protection_forced_off_cached, +from aikido_zen.helpers.should_skip_attack_scan import ( + should_skip_attack_scan, ) from aikido_zen.thread.thread_cache import get_cache from .sql_injection.context_contains_sql_injection import context_contains_sql_injection @@ -37,25 +37,16 @@ def run_vulnerability_scan(kind, op, args): raises error if blocking is enabled, communicates it with connection_manager """ context = get_current_context() - - if is_protection_forced_off_cached(context): - return - - comms = comm.get_comms() - thread_cache = get_cache() if not context and kind != "ssrf": - # Make a special exception for SSRF, which checks itself if context is set. - # This is because some scans/tests for SSRF do not require a context to be set. + # Make a special exception for SSRF: + # For stored ssrf we don't need a context return - - if not thread_cache and kind != "ssrf": - # Make a special exception for SSRF, which checks itself if thread cache is set. - # This is because some scans/tests for SSRF do not require a thread cache to be set. + if should_skip_attack_scan(context) and kind != "ssrf": + # Make a special exception for SSRF: + # For stored ssrf we don't want to check bypassed IPs or protection forced off. return - if thread_cache and context: - if thread_cache.is_bypassed_ip(context.remote_address): - # This IP is on the bypass list, not scanning - return + + comms = comm.get_comms() error_type = AikidoException # Default error error_args = tuple() @@ -87,6 +78,7 @@ def run_vulnerability_scan(kind, op, args): injection_results = inspect_getaddrinfo_result(dns_results, hostname, port) error_type = AikidoSSRF + thread_cache = get_cache() if thread_cache and port > 0: thread_cache.hostnames.add(hostname, port) else: @@ -101,7 +93,10 @@ def run_vulnerability_scan(kind, op, args): blocked = is_blocking_enabled() operation = injection_results["operation"] - thread_cache.stats.on_detected_attack(blocked, operation) + + thread_cache = get_cache() + if thread_cache: + thread_cache.stats.on_detected_attack(blocked, operation) stack = get_clean_stacktrace() diff --git a/aikido_zen/vulnerabilities/init_test.py b/aikido_zen/vulnerabilities/init_test.py index da7b13e1c..06d1ec8ac 100644 --- a/aikido_zen/vulnerabilities/init_test.py +++ b/aikido_zen/vulnerabilities/init_test.py @@ -198,8 +198,13 @@ def test_ssrf_vulnerability_scan_bypassed_ip(get_context): run_vulnerability_scan(kind="ssrf", op="test", args=(dns_results, hostname, port)) assert get_cache().stats.get_record()["requests"]["attacksDetected"]["total"] == 0 - # Verify that hostnames.add was not called due to bypassed IP - assert get_cache().hostnames.as_array() == [] + assert get_cache().hostnames.as_array() == [ + { + "hits": 1, + "hostname": "example.com", + "port": 80, + }, + ] def test_ssrf_vulnerability_scan_protection_gets_forced_off(get_context): @@ -209,9 +214,9 @@ def test_ssrf_vulnerability_scan_protection_gets_forced_off(get_context): dns_results = MagicMock() hostname = "example.com" port = 80 - assert get_context.protection_forced_off is None + assert get_context.should_skip_attack_scan is None run_vulnerability_scan(kind="ssrf", op="test", args=(dns_results, hostname, port)) - assert get_context.protection_forced_off is False + assert get_context.should_skip_attack_scan is True # Bypassed IP def test_sql_injection_with_protection_forced_off(caplog, get_context, monkeypatch): @@ -227,7 +232,7 @@ def test_sql_injection_with_protection_forced_off(caplog, get_context, monkeypat op="test_op", args=("INSERT * INTO VALUES ('doggoss2', TRUE);", "mysql"), ) - get_context.set_force_protection_off(True) + get_context.set_should_skip_attack_scan(True) run_vulnerability_scan( kind="sql_injection", op="test_op", diff --git a/aikido_zen/vulnerabilities/ssrf/inspect_getaddrinfo_result.py b/aikido_zen/vulnerabilities/ssrf/inspect_getaddrinfo_result.py index a263a96a1..fd941024f 100644 --- a/aikido_zen/vulnerabilities/ssrf/inspect_getaddrinfo_result.py +++ b/aikido_zen/vulnerabilities/ssrf/inspect_getaddrinfo_result.py @@ -13,6 +13,7 @@ from .find_hostname_in_context import find_hostname_in_context from .extract_ip_array_from_results import extract_ip_array_from_results from .is_redirect_to_private_ip import is_redirect_to_private_ip +from aikido_zen.helpers.should_skip_attack_scan import should_skip_attack_scan # gets called when the result of the DNS resolution has come in @@ -27,11 +28,7 @@ def inspect_getaddrinfo_result(dns_results, hostname, port): return context = get_current_context() - if not context: - return # Context should be set to check user input. - if get_cache() and get_cache().is_bypassed_ip(context.remote_address): - # We check for bypassed ip's here since it is not checked for us - # in run_vulnerability_scan due to the exception for SSRF (see above code) + if not context or should_skip_attack_scan(context): return # attack_findings is an object containing source, pathToPayload and payload. diff --git a/benchmarks/wrk_benchmark/flask_mysql_uwsgi.py b/benchmarks/wrk_benchmark/flask_mysql_uwsgi.py index e728b7596..87958552c 100644 --- a/benchmarks/wrk_benchmark/flask_mysql_uwsgi.py +++ b/benchmarks/wrk_benchmark/flask_mysql_uwsgi.py @@ -14,5 +14,5 @@ "http://localhost:8088/benchmark_io", "http://localhost:8089/benchmark_io", "a route that makes multiple I/O calls", - percentage_limit=35 + percentage_limit=25 )