Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions aikido_zen/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, context_obj=None, body=None, req=None, source=None):
self.headers: Headers = Headers()
self.cookies = dict()
self.query = dict()
self.protection_forced_off = None
self.should_skip_attack_scan = None

# Parse WSGI/ASGI/... request :
self.method = self.remote_address = self.url = None
Expand Down Expand Up @@ -139,5 +139,5 @@ def get_route_metadata(self):
def get_user_agent(self):
return self.headers.get_header("USER_AGENT")

def set_force_protection_off(self, value: bool):
self.protection_forced_off = value
def set_should_skip_attack_scan(self, value: bool):
self.should_skip_attack_scan = value
18 changes: 9 additions & 9 deletions aikido_zen/context/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_wsgi_context_1():
"outgoing_req_redirects": [],
"executed_middleware": False,
"route_params": [],
"protection_forced_off": None,
"should_skip_attack_scan": None,
}
assert context.get_user_agent() is None

Expand Down Expand Up @@ -104,7 +104,7 @@ def test_wsgi_context_2():
"outgoing_req_redirects": [],
"executed_middleware": False,
"route_params": [],
"protection_forced_off": None,
"should_skip_attack_scan": None,
}
assert context.get_user_agent() == "Mozilla/5.0"

Expand Down Expand Up @@ -288,11 +288,11 @@ def test_set_valid_json_with_special_characters_bytes():
assert context.body == {"key": "value with special characters !@#$%^&*()"}


def test_set_protection_forced_off():
def test_set_should_skip_attack_scan():
context = Context(req=basic_wsgi_req, body=None, source="flask")
context.set_force_protection_off(True)
assert context.protection_forced_off is True
context.set_force_protection_off(False)
assert context.protection_forced_off is False
context.set_force_protection_off(None)
assert context.protection_forced_off is None
context.set_should_skip_attack_scan(True)
assert context.should_skip_attack_scan is True
context.set_should_skip_attack_scan(False)
assert context.should_skip_attack_scan is False
context.set_should_skip_attack_scan(None)
assert context.should_skip_attack_scan is None
30 changes: 0 additions & 30 deletions aikido_zen/helpers/is_protection_forced_off_cached.py

This file was deleted.

38 changes: 38 additions & 0 deletions aikido_zen/helpers/should_skip_attack_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from aikido_zen.thread.thread_cache import get_cache
from aikido_zen.helpers.protection_forced_off import protection_forced_off
from aikido_zen.helpers.logging import logger
from aikido_zen.context import Context


def should_skip_attack_scan(context: Context) -> bool:
"""
Check if protection is forced off or IP bypassed using cache stored in the context.
This function assumes that the thread cache has already been retrieved
and uses it to determine if protection is forced off for the given context.
"""
if not context:
return False

if context.should_skip_attack_scan is not None:
# Retrieving from cache, we don't want to constantly go through
# all the endpoints for every single vulnerability check.
return context.should_skip_attack_scan

thread_cache = get_cache()
if not thread_cache:
return False

should_skip = False
# We check for a boolean protectionForcedOff on the matching endpoints, allows users to disable scans on certain routes.
if protection_forced_off(
context.get_route_metadata(), thread_cache.get_endpoints()
):
should_skip = True
# We check for Bypassed IPs : Allows users to let their DAST not be blocked by Zen
if thread_cache.is_bypassed_ip(context.remote_address):
should_skip = True

context.set_should_skip_attack_scan(should_skip)
context.set_as_current_context()

return should_skip
2 changes: 1 addition & 1 deletion aikido_zen/sinks/tests/clickhouse_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, body):
self.source = "express"
self.route = "/"
self.parsed_userinput = {}
self.protection_forced_off = False
self.should_skip_attack_scan = False


@pytest.fixture(autouse=True)
Expand Down
33 changes: 14 additions & 19 deletions aikido_zen/vulnerabilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from aikido_zen.helpers.logging import logger
from aikido_zen.helpers.get_clean_stacktrace import get_clean_stacktrace
from aikido_zen.helpers.blocking_enabled import is_blocking_enabled
from aikido_zen.helpers.is_protection_forced_off_cached import (
is_protection_forced_off_cached,
from aikido_zen.helpers.should_skip_attack_scan import (
should_skip_attack_scan,
)
from aikido_zen.thread.thread_cache import get_cache
from .sql_injection.context_contains_sql_injection import context_contains_sql_injection
Expand All @@ -37,25 +37,16 @@ def run_vulnerability_scan(kind, op, args):
raises error if blocking is enabled, communicates it with connection_manager
"""
context = get_current_context()

if is_protection_forced_off_cached(context):
return

comms = comm.get_comms()
thread_cache = get_cache()
if not context and kind != "ssrf":
# Make a special exception for SSRF, which checks itself if context is set.
# This is because some scans/tests for SSRF do not require a context to be set.
# Make a special exception for SSRF:
# For stored ssrf we don't need a context
return

if not thread_cache and kind != "ssrf":
# Make a special exception for SSRF, which checks itself if thread cache is set.
# This is because some scans/tests for SSRF do not require a thread cache to be set.
if should_skip_attack_scan(context) and kind != "ssrf":
# Make a special exception for SSRF:
# For stored ssrf we don't want to check bypassed IPs or protection forced off.
return
if thread_cache and context:
if thread_cache.is_bypassed_ip(context.remote_address):
# This IP is on the bypass list, not scanning
return

comms = comm.get_comms()

error_type = AikidoException # Default error
error_args = tuple()
Expand Down Expand Up @@ -87,6 +78,7 @@ def run_vulnerability_scan(kind, op, args):
injection_results = inspect_getaddrinfo_result(dns_results, hostname, port)
error_type = AikidoSSRF

thread_cache = get_cache()
if thread_cache and port > 0:
thread_cache.hostnames.add(hostname, port)
else:
Expand All @@ -101,7 +93,10 @@ def run_vulnerability_scan(kind, op, args):

blocked = is_blocking_enabled()
operation = injection_results["operation"]
thread_cache.stats.on_detected_attack(blocked, operation)

thread_cache = get_cache()
if thread_cache:
thread_cache.stats.on_detected_attack(blocked, operation)

stack = get_clean_stacktrace()

Expand Down
15 changes: 10 additions & 5 deletions aikido_zen/vulnerabilities/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,13 @@ def test_ssrf_vulnerability_scan_bypassed_ip(get_context):
run_vulnerability_scan(kind="ssrf", op="test", args=(dns_results, hostname, port))
assert get_cache().stats.get_record()["requests"]["attacksDetected"]["total"] == 0

# Verify that hostnames.add was not called due to bypassed IP
assert get_cache().hostnames.as_array() == []
assert get_cache().hostnames.as_array() == [
{
"hits": 1,
"hostname": "example.com",
"port": 80,
},
]


def test_ssrf_vulnerability_scan_protection_gets_forced_off(get_context):
Expand All @@ -209,9 +214,9 @@ def test_ssrf_vulnerability_scan_protection_gets_forced_off(get_context):
dns_results = MagicMock()
hostname = "example.com"
port = 80
assert get_context.protection_forced_off is None
assert get_context.should_skip_attack_scan is None
run_vulnerability_scan(kind="ssrf", op="test", args=(dns_results, hostname, port))
assert get_context.protection_forced_off is False
assert get_context.should_skip_attack_scan is True # Bypassed IP


def test_sql_injection_with_protection_forced_off(caplog, get_context, monkeypatch):
Expand All @@ -227,7 +232,7 @@ def test_sql_injection_with_protection_forced_off(caplog, get_context, monkeypat
op="test_op",
args=("INSERT * INTO VALUES ('doggoss2', TRUE);", "mysql"),
)
get_context.set_force_protection_off(True)
get_context.set_should_skip_attack_scan(True)
run_vulnerability_scan(
kind="sql_injection",
op="test_op",
Expand Down
7 changes: 2 additions & 5 deletions aikido_zen/vulnerabilities/ssrf/inspect_getaddrinfo_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .find_hostname_in_context import find_hostname_in_context
from .extract_ip_array_from_results import extract_ip_array_from_results
from .is_redirect_to_private_ip import is_redirect_to_private_ip
from aikido_zen.helpers.should_skip_attack_scan import should_skip_attack_scan


# gets called when the result of the DNS resolution has come in
Expand All @@ -27,11 +28,7 @@ def inspect_getaddrinfo_result(dns_results, hostname, port):
return

context = get_current_context()
if not context:
return # Context should be set to check user input.
if get_cache() and get_cache().is_bypassed_ip(context.remote_address):
# We check for bypassed ip's here since it is not checked for us
# in run_vulnerability_scan due to the exception for SSRF (see above code)
if not context or should_skip_attack_scan(context):
return

# attack_findings is an object containing source, pathToPayload and payload.
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/wrk_benchmark/flask_mysql_uwsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
"http://localhost:8088/benchmark_io",
"http://localhost:8089/benchmark_io",
"a route that makes multiple I/O calls",
percentage_limit=35
percentage_limit=25
)
Loading