diff --git a/src/scenic/core/serialization.py b/src/scenic/core/serialization.py index a7c52367a..580ffb32b 100644 --- a/src/scenic/core/serialization.py +++ b/src/scenic/core/serialization.py @@ -5,6 +5,7 @@ `Scenario.simulationToBytes`, and `Scene.dumpAsScenicCode`. """ +import hashlib import io import math import pickle @@ -14,6 +15,33 @@ from scenic.core.distributions import Samplable, needsSampling from scenic.core.utils import DefaultIdentityDict + +def deterministicHash(mapping, *, digest_size=8): + """Compute a deterministic hash for a mapping of options. + + Keys are sorted (by their string representation) and encoded with explicit + separators so that different key/value combinations do not collide under + simple concatenation. + + Only int/float/str (and bool, as a subclass of int) values are encoded directly; any other value is replaced + by a generic placeholder byte, to avoid nondeterminism from reprs containing + memory addresses or other run-specific data. + """ + hasher = hashlib.blake2b(digest_size=digest_size) + # Sort by stringified key so we can handle non-string keys deterministically. + for key in sorted(mapping.keys(), key=str): + hasher.update(b"\0K") + hasher.update(str(key).encode()) + hasher.update(b"\0V") + value = mapping[key] + if isinstance(value, (int, float, str)): + hasher.update(str(value).encode()) + else: + # Unsupported types just contribute a placeholder byte. + hasher.update(b"\0") + return hasher.digest() + + ## JSON @@ -126,7 +154,7 @@ def sceneFormatVersion(cls): Must be incremented if the `writeScene` method or any of its helper methods (e.g. `writeValue`) change, or if a new codec is added. """ - return 2 + return 3 @classmethod def replayFormatVersion(cls): diff --git a/src/scenic/domains/driving/roads.py b/src/scenic/domains/driving/roads.py index c3cb85b9e..636026401 100644 --- a/src/scenic/domains/driving/roads.py +++ b/src/scenic/domains/driving/roads.py @@ -37,6 +37,7 @@ import scenic.core.geometry as geometry from scenic.core.object_types import Point from scenic.core.regions import PolygonalRegion, PolylineRegion +from scenic.core.serialization import deterministicHash import scenic.core.type_support as type_support import scenic.core.utils as utils from scenic.core.vectors import Orientation, Vector, VectorField @@ -1054,21 +1055,34 @@ def fromFile(cls, path, useCache: bool = True, writeCache: bool = True, **kwargs data = f.read() digest = hashlib.blake2b(data).digest() + # Hash the map options as well so changing them invalidates the cache. + optionsDigest = deterministicHash(kwargs, digest_size=8) + # By default, use the pickled version if it exists and is not outdated pickledPath = path.with_suffix(cls.pickledExt) if useCache and pickledPath.exists(): try: - return cls.fromPickle(pickledPath, originalDigest=digest) + return cls.fromPickle( + pickledPath, + originalDigest=digest, + optionsDigest=optionsDigest, + ) except pickle.UnpicklingError: verbosePrint("Unable to load cached network (old format or corrupted).") except cls.DigestMismatchError: - verbosePrint("Cached network does not match original file; ignoring it.") + verbosePrint( + "Cached network does not match original file or map options; ignoring it." + ) # Not using the pickled version; parse the original file based on its extension network = handlers[ext](path, **kwargs) if writeCache: verbosePrint(f"Caching road network in {cls.pickledExt} file.") - network.dumpPickle(path.with_suffix(cls.pickledExt), digest) + network.dumpPickle( + path.with_suffix(cls.pickledExt), + digest, + optionsDigest=optionsDigest, + ) return network @classmethod @@ -1112,11 +1126,12 @@ def fromOpenDrive( return network @classmethod - def fromPickle(cls, path, originalDigest=None): + def fromPickle(cls, path, originalDigest=None, optionsDigest=None): startTime = time.time() verbosePrint("Loading cached version of road network...") with open(path, "rb") as f: + # Version field versionField = f.read(4) if len(versionField) != 4: raise pickle.UnpicklingError(f"{cls.pickledExt} file is corrupted") @@ -1126,6 +1141,8 @@ def fromPickle(cls, path, originalDigest=None): f"{cls.pickledExt} file is too old; " "regenerate it from the original map" ) + + # Digest of the original map file digest = f.read(64) if len(digest) != 64: raise pickle.UnpicklingError(f"{cls.pickledExt} file is corrupted") @@ -1134,6 +1151,18 @@ def fromPickle(cls, path, originalDigest=None): f"{cls.pickledExt} file does not correspond to the original map; " " regenerate it" ) + + # Digest of the map options used to generate this cache + cachedOptionsDigest = f.read(8) + if len(cachedOptionsDigest) != 8: + raise pickle.UnpicklingError(f"{cls.pickledExt} file is corrupted") + if optionsDigest and optionsDigest != cachedOptionsDigest: + raise cls.DigestMismatchError( + f"{cls.pickledExt} file does not correspond to the current " + "map options; regenerate it" + ) + + # Remaining bytes are the compressed pickle of the Network with gzip.open(f) as gf: try: network = pickle.load(gf) # invokes __setstate__ below @@ -1167,15 +1196,19 @@ def reconnect(thing): for maneuver in elem.maneuvers: reconnect(maneuver) - def dumpPickle(self, path, digest): + def dumpPickle(self, path, digest, optionsDigest): path = pathlib.Path(path) if not path.suffix: path = path.with_suffix(self.pickledExt) version = struct.pack(" bytes: + """Return the optionsDigest bytes from a cached .snet header. + + Header layout: + - 4 bytes: version + - 64 bytes: originalDigest + - 8 bytes: optionsDigest + """ + with open(pickled_path, "rb") as f: + header = f.read(76) # 4 + 64 + 8 + return header[68:76] # skip version (4) + originalDigest (64) + + +def test_dump_pickle_from_pickle(tmp_path, network): + """dumpPickle/fromPickle should rebuild the Network when digests match""" + digest = b"x" * 64 # fake original map digest + options_digest = b"y" * 8 # fake map options digest + path = tmp_path / "net.snet" + + # Write the cache using the new format. + network.dumpPickle(path, digest, options_digest) + + # Read it back with matching digests. + loaded = Network.fromPickle( + path, + originalDigest=digest, + optionsDigest=options_digest, + ) + + # Sanity checks + assert isinstance(loaded, Network) + assert loaded.elements.keys() == network.elements.keys() + + +def test_from_pickle_digest_mismatch(tmp_path, network): + """fromPickle should reject files whose digests don't match.""" + digest = b"x" * 64 + options_digest = b"y" * 8 + path = tmp_path / "net.snet" + + network.dumpPickle(path, digest, options_digest) + + wrong_digest = b"z" * 64 + wrong_options = b"w" * 8 + + # Wrong originalDigest -> DigestMismatchError. + with pytest.raises(Network.DigestMismatchError): + Network.fromPickle( + path, + originalDigest=wrong_digest, + optionsDigest=options_digest, + ) + + # Wrong optionsDigest -> DigestMismatchError. + with pytest.raises(Network.DigestMismatchError): + Network.fromPickle( + path, + originalDigest=digest, + optionsDigest=wrong_options, + ) + + # No digests supplied -> allowed (standalone .snet use case). + Network.fromPickle(path) + + +def test_cache_regenerated_when_options_change(cached_maps): + """Changing map options should invalidate and rewrite the cached .snet file.""" + # Get the temp copy of Town01 from cached_maps. + xodr_loc = cached_maps[str(mapFolder / "CARLA" / "Town01.xodr")] + xodr_path = Path(str(xodr_loc)) + pickled_path = xodr_path.with_suffix(Network.pickledExt) + + # First load: cache with one tolerance value + Network.fromFile( + xodr_path, + useCache=True, + writeCache=True, + tolerance=0.05, + ) + options_digest1 = _read_options_digest(pickled_path) + + # Second load: same map, different tolerance. + Network.fromFile( + xodr_path, + useCache=True, + writeCache=True, + tolerance=0.123, + ) + options_digest2 = _read_options_digest(pickled_path) + + # If the options changed, the cache header should also change. + assert options_digest1 != options_digest2