diff --git a/src/scenic/core/external_params.py b/src/scenic/core/external_params.py index 88848db4b..24cae6acf 100644 --- a/src/scenic/core/external_params.py +++ b/src/scenic/core/external_params.py @@ -96,6 +96,9 @@ """ +from abc import ABC, abstractmethod +import warnings + from dotmap import DotMap import numpy @@ -182,6 +185,7 @@ def __init__(self, params, globalParams): import verifai.server # construct FeatureSpace + timeBound = globalParams.get("timeBound", 0) usingProbs = False self.params = tuple(params) for index, param in enumerate(self.params): @@ -193,11 +197,23 @@ def __init__(self, params, globalParams): param.index = index if param.probs is not None: usingProbs = True + + if timeBound == 0 and any(param.isTimeSeries for param in self.params): + warnings.warn( + "TimeSeries external parameter used but no global parameter `timeBound` is specified. " + "(If using VerifAI’s ScenicSampler, set its maxSteps option)." + ) + space = verifai.features.FeatureSpace( { - self.nameForParam(index): verifai.features.Feature(param.domain) + self.nameForParam(index): ( + verifai.features.Feature(param.domain) + if not param.isTimeSeries + else verifai.features.TimeSeriesFeature(param.domain) + ) for index, param in enumerate(self.params) - } + }, + timeBound=timeBound, ) # set up VerifAI sampler @@ -252,17 +268,48 @@ def __init__(self, params, globalParams): self.rejectionFeedback = 1 self.cachedSample = None + self._lastSample = None + self._lastDynamicSample = None + self._lastSimulation = None + self._lastTime = -1 + def nextSample(self, feedback): - return self.sampler.nextSample(feedback) + if feedback is not None: + assert self._lastSample is not None + self._lastSample.update(feedback) - def update(self, sample, info, rho): - self.sampler.update(sample, info, rho) + self._lastSample = self.sampler.getSample() + return self._lastSample - def getSample(self): - return self.sampler.getSample() + def nextDynamicSample(self): + import scenic.syntax.veneer as veneer + + assert veneer.currentSimulation is not None + + if self._lastSimulation is not veneer.currentSimulation: + self._lastSimulation = veneer.currentSimulation + self._lastTime = -1 + + if veneer.currentSimulation.currentTime > self._lastTime: + feedback = veneer.currentSimulation + self._lastDynamicSample = self.cachedSample.getDynamicSample(feedback) + self._lastTime = veneer.currentSimulation.currentTime + + return self._lastDynamicSample def valueFor(self, param): - return getattr(self.cachedSample, self.nameForParam(param.index)) + if not param.isTimeSeries: + return param.extractOutput( + getattr(self.cachedSample.staticSample, self.nameForParam(param.index)) + ) + else: + callback = lambda: param.extractOutput( + getattr( + self.nextDynamicSample(), + self.nameForParam(param.index), + ) + ) + return TimeSeriesParameter(callback) @staticmethod def nameForParam(i): @@ -276,6 +323,7 @@ class ExternalParameter(Distribution): def __init__(self): super().__init__() self.sampler = None + self.isTimeSeries = False import scenic.syntax.veneer as veneer # TODO improve? veneer.registerExternalParameter(self) @@ -289,6 +337,41 @@ def sampleGiven(self, value): assert self.sampler is not None return self.sampler.valueFor(self) + def extractOutput(self, value): + """ + Given a raw sampled value for a parameter, optionally extract the actual desired value. + + By default just passes the value through unchanged. + """ + return value + + +class TimeSeriesParameter: + def __init__(self, callback): + self._callback = callback + self._lastTime = -1 + + def getSample(self): + import scenic.syntax.veneer as veneer + + assert veneer.currentSimulation is not None + + if veneer.currentSimulation.currentTime <= self._lastTime: + raise RuntimeError( + "Attempted `getSample` for a TimeSeries external parameter twice in one timestep." + ) + + self._lastTime = veneer.currentSimulation.currentTime + return self._callback() + + +def TimeSeries(param): + if not isinstance(param, ExternalParameter): + raise TypeError("Cannot turn a non `ExternalParameter` into a time series") + + param.isTimeSeries = True + return param + class VerifaiParameter(ExternalParameter): """An external parameter sampled using one of VerifAI's samplers.""" @@ -341,8 +424,7 @@ def __init__(self, low, high, buckets=None, weights=None): total = sum(weights) self.probs = tuple(wt / total for wt in weights) - def sampleGiven(self, value): - value = super().sampleGiven(value) + def extractOutput(self, value): assert len(value) == 1 return value[0] @@ -367,8 +449,7 @@ def __init__(self, low, high, weights=None): else: self.probs = None - def sampleGiven(self, value): - value = super().sampleGiven(value) + def extractOutput(self, value): assert len(value) == 1 return value[0] diff --git a/src/scenic/syntax/veneer.py b/src/scenic/syntax/veneer.py index cce52162f..034ff61d8 100644 --- a/src/scenic/syntax/veneer.py +++ b/src/scenic/syntax/veneer.py @@ -120,6 +120,7 @@ "VerifaiRange", "VerifaiDiscreteRange", "VerifaiOptions", + "TimeSeries", "File", "Files", # Constructible types @@ -201,6 +202,7 @@ from scenic.core.dynamics.invocables import BlockConclusion, runTryInterrupt from scenic.core.dynamics.scenarios import DynamicScenario from scenic.core.external_params import ( + TimeSeries, VerifaiDiscreteRange, VerifaiOptions, VerifaiParameter,