diff --git a/pytest_socket.py b/pytest_socket.py index 1f7160b..270e02b 100644 --- a/pytest_socket.py +++ b/pytest_socket.py @@ -1,3 +1,4 @@ +import ipaddress import socket import pytest @@ -179,7 +180,7 @@ def socket_allow_hosts(allowed=None, allow_unix_socket=False): def guarded_connect(inst, *args): host = host_from_connect_args(args) - if host in allowed or (_is_unix_socket(inst.family) and allow_unix_socket): + if is_valid_host(host, allowed) or (_is_unix_socket(inst.family) and allow_unix_socket): return _true_connect(inst, *args) raise SocketConnectBlockedError(allowed, host) @@ -191,3 +192,17 @@ def _remove_restrictions(): """restore socket.socket.* to allow access to the Internet. useful in testing.""" socket.socket = _true_socket socket.socket.connect = _true_connect + + +def is_valid_host(host, allowed): + if not host: + return + ips = [ip for ip in allowed if "/" not in ip] + if host in ips: + return True + networks = [ipaddress.ip_network(mask) for mask in allowed if "/" in mask] + ip = ipaddress.ip_address(host) + for net in networks: + if ip in net: + return True + return False diff --git a/tests/test_restrict_hosts.py b/tests/test_restrict_hosts.py index e1b721b..5efb258 100644 --- a/tests/test_restrict_hosts.py +++ b/tests/test_restrict_hosts.py @@ -1,4 +1,5 @@ import inspect +from urllib.parse import urlparse import pytest @@ -256,3 +257,33 @@ def test_fail_2(): result.assert_outcomes(1, 0, 2) assert_host_blocked(result, "2.2.2.2") assert_host_blocked(result, httpbin.host) + + +def test_cidr_allow(testdir, httpbin): + test_url = urlparse(httpbin.url) + testdir.makepyfile( + """ + import pytest + import socket + @pytest.mark.allow_hosts('127.0.0.0/8') + def test_pass(): + socket.socket().connect(('{0}', {1})) + @pytest.mark.allow_hosts('127.0.0.0/16') + def test_pass_2(): + socket.socket().connect(('{0}', {1})) + def test_fail(): + socket.socket().connect(('2.2.2.2', {1})) + def test_fail_2(): + socket.socket().connect(('192.168.1.10', {1})) + @pytest.mark.allow_hosts('172.20.0.0/16') + def test_fail_3(): + socket.socket().connect(('{0}', {1})) + """.format( + test_url.hostname, test_url.port + ) + ) + result = testdir.runpytest("--verbose", "--allow-hosts=1.2.3.4") + result.assert_outcomes(2, 0, 3) + assert_host_blocked(result, "2.2.2.2") + assert_host_blocked(result, "192.168.1.10") + assert_host_blocked(result, test_url.hostname)