Skip to content

Commit 0256452

Browse files
Fixed closures
1 parent 582bfc2 commit 0256452

File tree

2 files changed

+81
-8
lines changed

2 files changed

+81
-8
lines changed

drjit/__init__.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2681,15 +2681,15 @@ def inner(closure, *args, **kwargs):
26812681

26822682
class FrozenFunction:
26832683
def __init__(self, f) -> None:
2684-
closure = inspect.getclosurevars(f)
2685-
self.closure = (closure.nonlocals, closure.globals)
2684+
self.f = f
26862685
self.frozen = detail.FrozenFunction(
26872686
inner, limit, warn_after, backend, auto_opaque
26882687
)
26892688

26902689
def __call__(self, *args, **kwargs):
26912690
_state = state_fn(*args, **kwargs) if state_fn is not None else None
2692-
return self.frozen([self.closure, _state], *args, **kwargs)
2691+
closure = inspect.getclosurevars(f)
2692+
return self.frozen([closure.nonlocals, closure.globals, _state], *args, **kwargs)
26932693

26942694
@property
26952695
def n_recordings(self):
@@ -2725,7 +2725,7 @@ def __get__(self, obj, type=None):
27252725
if obj is None:
27262726
return self
27272727
else:
2728-
return FrozenMethod(self.frozen, self.closure, obj)
2728+
return FrozenMethod(self.f, self.frozen, obj)
27292729

27302730
class FrozenMethod(FrozenFunction):
27312731
"""
@@ -2738,14 +2738,15 @@ class FrozenMethod(FrozenFunction):
27382738
The ``__call__`` method of the ``FrozenMethod`` then supplies the object
27392739
in addition to the arguments to the internal function.
27402740
"""
2741-
def __init__(self, frozen, closure, obj) -> None:
2741+
def __init__(self, f, frozen, obj) -> None:
2742+
self.f = f
27422743
self.obj = obj
27432744
self.frozen = frozen
2744-
self.closure = closure
27452745

27462746
def __call__(self, *args, **kwargs):
27472747
_state = state_fn(self.obj, *args, **kwargs) if state_fn is not None else None
2748-
return self.frozen([self.closure, _state], self.obj, *args, **kwargs)
2748+
closure = inspect.getclosurevars(self.f)
2749+
return self.frozen([closure.nonlocals, closure.globals, _state], self.obj, *args, **kwargs)
27492750

27502751
return functools.wraps(f)(FrozenFunction(f))
27512752

tests/test_freeze.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3085,7 +3085,7 @@ def test80_tensor_mean(t, auto_opaque):
30853085
def func(x):
30863086
return dr.mean(x)
30873087

3088-
frozen = dr.freeze(func)
3088+
frozen = dr.freeze(func, auto_opaque=auto_opaque)
30893089

30903090
for i in range(3):
30913091
shape = ((i + 3), 10)
@@ -3098,3 +3098,75 @@ def func(x):
30983098

30993099
assert frozen.n_recordings == 1
31003100

3101+
3102+
@pytest.test_arrays("float32, jit, shape=(*)")
3103+
@pytest.mark.parametrize("auto_opaque", [False, True])
3104+
def test81_changing_closures(t, auto_opaque):
3105+
3106+
y = 1
3107+
3108+
def func(x):
3109+
return x + y
3110+
3111+
frozen = dr.freeze(func, auto_opaque = auto_opaque)
3112+
3113+
for i in range(3):
3114+
x = dr.arange(t, i + 3)
3115+
3116+
res = frozen(x)
3117+
ref = func(x)
3118+
3119+
assert dr.allclose(res, ref)
3120+
3121+
assert frozen.n_recordings == 1
3122+
3123+
for i in range(3):
3124+
x = dr.arange(t, i + 3)
3125+
3126+
y += 1
3127+
3128+
res = frozen(x)
3129+
ref = func(x)
3130+
3131+
assert dr.allclose(res, ref)
3132+
3133+
assert frozen.n_recordings == 4
3134+
3135+
3136+
@pytest.test_arrays("float32, jit, shape=(*)")
3137+
@pytest.mark.parametrize("auto_opaque", [False, True])
3138+
def test82_changing_closures_methods(t, auto_opaque):
3139+
3140+
y = 1
3141+
3142+
class Test:
3143+
def func(self, x):
3144+
return x + y
3145+
3146+
@dr.freeze
3147+
def frozen(self, x):
3148+
return x + y
3149+
3150+
test = Test()
3151+
3152+
for i in range(3):
3153+
x = dr.arange(t, i + 3)
3154+
3155+
res = test.frozen(x)
3156+
ref = test.func(x)
3157+
3158+
assert dr.allclose(res, ref)
3159+
3160+
assert test.frozen.n_recordings == 1
3161+
3162+
for i in range(3):
3163+
x = dr.arange(t, i + 3)
3164+
3165+
y += 1
3166+
3167+
res = test.frozen(x)
3168+
ref = test.func(x)
3169+
3170+
assert dr.allclose(res, ref)
3171+
3172+
assert test.frozen.n_recordings == 4

0 commit comments

Comments
 (0)