Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/scenic/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
`Scenario.simulationToBytes`, and `Scene.dumpAsScenicCode`.
"""

import hashlib
import io
import math
import pickle
Expand All @@ -14,6 +15,33 @@
from scenic.core.distributions import Samplable, needsSampling
from scenic.core.utils import DefaultIdentityDict


def deterministicHash(mapping, *, digest_size=8):
"""Compute a deterministic hash for a mapping of options.

Keys are sorted (by their string representation) and encoded with explicit
separators so that different key/value combinations do not collide under
simple concatenation.

Only int/float/str (and bool, as a subclass of int) values are encoded directly; any other value is replaced
by a generic placeholder byte, to avoid nondeterminism from reprs containing
memory addresses or other run-specific data.
"""
hasher = hashlib.blake2b(digest_size=digest_size)
# Sort by stringified key so we can handle non-string keys deterministically.
for key in sorted(mapping.keys(), key=str):
hasher.update(b"\0K")
hasher.update(str(key).encode())
hasher.update(b"\0V")
value = mapping[key]
if isinstance(value, (int, float, str)):
hasher.update(str(value).encode())
else:
# Unsupported types just contribute a placeholder byte.
hasher.update(b"\0")
return hasher.digest()


## JSON


Expand Down Expand Up @@ -126,7 +154,7 @@ def sceneFormatVersion(cls):
Must be incremented if the `writeScene` method or any of its helper
methods (e.g. `writeValue`) change, or if a new codec is added.
"""
return 2
return 3

@classmethod
def replayFormatVersion(cls):
Expand Down
45 changes: 39 additions & 6 deletions src/scenic/domains/driving/roads.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import scenic.core.geometry as geometry
from scenic.core.object_types import Point
from scenic.core.regions import PolygonalRegion, PolylineRegion
from scenic.core.serialization import deterministicHash
import scenic.core.type_support as type_support
import scenic.core.utils as utils
from scenic.core.vectors import Orientation, Vector, VectorField
Expand Down Expand Up @@ -1054,21 +1055,34 @@ def fromFile(cls, path, useCache: bool = True, writeCache: bool = True, **kwargs
data = f.read()
digest = hashlib.blake2b(data).digest()

# Hash the map options as well so changing them invalidates the cache.
optionsDigest = deterministicHash(kwargs, digest_size=8)

# By default, use the pickled version if it exists and is not outdated
pickledPath = path.with_suffix(cls.pickledExt)
if useCache and pickledPath.exists():
try:
return cls.fromPickle(pickledPath, originalDigest=digest)
return cls.fromPickle(
pickledPath,
originalDigest=digest,
optionsDigest=optionsDigest,
)
except pickle.UnpicklingError:
verbosePrint("Unable to load cached network (old format or corrupted).")
except cls.DigestMismatchError:
verbosePrint("Cached network does not match original file; ignoring it.")
verbosePrint(
"Cached network does not match original file or map options; ignoring it."
)

# Not using the pickled version; parse the original file based on its extension
network = handlers[ext](path, **kwargs)
if writeCache:
verbosePrint(f"Caching road network in {cls.pickledExt} file.")
network.dumpPickle(path.with_suffix(cls.pickledExt), digest)
network.dumpPickle(
path.with_suffix(cls.pickledExt),
digest,
optionsDigest=optionsDigest,
)
return network

@classmethod
Expand Down Expand Up @@ -1112,11 +1126,12 @@ def fromOpenDrive(
return network

@classmethod
def fromPickle(cls, path, originalDigest=None):
def fromPickle(cls, path, originalDigest=None, optionsDigest=None):
startTime = time.time()
verbosePrint("Loading cached version of road network...")

with open(path, "rb") as f:
# Version field
versionField = f.read(4)
if len(versionField) != 4:
raise pickle.UnpicklingError(f"{cls.pickledExt} file is corrupted")
Expand All @@ -1126,6 +1141,8 @@ def fromPickle(cls, path, originalDigest=None):
f"{cls.pickledExt} file is too old; "
"regenerate it from the original map"
)

# Digest of the original map file
digest = f.read(64)
if len(digest) != 64:
raise pickle.UnpicklingError(f"{cls.pickledExt} file is corrupted")
Expand All @@ -1134,6 +1151,18 @@ def fromPickle(cls, path, originalDigest=None):
f"{cls.pickledExt} file does not correspond to the original map; "
" regenerate it"
)

# Digest of the map options used to generate this cache
cachedOptionsDigest = f.read(8)
if len(cachedOptionsDigest) != 8:
raise pickle.UnpicklingError(f"{cls.pickledExt} file is corrupted")
if optionsDigest and optionsDigest != cachedOptionsDigest:
raise cls.DigestMismatchError(
f"{cls.pickledExt} file does not correspond to the current "
"map options; regenerate it"
)

# Remaining bytes are the compressed pickle of the Network
with gzip.open(f) as gf:
try:
network = pickle.load(gf) # invokes __setstate__ below
Expand Down Expand Up @@ -1167,15 +1196,19 @@ def reconnect(thing):
for maneuver in elem.maneuvers:
reconnect(maneuver)

def dumpPickle(self, path, digest):
def dumpPickle(self, path, digest, optionsDigest):
path = pathlib.Path(path)
if not path.suffix:
path = path.with_suffix(self.pickledExt)
version = struct.pack("<I", self._currentFormatVersion())
data = pickle.dumps(self)

with open(path, "wb") as f:
f.write(version) # uncompressed in case we change compression schemes later
f.write(digest) # uncompressed for quick lookup
f.write(digest) # digest of original map file
f.write(optionsDigest) # digest of map options

# The rest of the file is a gzip-compressed pickle of the Network.
with gzip.open(f, "wb") as gf:
gf.write(data)

Expand Down
30 changes: 13 additions & 17 deletions src/scenic/syntax/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from scenic.core.errors import InvalidScenarioError, PythonCompileError
from scenic.core.lazy_eval import needsLazyEvaluation
import scenic.core.pruning as pruning
from scenic.core.serialization import deterministicHash
from scenic.core.utils import cached_property
from scenic.syntax.compiler import compileScenicAST
from scenic.syntax.parser import parse_string
Expand All @@ -53,7 +54,7 @@
class CompileOptions:
"""Internal class for capturing options used when compiling a scenario."""

# N.B. update `hash` below when adding a new field
# N.B. update `_hashMapping` below when adding a new field

#: Whether or not the scenario uses `2D compatibility mode`.
mode2D: bool = False
Expand All @@ -64,25 +65,20 @@ class CompileOptions:
#: Selected modular scenario, if any.
scenario: Optional[str] = None

def _hashMapping(self):
mapping = {"mode2D": self.mode2D}
if self.modelOverride:
mapping["modelOverride"] = self.modelOverride
if self.scenario:
mapping["scenario"] = self.scenario
for k, v in self.paramOverrides.items():
mapping[f"param:{k}"] = v
return mapping

@cached_property
def hash(self):
"""Deterministic hash saved in serialized scenes to catch option mismatches."""
stream = io.BytesIO()
stream.write(bytes([self.mode2D]))
if self.modelOverride:
stream.write(self.modelOverride.encode())
for key in sorted(self.paramOverrides.keys()):
stream.write(key.encode())
value = self.paramOverrides[key]
if isinstance(value, (int, float, str)):
stream.write(str(value).encode())
else:
stream.write([0])
if self.scenario:
stream.write(self.scenario.encode())
# We can't use `hash` because it is not deterministic
# (e.g. the hashes of strings are randomized)
return hashlib.blake2b(stream.getvalue(), digest_size=4).digest()
return deterministicHash(self._hashMapping(), digest_size=4)


def scenarioFromString(
Expand Down
27 changes: 26 additions & 1 deletion tests/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy
import pytest

from scenic.core.serialization import SerializationError, Serializer
from scenic.core.serialization import SerializationError, Serializer, deterministicHash
from scenic.core.simulators import DivergenceError, DummySimulator
from tests.utils import (
areEquivalent,
Expand Down Expand Up @@ -260,6 +260,18 @@ def test_scene_inconsistent_mode(self):
with pytest.raises(SerializationError):
sc2.sceneFromBytes(data)

def test_scene_inconsistent_params(self):
code = """
ego = new Object
param x = 1
"""
sc1 = compileScenic(code, params={"x": 1})
sc2 = compileScenic(code, params={"x": 2})
scene1 = sampleScene(sc1)
data = sc1.sceneToBytes(scene1)
with pytest.raises(SerializationError):
sc2.sceneFromBytes(data)

def test_scene_behavior(self):
scenario = compileScenic(
"""
Expand Down Expand Up @@ -482,3 +494,16 @@ def test_combined_serialization(self):
data = scenario.simulationToBytes(sim1)
sim2 = scenario.simulationFromBytes(data, simulator, maxSteps=1)
assert getEgoActionsFrom(sim1) == getEgoActionsFrom(sim2)


def test_deterministic_hash_non_scalar_values():
class Foo:
pass

mapping1 = {"a": Foo()}
mapping2 = {"a": Foo()} # different instance, still non-scalar

digest1 = deterministicHash(mapping1)
digest2 = deterministicHash(mapping2)
# Non-scalar values should hash in a stable way, independent of identity.
assert digest1 == digest2
97 changes: 97 additions & 0 deletions tests/domains/driving/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,100 @@ def test_sidewalk(network):
pt = sw.uniformPointInner()
assert network.sidewalkAt(pt) is sw
assert network.elementAt(pt) is sw


# --- Tests for cached network pickles ---


def _read_options_digest(pickled_path: Path) -> bytes:
"""Return the optionsDigest bytes from a cached .snet header.

Header layout:
- 4 bytes: version
- 64 bytes: originalDigest
- 8 bytes: optionsDigest
"""
with open(pickled_path, "rb") as f:
header = f.read(76) # 4 + 64 + 8
return header[68:76] # skip version (4) + originalDigest (64)


def test_dump_pickle_from_pickle(tmp_path, network):
"""dumpPickle/fromPickle should rebuild the Network when digests match"""
digest = b"x" * 64 # fake original map digest
options_digest = b"y" * 8 # fake map options digest
path = tmp_path / "net.snet"

# Write the cache using the new format.
network.dumpPickle(path, digest, options_digest)

# Read it back with matching digests.
loaded = Network.fromPickle(
path,
originalDigest=digest,
optionsDigest=options_digest,
)

# Sanity checks
assert isinstance(loaded, Network)
assert loaded.elements.keys() == network.elements.keys()


def test_from_pickle_digest_mismatch(tmp_path, network):
"""fromPickle should reject files whose digests don't match."""
digest = b"x" * 64
options_digest = b"y" * 8
path = tmp_path / "net.snet"

network.dumpPickle(path, digest, options_digest)

wrong_digest = b"z" * 64
wrong_options = b"w" * 8

# Wrong originalDigest -> DigestMismatchError.
with pytest.raises(Network.DigestMismatchError):
Network.fromPickle(
path,
originalDigest=wrong_digest,
optionsDigest=options_digest,
)

# Wrong optionsDigest -> DigestMismatchError.
with pytest.raises(Network.DigestMismatchError):
Network.fromPickle(
path,
originalDigest=digest,
optionsDigest=wrong_options,
)

# No digests supplied -> allowed (standalone .snet use case).
Network.fromPickle(path)


def test_cache_regenerated_when_options_change(cached_maps):
"""Changing map options should invalidate and rewrite the cached .snet file."""
# Get the temp copy of Town01 from cached_maps.
xodr_loc = cached_maps[str(mapFolder / "CARLA" / "Town01.xodr")]
xodr_path = Path(str(xodr_loc))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a future PR: we should switch from tmpdir_factory to tmp_path_factory so that we already have Path objects and don't need to do this conversion.

pickled_path = xodr_path.with_suffix(Network.pickledExt)

# First load: cache with one tolerance value
Network.fromFile(
xodr_path,
useCache=True,
writeCache=True,
tolerance=0.05,
)
options_digest1 = _read_options_digest(pickled_path)

# Second load: same map, different tolerance.
Network.fromFile(
xodr_path,
useCache=True,
writeCache=True,
tolerance=0.123,
)
options_digest2 = _read_options_digest(pickled_path)

# If the options changed, the cache header should also change.
assert options_digest1 != options_digest2
Loading