diff --git a/testcases.py b/testcases.py index bda56903..ff279471 100644 --- a/testcases.py +++ b/testcases.py @@ -1259,6 +1259,32 @@ def _path(p: List) -> Tuple[str, int, str, int]: (TestCasePortRebinding._addr(p, "dst"), int(getattr(p["udp"], "dstport"))), ) + @staticmethod + def _is_probing_packet(p: List) -> bool: + q = p["quic"] + + if not hasattr(q, "frame_type"): + return False + + for f in getattr(q, "frame_type").all_fields: + if f.hex_value not in [0x00, 0x1A, 0x1B]: + return False + + return True + + @staticmethod + def _is_ack_only_packet(p: List) -> bool: + q = p["quic"] + + if not hasattr(q, "frame_type"): + return False + + for f in getattr(q, "frame_type").all_fields: + if f.hex_value not in [0x02, 0x03]: + return False + + return True + def check(self) -> TestResult: super().check() if not self._keylog_file(): @@ -1273,30 +1299,32 @@ def check(self) -> TestResult: self._server_trace()._get_direction_filter(Direction.FROM_SERVER) + " quic" ) - cur = None - last = None paths = set() challenges = set() + path_challenges = set() + for p in tr_server: + path_challenges.add(self._path(p)) + break + for p in tr_server: cur = self._path(p) - if last is None: - last = cur - continue + paths.add(cur) - if last != cur and cur not in paths: - paths.add(last) - last = cur - # Packet on new path, should have a PATH_CHALLENGE frame - if hasattr(p["quic"], "path_challenge.data") is False: - logging.info( - "First server packet on new path %s did not contain a PATH_CHALLENGE frame", - cur, - ) - logging.info(p["quic"]) - return TestResult.FAILED - else: - challenges.add(getattr(p["quic"], "path_challenge.data")) - paths.add(cur) + if cur not in path_challenges and hasattr(p["quic"], "path_challenge.data"): + challenges.add(getattr(p["quic"], "path_challenge.data")) + path_challenges.add(cur) + + if ( + not self._is_ack_only_packet(p) + and not self._is_probing_packet(p) + and cur not in path_challenges + ): + logging.info( + "First server non-probing packet on new path %s before observing a PATH_CHALLENGE frame", + cur, + ) + logging.info(p["quic"]) + return TestResult.FAILED logging.info("Server saw these paths used: %s", paths) if len(paths) <= 1: