Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion pytest_socket.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ipaddress
import socket

import pytest
Expand Down Expand Up @@ -179,7 +180,7 @@ def socket_allow_hosts(allowed=None, allow_unix_socket=False):

def guarded_connect(inst, *args):
host = host_from_connect_args(args)
if host in allowed or (_is_unix_socket(inst.family) and allow_unix_socket):
if is_valid_host(host, allowed) or (_is_unix_socket(inst.family) and allow_unix_socket):
return _true_connect(inst, *args)

raise SocketConnectBlockedError(allowed, host)
Expand All @@ -191,3 +192,17 @@ def _remove_restrictions():
"""restore socket.socket.* to allow access to the Internet. useful in testing."""
socket.socket = _true_socket
socket.socket.connect = _true_connect


def is_valid_host(host, allowed):
if not host:
return
ips = [ip for ip in allowed if "/" not in ip]
if host in ips:
return True
networks = [ipaddress.ip_network(mask) for mask in allowed if "/" in mask]
ip = ipaddress.ip_address(host)
for net in networks:
if ip in net:
return True
return False
31 changes: 31 additions & 0 deletions tests/test_restrict_hosts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
from urllib.parse import urlparse

import pytest

Expand Down Expand Up @@ -256,3 +257,33 @@ def test_fail_2():
result.assert_outcomes(1, 0, 2)
assert_host_blocked(result, "2.2.2.2")
assert_host_blocked(result, httpbin.host)


def test_cidr_allow(testdir, httpbin):
test_url = urlparse(httpbin.url)
testdir.makepyfile(
"""
import pytest
import socket
@pytest.mark.allow_hosts('127.0.0.0/8')
def test_pass():
socket.socket().connect(('{0}', {1}))
@pytest.mark.allow_hosts('127.0.0.0/16')
def test_pass_2():
socket.socket().connect(('{0}', {1}))
def test_fail():
socket.socket().connect(('2.2.2.2', {1}))
def test_fail_2():
socket.socket().connect(('192.168.1.10', {1}))
@pytest.mark.allow_hosts('172.20.0.0/16')
def test_fail_3():
socket.socket().connect(('{0}', {1}))
""".format(
test_url.hostname, test_url.port
)
)
result = testdir.runpytest("--verbose", "--allow-hosts=1.2.3.4")
result.assert_outcomes(2, 0, 3)
assert_host_blocked(result, "2.2.2.2")
assert_host_blocked(result, "192.168.1.10")
assert_host_blocked(result, test_url.hostname)