diff --git a/docker-compose.yml b/docker-compose.yml index 8646582a..57d53148 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,13 +1,26 @@ services: sim: image: martenseemann/quic-network-simulator - container_name: sim hostname: sim stdin_open: true tty: true environment: - WAITFORSERVER=$WAITFORSERVER - SCENARIO=$SCENARIO + - CLIENT_V4_ADDR=${CLIENT_V4_ADDR} + - CLIENT_V4_GATEWAY=${CLIENT_V4_NET}.2 + - CLIENT_V6_ADDR=${CLIENT_V6_ADDR} + - CLIENT_V6_GATEWAY=${CLIENT_V6_NET}::2 + - SERVER_V4_ADDR=${SERVER_V4_ADDR} + - SERVER_V4_GATEWAY=${SERVER_V4_NET}.2 + - SERVER_V6_ADDR=${SERVER_V6_ADDR} + - SERVER_V6_GATEWAY=${SERVER_V6_NET}::2 + - SUBNET_V4=${SUBNET_V4} + - SUBNET_V6=${SUBNET_V6} + - V4_PREFIX=${V4_PREFIX} + - V6_PREFIX=${V6_PREFIX} + - LEFTNET_NAME=eth0 + - RIGHTNET_NAME=eth1 cap_add: - NET_ADMIN - NET_RAW @@ -15,19 +28,18 @@ services: - "57832" networks: leftnet: - ipv4_address: 193.167.0.2 - ipv6_address: fd00:cafe:cafe:0::2 + ipv4_address: ${CLIENT_V4_NET}.2 + ipv6_address: ${CLIENT_V6_NET}::2 interface_name: eth0 rightnet: - ipv4_address: 193.167.100.2 - ipv6_address: fd00:cafe:cafe:100::2 + ipv4_address: ${SERVER_V4_NET}.2 + ipv6_address: ${SERVER_V6_NET}::2 interface_name: eth1 extra_hosts: - - "server:193.167.100.100" + - "server:${SERVER_V4_ADDR}" server: image: $SERVER - container_name: server hostname: server stdin_open: true tty: true @@ -41,6 +53,11 @@ services: - SSLKEYLOGFILE=/logs/keys.log - QLOGDIR=/logs/qlog/ - TESTCASE=$TESTCASE_SERVER + - SUBNET_V4_PREFIX=${SUBNET_V4_PREFIX} + - SUBNET_V4_SUBNET=${SUBNET_V4_SUBNET} + - SUBNET_V4=${SUBNET_V4} + - SUBNET_V6_PREFIX=${SUBNET_V6_PREFIX} + - SUBNET_V6=${SUBNET_V6} depends_on: - sim cap_add: @@ -49,16 +66,15 @@ services: memlock: 67108864 networks: rightnet: - ipv4_address: 193.167.100.100 - ipv6_address: fd00:cafe:cafe:100::100 + ipv4_address: ${SERVER_V4_ADDR} + ipv6_address: ${SERVER_V6_ADDR} interface_name: eth0 extra_hosts: - - "server4:193.167.100.100" - - "server6:fd00:cafe:cafe:100::100" + - "server4:${SERVER_V4_ADDR}" + - "server6:${SERVER_V6_ADDR}" client: image: $CLIENT - container_name: client hostname: client stdin_open: true tty: true @@ -73,6 +89,11 @@ services: - QLOGDIR=/logs/qlog/ - TESTCASE=$TESTCASE_CLIENT - REQUESTS=$REQUESTS + - SUBNET_V4_PREFIX=${SUBNET_V4_PREFIX} + - SUBNET_V4_SUBNET=${SUBNET_V4_SUBNET} + - SUBNET_V4=${SUBNET_V4} + - SUBNET_V6_PREFIX=${SUBNET_V6_PREFIX} + - SUBNET_V6=${SUBNET_V6} depends_on: - sim cap_add: @@ -81,18 +102,17 @@ services: memlock: 67108864 networks: leftnet: - ipv4_address: 193.167.0.100 - ipv6_address: fd00:cafe:cafe:0::100 + ipv4_address: ${CLIENT_V4_ADDR} + ipv6_address: ${CLIENT_V6_ADDR} interface_name: eth0 extra_hosts: - - "server4:193.167.100.100" - - "server6:fd00:cafe:cafe:100::100" - - "server46:193.167.100.100" - - "server46:fd00:cafe:cafe:100::100" + - "server4:${SERVER_V4_ADDR}" + - "server6:${SERVER_V6_ADDR}" + - "server46:${SERVER_V4_ADDR}" + - "server46:${SERVER_V6_ADDR}" iperf_server: image: martenseemann/quic-interop-iperf-endpoint - container_name: iperf_server stdin_open: true tty: true environment: @@ -105,17 +125,16 @@ services: - NET_ADMIN networks: rightnet: - ipv4_address: 193.167.100.110 - ipv6_address: fd00:cafe:cafe:100::110 + ipv4_address: ${SERVER_V4_NET}.110 + ipv6_address: ${SERVER_V6_NET}::110 extra_hosts: - - "client4:193.167.0.90" - - "client6:fd00:cafe:cafe:0::100" - - "client46:193.167.0.90" - - "client46:fd00:cafe:cafe:0::100" + - "client4:${CLIENT_V4_NET}.90" + - "client6:${CLIENT_V6_NET}::100" + - "client46:${CLIENT_V4_NET}.90" + - "client46:${CLIENT_V6_NET}::100" iperf_client: image: martenseemann/quic-interop-iperf-endpoint - container_name: iperf_client stdin_open: true tty: true environment: @@ -127,13 +146,13 @@ services: - NET_ADMIN networks: leftnet: - ipv4_address: 193.167.0.90 - ipv6_address: fd00:cafe:cafe:0::90 + ipv4_address: ${CLIENT_V4_NET}.90 + ipv6_address: ${CLIENT_V6_NET}::90 extra_hosts: - - "server4:193.167.100.110" - - "server6:fd00:cafe:cafe:100::110" - - "server46:193.167.100.110" - - "server46:fd00:cafe:cafe:100::110" + - "server4:${SERVER_V4_NET}.110" + - "server6:${SERVER_V6_NET}::110" + - "server46:${SERVER_V4_NET}.110" + - "server46:${SERVER_V6_NET}::110" networks: leftnet: @@ -143,8 +162,8 @@ networks: enable_ipv6: true ipam: config: - - subnet: 193.167.0.0/24 - - subnet: fd00:cafe:cafe:0::/64 + - subnet: ${CLIENT_V4_NET}.0/${V4_PREFIX} + - subnet: ${CLIENT_V6_NET}::/${V6_PREFIX} rightnet: driver: bridge driver_opts: @@ -152,6 +171,5 @@ networks: enable_ipv6: true ipam: config: - - subnet: 193.167.100.0/24 - - subnet: fd00:cafe:cafe:100::/64 - + - subnet: ${SERVER_V4_NET}.0/${V4_PREFIX} + - subnet: ${SERVER_V6_NET}::/${V6_PREFIX} diff --git a/interop.py b/interop.py index f39ae14f..7ec9c4d4 100644 --- a/interop.py +++ b/interop.py @@ -1,5 +1,7 @@ +import io import json import logging +import multiprocessing import os from random_slugs import generate_slug import re @@ -8,8 +10,11 @@ import subprocess import sys import tempfile +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Optional import prettytable from termcolor import colored @@ -20,8 +25,8 @@ class MeasurementResult: - result = TestResult - details = str + result: TestResult + details: str class LogFileFormatter(logging.Formatter): @@ -31,20 +36,41 @@ def format(self, record): return re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]").sub("", msg) +class LogRecordCapturingHandler(logging.Handler): + """Handler that captures log records with their levels for later replay.""" + + def __init__(self): + super().__init__() + self.records = [] + + def emit(self, record): + # Store just the level and the formatted message + self.records.append( + { + "level": record.levelno, + "msg": self.format(record), + } + ) + + class InteropRunner: - _start_time = 0 - test_results = {} + _start_time: datetime = datetime.now() + test_results = {} # dict[str, dict[str, dict[testcases.TestCase, TestResult]]] measurement_results = {} - compliant = {} - _implementations = {} - _client_server_pairs = [] - _tests = [] - _measurements = [] + compliant: dict[str, bool] = {} + _implementations: dict[str, dict[str, str]] = {} + _client_server_pairs: list[tuple[str, str]] = [] + _tests: list[testcases.TestCase] = [] + _measurements: list[testcases.Measurement] = [] _output = "" _markdown = False _log_dir = "" _save_files = False - _no_auto_unsupported = [] + _no_auto_unsupported: list[str] = [] + # Shared class variables for subnet allocation across all instances + _subnet_allocator_lock = threading.Lock() + _allocated_subnets: set[int] = set() + _next_subnet_index = 0 def __init__( self, @@ -57,7 +83,8 @@ def __init__( debug: bool, save_files=False, log_dir="", - no_auto_unsupported=[], + parallel=None, + no_auto_unsupported=None, ): logger = logging.getLogger() logger.setLevel(logging.DEBUG) @@ -67,7 +94,6 @@ def __init__( else: console.setLevel(logging.INFO) logger.addHandler(console) - self._start_time = datetime.now() self._tests = tests self._measurements = measurements self._client_server_pairs = client_server_pairs @@ -76,7 +102,22 @@ def __init__( self._markdown = markdown self._log_dir = log_dir self._save_files = save_files - self._no_auto_unsupported = no_auto_unsupported + if no_auto_unsupported is None: + self._no_auto_unsupported = [] + else: + self._no_auto_unsupported = no_auto_unsupported + + total_cores = multiprocessing.cpu_count() + if parallel is None or parallel <= 0: + self._parallel = total_cores + else: + self._parallel = parallel + logging.info( + "Running with %d parallel tests (system has %d cores)", + self._parallel, + total_cores, + ) + if len(self._log_dir) == 0: self._log_dir = "logs_{:%Y-%m-%dT%H:%M:%S}".format(self._start_time) if os.path.exists(self._log_dir): @@ -97,6 +138,29 @@ def _is_unsupported(self, lines: List[str]) -> bool: "exit status 127" in str(line) for line in lines ) + def _docker_compose( + self, + action: str, + project_name: str, + env: Optional[dict[str, str]] = None, + containers: str = "", + timeout: Optional[int] = None, + check: bool = True, + ) -> subprocess.CompletedProcess: + cmd = ( + (" ".join(f"{k}={v} " for k, v in env.items()) if env else "") + + f"docker compose --project-name {project_name} --env-file empty.env {action} {containers}" + ) + return subprocess.run( + cmd, + # env=env, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=timeout, + check=check, + ) + def _check_impl_is_compliant(self, name: str) -> bool: """check if an implementation return UNSUPPORTED for unknown test cases""" if name in self.compliant: @@ -105,69 +169,53 @@ def _check_impl_is_compliant(self, name: str) -> bool: ) return self.compliant[name] + (subnet_index, subnet_params) = self._allocate_subnet() + project_name = f"compliance_{name}_{subnet_index}" client_log_dir = tempfile.TemporaryDirectory(dir="/tmp", prefix="logs_client_") + server_log_dir = tempfile.TemporaryDirectory(dir="/tmp", prefix="logs_server_") www_dir = tempfile.TemporaryDirectory(dir="/tmp", prefix="compliance_www_") certs_dir = tempfile.TemporaryDirectory(dir="/tmp", prefix="compliance_certs_") downloads_dir = tempfile.TemporaryDirectory( dir="/tmp", prefix="compliance_downloads_" ) + params = { + "CERTS": certs_dir.name, + "TESTCASE_CLIENT": generate_slug(), + "TESTCASE_SERVER": generate_slug(), + "CLIENT_LOGS": client_log_dir.name, + "SERVER_LOGS": server_log_dir.name, + "WWW": www_dir.name, + "DOWNLOADS": downloads_dir.name, + "SCENARIO": '"simple-p2p --delay=15ms --bandwidth=10Mbps --queue=25"', + "CLIENT": self._implementations[name]["image"], + "SERVER": self._implementations[name]["image"], + } testcases.generate_cert_chain(certs_dir.name) - # check that the client is capable of returning UNSUPPORTED - logging.debug("Checking compliance of %s client", name) - cmd = ( - "CERTS=" + certs_dir.name + " " - "TESTCASE_CLIENT=" + generate_slug() + " " - "SERVER_LOGS=/dev/null " - "CLIENT_LOGS=" + client_log_dir.name + " " - "WWW=" + www_dir.name + " " - "DOWNLOADS=" + downloads_dir.name + " " - 'SCENARIO="simple-p2p --delay=15ms --bandwidth=10Mbps --queue=25" ' - "CLIENT=" + self._implementations[name]["image"] + " " - "SERVER=" - + self._implementations[name]["image"] - + " " # only needed so docker compose doesn't complain - "docker compose --env-file empty.env up --timeout 0 --abort-on-container-exit -V sim client" - ) - output = subprocess.run( - cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT - ) - if not self._is_unsupported(output.stdout.splitlines()): - logging.error("%s client not compliant.", name) - logging.debug("%s", output.stdout.decode("utf-8", errors="replace")) - self.compliant[name] = False - return False - logging.debug("%s client compliant.", name) - - # check that the server is capable of returning UNSUPPORTED - logging.debug("Checking compliance of %s server", name) - server_log_dir = tempfile.TemporaryDirectory(dir="/tmp", prefix="logs_server_") - cmd = ( - "CERTS=" + certs_dir.name + " " - "TESTCASE_SERVER=" + generate_slug() + " " - "SERVER_LOGS=" + server_log_dir.name + " " - "CLIENT_LOGS=/dev/null " - "WWW=" + www_dir.name + " " - "DOWNLOADS=" + downloads_dir.name + " " - "CLIENT=" - + self._implementations[name]["image"] - + " " # only needed so docker compose doesn't complain - "SERVER=" + self._implementations[name]["image"] + " " - "docker compose --env-file empty.env up -V server" - ) - output = subprocess.run( - cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT - ) - if not self._is_unsupported(output.stdout.splitlines()): - logging.error("%s server not compliant.", name) - logging.debug("%s", output.stdout.decode("utf-8", errors="replace")) - self.compliant[name] = False - return False - logging.debug("%s server compliant.", name) - - # remember compliance test outcome + # check that client and server are capable of returning UNSUPPORTED self.compliant[name] = True + for role, containers, opt in [ + ("client", "sim client", "--timeout 0 --abort-on-container-exit"), + ("server", "server", ""), + ]: + logging.debug("Checking compliance of %s %s", name, role) + output = self._docker_compose( + f"up {opt} --renew-anon-volumes", + project_name, + subnet_params | params, + containers, + check=False, + ) + if not self._is_unsupported(output.stdout.splitlines()): + logging.error("%s %s not compliant.", name, role) + logging.debug(output.stdout.decode("utf-8", errors="replace")) + self.compliant[name] = False + break + logging.debug("%s %s compliant.", name, role) + + self._docker_compose("down", project_name) + self._release_subnet(subnet_index) return True def _postprocess_results(self): @@ -179,7 +227,7 @@ def _postprocess_results(self): for c in set(clients) - set(self._no_auto_unsupported): for t in self._tests: if all(self.test_results[s][c][t] in questionable for s in servers): - print( + logging.info( f"Client {c} failed or did not support test {t.name()} " + 'against all servers, marking the entire test as "unsupported"' ) @@ -190,7 +238,7 @@ def _postprocess_results(self): for s in set(servers) - set(self._no_auto_unsupported): for t in self._tests: if all(self.test_results[s][c][t] in questionable for c in clients): - print( + logging.info( f"Server {s} failed or did not support test {t.name()} " + 'against all clients, marking the entire test as "unsupported"' ) @@ -332,40 +380,42 @@ def _export_results(self): ) out["measurements"].append(measurements) - f = open(self._output, "w") + f = open(self._output, "w", encoding="utf-8") json.dump(out, f) f.close() - def _copy_logs(self, container: str, dir: tempfile.TemporaryDirectory): + def _copy_logs( + self, container: str, log_dir: tempfile.TemporaryDirectory, project_name: str + ): + # Match container names based on project name + # e.g., for project "interop_test" and container "sim", matches "interop_test-sim-1" cmd = ( - "docker cp \"$(docker ps -a --format '{{.ID}} {{.Names}}' | awk '/^.* " + "docker cp \"$(docker ps -a --format '{{.ID}} {{.Names}}' | awk '/" + + project_name + + "-" + container - + "$/ {print $1}')\":/logs/. " - + dir.name + + "(-[0-9]+)?$/ {print $1}' | head -1)\":/logs/. " + + log_dir.name ) r = subprocess.run( cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + check=True, ) if r.returncode != 0: logging.info( - "Copying logs from %s failed: %s", + "Copying logs from %s (project: %s) failed: %s", container, + project_name, r.stdout.decode("utf-8", errors="replace"), ) - def _run_testcase( - self, server: str, client: str, test: Callable[[], testcases.TestCase] - ) -> TestResult: - return self._run_test(server, client, None, test)[0] - def _run_test( self, server: str, client: str, - log_dir_prefix: None, test: Callable[[], testcases.TestCase], ) -> Tuple[TestResult, float]: start_time = datetime.now() @@ -380,10 +430,16 @@ def _run_test( log_handler.setFormatter(formatter) logging.getLogger().addHandler(log_handler) + (subnet_index, subnet_params) = self._allocate_subnet() + testcase = test( sim_log_dir=sim_log_dir, client_keylog_file=client_log_dir.name + "/keys.log", server_keylog_file=server_log_dir.name + "/keys.log", + client_v4=subnet_params["CLIENT_V4_ADDR"], + client_v6=subnet_params["CLIENT_V6_ADDR"], + server_v4=subnet_params["SERVER_V4_ADDR"], + server_v6=subnet_params["SERVER_V6_ADDR"], ) print( "Server: " @@ -396,72 +452,65 @@ def _run_test( reqs = " ".join([testcase.urlprefix() + p for p in testcase.get_paths()]) logging.debug("Requests: %s", reqs) - params = ( - "WAITFORSERVER=server:443 " - "CERTS=" + testcase.certs_dir() + " " - "TESTCASE_SERVER=" + testcase.testname(Perspective.SERVER) + " " - "TESTCASE_CLIENT=" + testcase.testname(Perspective.CLIENT) + " " - "WWW=" + testcase.www_dir() + " " - "DOWNLOADS=" + testcase.download_dir() + " " - "SERVER_LOGS=" + server_log_dir.name + " " - "CLIENT_LOGS=" + client_log_dir.name + " " - 'SCENARIO="{}" ' - "CLIENT=" + self._implementations[client]["image"] + " " - "SERVER=" + self._implementations[server]["image"] + " " - 'REQUESTS="' + reqs + '" ' - ).format(testcase.scenario()) - params += " ".join(testcase.additional_envs()) + project_name = f"interop_{server}_{client}_{testcase.name()}_{subnet_index}" containers = "sim client server " + " ".join(testcase.additional_containers()) - cmd = ( - params - + " docker compose --env-file empty.env up --abort-on-container-exit --timeout 1 " - + containers - ) - logging.debug("Command: %s", cmd) - status = TestResult.FAILED - output = "" + output = None expired = False try: - r = subprocess.run( - cmd, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - timeout=testcase.timeout(), + r = self._docker_compose( + "up --abort-on-container-exit --timeout 1", + project_name, + subnet_params + | { + "WAITFORSERVER": "server:443", + "CERTS": testcase.certs_dir(), + "TESTCASE_SERVER": testcase.testname(Perspective.SERVER), + "TESTCASE_CLIENT": testcase.testname(Perspective.CLIENT), + "WWW": testcase.www_dir(), + "DOWNLOADS": testcase.download_dir(), + "SERVER_LOGS": server_log_dir.name, + "CLIENT_LOGS": client_log_dir.name, + "SCENARIO": "'" + testcase.scenario() + "'", + "CLIENT": self._implementations[client]["image"], + "SERVER": self._implementations[server]["image"], + "REQUESTS": "'" + reqs + " ".join(testcase.additional_envs()) + "'", + }, + containers, + testcase.timeout(), ) output = r.stdout except subprocess.TimeoutExpired as ex: + logging.error("Test timed out after %ds", testcase.timeout()) output = ex.stdout expired = True + except subprocess.CalledProcessError as ex: + logging.error("Test failed with error: %s", ex) + output = ex.stdout - logging.debug("%s", output.decode("utf-8", errors="replace")) + if output is not None: + logging.debug(output.decode("utf-8", errors="replace")) if expired: logging.debug("Test failed: took longer than %ds.", testcase.timeout()) - r = subprocess.run( - "docker compose --env-file empty.env stop " + containers, - shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - timeout=60, - ) - logging.debug("%s", r.stdout.decode("utf-8", errors="replace")) + self._docker_compose("stop", project_name, None, containers, timeout=60) # copy the pcaps from the simulator - self._copy_logs("sim", sim_log_dir) - self._copy_logs("client", client_log_dir) - self._copy_logs("server", server_log_dir) + self._copy_logs("sim", sim_log_dir, project_name) + self._copy_logs("client", client_log_dir, project_name) + self._copy_logs("server", server_log_dir, project_name) - if not expired: - lines = output.splitlines() + if not expired and output is not None: + lines = output.decode("utf-8", errors="replace").splitlines() if self._is_unsupported(lines): status = TestResult.UNSUPPORTED - elif any("client exited with code 0" in str(line) for line in lines): + elif any( + re.search(r"client.*exited with code 0", str(line)) for line in lines + ): try: status = testcase.check() except FileNotFoundError as e: - logging.error(f"testcase.check() threw FileNotFoundError: {e}") + logging.error("testcase.check() threw FileNotFoundError: %s", e) status = TestResult.FAILED # save logs @@ -469,8 +518,6 @@ def _run_test( log_handler.close() if status == TestResult.FAILED or status == TestResult.SUCCEEDED: log_dir = self._log_dir + "/" + server + "_" + client + "/" + str(testcase) - if log_dir_prefix: - log_dir += "/" + log_dir_prefix shutil.copytree(server_log_dir.name, log_dir + "/server") shutil.copytree(client_log_dir.name, log_dir + "/client") shutil.copytree(sim_log_dir.name, log_dir + "/sim") @@ -480,8 +527,11 @@ def _run_test( try: shutil.copytree(testcase.download_dir(), log_dir + "/downloads") except Exception as exception: + # This logging will now go to console since we restored handlers logging.info("Could not copy downloaded files: %s", exception) + self._docker_compose("down", project_name) + self._release_subnet(subnet_index) testcase.cleanup() server_log_dir.cleanup() client_log_dir.cleanup() @@ -505,8 +555,8 @@ def _run_measurement( self, server: str, client: str, test: Callable[[], testcases.Measurement] ) -> MeasurementResult: values = [] - for i in range(0, test.repetitions()): - result, value = self._run_test(server, client, "%d" % (i + 1), test) + for _ in range(0, test.repetitions()): + result, value = self._run_test(server, client, test) if result != TestResult.SUCCEEDED: res = MeasurementResult() res.result = result @@ -522,6 +572,45 @@ def _run_measurement( ) return res + def _allocate_subnet(self): + """Allocate a unique subnet range for a test""" + with self._subnet_allocator_lock: + # Find next available subnet index + while InteropRunner._next_subnet_index in InteropRunner._allocated_subnets: + InteropRunner._next_subnet_index += 1 + + subnet_index = InteropRunner._next_subnet_index + InteropRunner._allocated_subnets.add(subnet_index) + InteropRunner._next_subnet_index += 1 + + subnet_v4 = f"10.{subnet_index}" + subnet_v6 = f"fd00:cafe:{subnet_index:04x}" + + params = { + "SUBNET_V4_PREFIX": "16", + "SUBNET_V4": subnet_v4, + "SUBNET_V4_SUBNET": ".0.0", + "V4_PREFIX": "24", + "CLIENT_V4_NET": f"{subnet_v4}.10", + "CLIENT_V4_ADDR": f"{subnet_v4}.10.10", + "SERVER_V4_NET": f"{subnet_v4}.222", + "SERVER_V4_ADDR": f"{subnet_v4}.222.222", + "SUBNET_V6_PREFIX": "48", + "SUBNET_V6": subnet_v6, + "V6_PREFIX": "64", + "CLIENT_V6_NET": f"{subnet_v6}:10", + "CLIENT_V6_ADDR": f"{subnet_v6}:10::10", + "SERVER_V6_NET": f"{subnet_v6}:222", + "SERVER_V6_ADDR": f"{subnet_v6}:222::222", + } + + return (subnet_index, params) + + def _release_subnet(self, subnet_index): + """Release a subnet range after test completion""" + with self._subnet_allocator_lock: + InteropRunner._allocated_subnets.discard(subnet_index) + def run(self): """run the interop test suite and output the table""" @@ -534,25 +623,89 @@ def run(self): client, self._implementations[client]["image"], ) - if not ( - self._check_impl_is_compliant(server) - and self._check_impl_is_compliant(client) - ): + + # Set up a handler to capture log records with their levels for this client/server pair + capture_handler = LogRecordCapturingHandler() + + # Find and remove console handlers, saving them for later restoration + root_logger = logging.getLogger() + console_handlers = [] + for handler in root_logger.handlers[ + : + ]: # Use slice to avoid modifying list during iteration + if ( + isinstance(handler, logging.StreamHandler) + and handler.stream == sys.stderr + ): + console_handlers.append(handler) + # Copy the console handler's level and formatter to the capture handler + capture_handler.setLevel(handler.level) + capture_handler.setFormatter(handler.formatter) + root_logger.removeHandler(handler) + + # Add the capture handler to capture logs with levels + root_logger.addHandler(capture_handler) + + # Check compliance (now captured) + compliant = self._check_impl_is_compliant( + server + ) and self._check_impl_is_compliant(client) + + if not compliant: logging.info("Not compliant, skipping") + # Restore console handlers before continuing + root_logger.removeHandler(capture_handler) + for handler in console_handlers: + root_logger.addHandler(handler) continue # run the test cases - for testcase in self._tests: - status = self._run_testcase(server, client, testcase) - self.test_results[server][client][testcase] = status - if status == TestResult.FAILED: - nr_failed += 1 + with ThreadPoolExecutor(max_workers=self._parallel) as executor: + # Submit all tests to the executor + futures = {} + for testcase in self._tests: + future = executor.submit(self._run_test, server, client, testcase) + futures[future] = testcase + # Small delay to prevent thundering herd on Docker daemon + time.sleep(0.2) + + # Collect results as they complete + for future in as_completed(futures): + testcase = futures[future] + try: + status, _ = future.result() + self.test_results[server][client][testcase] = status + if status == TestResult.FAILED: + nr_failed += 1 + print(f"Completed: {testcase.name()} - {status}") + + except Exception as e: + self.test_results[server][client][testcase] = TestResult.FAILED + nr_failed += 1 + print(f"Test {testcase.name()} failed with exception: {e}") # run the measurements for measurement in self._measurements: res = self._run_measurement(server, client, measurement) self.measurement_results[server][client][measurement] = res + # Restore console handlers and replay captured logs at their original levels + root_logger.removeHandler(capture_handler) + for handler in console_handlers: + root_logger.addHandler(handler) + + # Replay captured log records at their original levels + for record in capture_handler.records: + # Only output if the record level meets the threshold + for handler in root_logger.handlers: + if ( + isinstance(handler, logging.StreamHandler) + and handler.stream == sys.stderr + ): + if record["level"] >= handler.level: + print(record["msg"], file=sys.stderr) + break # Only print once per record + self._postprocess_results() self._print_results() self._export_results() diff --git a/run.py b/run.py index 462f8df4..d91d0cef 100755 --- a/run.py +++ b/run.py @@ -83,6 +83,13 @@ def get_args(): "--no-auto-unsupported", help="implementations for which auto-marking as unsupported when all tests fail should be skipped", ) + parser.add_argument( + "--parallel", + type=int, + default=None, + help="Number of tests to run in parallel. Use -1 for all CPU cores, " + "or specify a number. Default: half of available cores", + ) return parser.parse_args() replace_arg = get_args().replace @@ -169,6 +176,7 @@ def get_tests_and_measurements( debug=get_args().debug, log_dir=get_args().log_dir, save_files=get_args().save_files, + parallel=get_args().parallel, no_auto_unsupported=( no_auto_unsupported if get_args().no_auto_unsupported is None diff --git a/testcases.py b/testcases.py index 328e9227..37beb51c 100644 --- a/testcases.py +++ b/testcases.py @@ -17,10 +17,9 @@ Direction, PacketType, TraceAnalyzer, - get_direction, get_packet_type, ) -from typing import List, Tuple +from typing import List, Tuple, Optional from Crypto.Cipher import AES @@ -63,7 +62,7 @@ def generate_cert_chain(directory: str, length: int = 1): class TestCase(abc.ABC): - _files = [] + _files: List[str] = [] _www_dir = None _client_keylog_file = None _server_keylog_file = None @@ -72,17 +71,29 @@ class TestCase(abc.ABC): _cert_dir = None _cached_server_trace = None _cached_client_trace = None + _client_v4 = None + _client_v6 = None + _server_v4 = None + _server_v6 = None def __init__( self, sim_log_dir: tempfile.TemporaryDirectory, client_keylog_file: str, server_keylog_file: str, + client_v4: str, + client_v6: str, + server_v4: str, + server_v6: str, ): self._server_keylog_file = server_keylog_file self._client_keylog_file = client_keylog_file self._files = [] self._sim_log_dir = sim_log_dir + self._client_v4 = client_v4 + self._client_v6 = client_v6 + self._server_v4 = server_v4 + self._server_v6 = server_v6 @abc.abstractmethod def name(self): @@ -151,7 +162,7 @@ def _is_valid_keylog(self, filename) -> bool: return False return True - def _keylog_file(self) -> str: + def _keylog_file(self) -> Optional[str]: if self._is_valid_keylog(self._client_keylog_file): logging.debug("Using the client's key log file.") return self._client_keylog_file @@ -159,6 +170,7 @@ def _keylog_file(self) -> str: logging.debug("Using the server's key log file.") return self._server_keylog_file logging.debug("No key log file found.") + return None def _inject_keylog_if_possible(self, trace: str): """ @@ -184,14 +196,28 @@ def _client_trace(self): if self._cached_client_trace is None: trace = self._sim_log_dir.name + "/trace_node_left.pcap" self._inject_keylog_if_possible(trace) - self._cached_client_trace = TraceAnalyzer(trace, self._keylog_file()) + self._cached_client_trace = TraceAnalyzer( + trace, + self._client_v4, + self._client_v6, + self._server_v4, + self._server_v6, + self._keylog_file(), + ) return self._cached_client_trace def _server_trace(self): if self._cached_server_trace is None: trace = self._sim_log_dir.name + "/trace_node_right.pcap" self._inject_keylog_if_possible(trace) - self._cached_server_trace = TraceAnalyzer(trace, self._keylog_file()) + self._cached_server_trace = TraceAnalyzer( + trace, + self._client_v4, + self._client_v6, + self._server_v4, + self._server_v6, + self._keylog_file(), + ) return self._cached_server_trace def _generate_random_file(self, size: int, filename: str = None) -> str: @@ -303,10 +329,9 @@ def get_paths(self): pass @abc.abstractmethod - def check(self) -> TestResult: + def check(self): self._client_trace() self._server_trace() - pass class Measurement(TestCase): @@ -835,7 +860,7 @@ def check(self) -> TestResult: res = TestResult.FAILED log_output = [] for p in self._server_trace().get_raw_packets(): - direction = get_direction(p) + direction = self._server_trace().get_direction(p) packet_type = get_packet_type(p) if packet_type == PacketType.VERSIONNEGOTIATION: logging.info("Didn't expect a Version Negotiation packet.") @@ -1303,7 +1328,8 @@ def check(self) -> TestResult: return TestResult.FAILED else: challenges.add(getattr(p["quic"], "path_challenge.data")) - paths.add(cur) + if cur is not None: + paths.add(cur) logging.info("Server saw these paths used: %s", paths) if len(paths) <= 1: diff --git a/trace.py b/trace.py index e7dbdeb9..5cdc890d 100644 --- a/trace.py +++ b/trace.py @@ -5,11 +5,6 @@ import pyshark -IP4_CLIENT = "193.167.0.100" -IP4_SERVER = "193.167.100.100" -IP6_CLIENT = "fd00:cafe:cafe:0::100" -IP6_SERVER = "fd00:cafe:cafe:100::100" - QUIC_V2 = hex(0x6B3343CF) @@ -47,20 +42,6 @@ class PacketType(Enum): } -def get_direction(p) -> Direction: - if (hasattr(p, "ip") and p.ip.src == IP4_CLIENT) or ( - hasattr(p, "ipv6") and p.ipv6.src == IP6_CLIENT - ): - return Direction.FROM_CLIENT - - if (hasattr(p, "ip") and p.ip.src == IP4_SERVER) or ( - hasattr(p, "ipv6") and p.ipv6.src == IP6_SERVER - ): - return Direction.FROM_SERVER - - return Direction.INVALID - - def get_packet_type(p) -> PacketType: if p.quic.header_form == "0": return PacketType.ONERTT @@ -79,20 +60,59 @@ def get_packet_type(p) -> PacketType: class TraceAnalyzer: _filename = "" - - def __init__(self, filename: str, keylog_file: Optional[str] = None): + _client_v4 = "" + _client_v6 = "" + _server_v4 = "" + _server_v6 = "" + + def __init__( + self, + filename: str, + client_v4: str, + client_v6: str, + server_v4: str, + server_v6: str, + keylog_file: Optional[str] = None, + ): self._filename = filename + self._client_v4 = client_v4 + self._client_v6 = client_v6 + self._server_v4 = server_v4 + self._server_v6 = server_v6 self._keylog_file = keylog_file + def get_direction(self, p) -> Direction: + if (hasattr(p, "ip") and p.ip.src == self._client_v4) or ( + hasattr(p, "ipv6") and p.ipv6.src == self._client_v6 + ): + return Direction.FROM_CLIENT + + if (hasattr(p, "ip") and p.ip.src == self._server_v4) or ( + hasattr(p, "ipv6") and p.ipv6.src == self._server_v6 + ): + return Direction.FROM_SERVER + + return Direction.INVALID + def _get_direction_filter(self, d: Direction) -> str: f = "(quic && !icmp) && " if d == Direction.FROM_CLIENT: return ( - f + "(ip.src==" + IP4_CLIENT + " || ipv6.src==" + IP6_CLIENT + ") && " + f + + "(ip.src==" + + self._client_v4 + + " || ipv6.src==" + + self._client_v6 + + ") && " ) elif d == Direction.FROM_SERVER: return ( - f + "(ip.src==" + IP4_SERVER + " || ipv6.src==" + IP6_SERVER + ") && " + f + + "(ip.src==" + + self._server_v4 + + " || ipv6.src==" + + self._server_v6 + + ") && " ) else: return f @@ -145,7 +165,7 @@ def get_1rtt_sniff_times( ) -> Tuple[List, datetime.datetime, datetime.datetime]: """Get all QUIC packets, one or both directions, and first and last sniff times.""" packets = [] - first, last = 0, 0 + first, last = datetime.datetime.min, datetime.datetime.min for packet in self._get_packets( self._get_direction_filter(direction) + "quic.header_form==0" ): @@ -155,7 +175,7 @@ def get_1rtt_sniff_times( and not hasattr(layer, "long_packet_type") and not hasattr(layer, "long_packet_type_v2") ): - if first == 0: + if first == datetime.datetime.min: first = packet.sniff_time last = packet.sniff_time packets.append(layer)