diff --git a/CMakeLists.txt b/CMakeLists.txt index 3cfc2dae7..6f78ff729 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,6 +63,7 @@ target_include_directories(drjit INTERFACE $ $ + $ $ $ ) diff --git a/docs/freeze.rst b/docs/freeze.rst new file mode 100644 index 000000000..f1d20f90d --- /dev/null +++ b/docs/freeze.rst @@ -0,0 +1,938 @@ +.. py:currentmodule:: drjit + +.. _freeze: + +Function Freezing +================= + +.. warning:: + This feature is still experimental, and we list a number of unsupported cases + in the :ref:`pitfalls` section. This feature also only supports a subset of the + operations that can be performed with Dr.Jit, we list them in the + :ref:`unsupported_operations` section. + If you encounter any issues please feel free to open an issue + `here `__ + (make sure to include a minimal reproducing example). + + +Introduction +------------ + +When working with Dr.Jit, your code typically is first traced, to obtain a +computation graph of the operations you intended to perform. When calling +:py:func:`drjit.eval`, this graph is assembled into an intermediate +representation, either LLVM IR or CUDA PTX. Finally, this assembly is compiled +into actual binary code that can be run on the specified hardware. This last +step can be very expensive, since the underlying compilers perform a lot of +optimization on the intermediate code. Dr.Jit therefore caches this step by +default using a hash of the assembled IR code. + +However, the first two steps of tracing the Python code and generating the +intermediate representation can still be expensive on their own. When a lot of +Python code has to be traced, such as custom Python functions, the GIL has to be +locked multiple times. Similarly, when tracing virtual function calls of many +instances of custom plugins, these functions can cause a large performance +overhead. + +This feature addresses this performance bottleneck, by introducing a +:py:func:`drjit.freeze` decorator. If a function is annotated with this +decorator, Dr.Jit will attempt to cache not only the compiled kernel as before, +but also the tracing and assembly steps. When a frozen function is called the +first time, Dr.Jit will analyze the inputs, and then trace the function once, +taking note of all kernels launched within that function. On subsequent calls to +the function, Dr.Jit will check that the new inputs are still compatible with +the previously-recorded kernels. If so, all tracing and assembly is skipped and +the kernels are launched directly. + + +Usage +----- + +In supported cases, using this feature is as simple as annotating the function +with the :py:func:`drjit.freeze` decorator: + +.. code-block:: python + + import drjit as dr + from drjit.cuda import Float, UInt32 + + # Without freezing - traces every time + def func(x): + # Complex operations... + y = x + 1 + dr.eval(y) + z = x * 2 + return z + + # With freezing - traces only once + @dr.freeze + def frozen(x): + # Same complex operations... + y = x + 1 + dr.eval(y) + z = x * 2 + return z + + +Note that the overhead is not fully eliminated, since analyzing the inputs, +mapping the kernel's outputs to function outputs, and performing various checks +to ensure correctness, still takes some time. + +For debugging purposes, this feature can easily be disabled by setting the +:py:attr:`drjit.JitFlag.KernelFreezing` to ``False``. + +.. code-block:: python + + @dr.freeze + def func(x): + ... + + # By default the function is recorded and replayed on subsequent calls. + func(x) + + # Function freezing can be disabled by setting a flag to False. Subsequent + # calls will not use the recording and run the function as if it was not + # annotated. + dr.set_flag(dr.JitFlag.KernelFreezing, False) + func(x) + +To re-enable function freezing, the flag can simply be set to ``True`` again. +Previous recordings, made while the flag was set, will still be available and +can be used when replaying the function. + +Additional arguments can be specified when using the decorator. These are +documented in the API-level documentation :py:func:`drjit.freeze`. + +More implementation details are given :ref:`below `. + +.. _unsupported_operations: + +Unsupported operations +---------------------- + +Frozen functions can only contain operations that can be replayed seamlessly +with new inputs. We describe the main **unsupported** operations below. + + +Array access +~~~~~~~~~~~~ + +The input of a frozen function can consist of two kinds of variables: + +- Plain Python variables (integers, strings, etc), which are simply cached. + Because changes to these values can affect the generated kernel, e.g. via a + Python `if` statement, any change in the value of a Python input triggers a + re-recording. +- DrJit variables. Opaque JIT variables are allowed to change from one call to + another without requiring re-tracing of the function. + +Since JIT variables' values can change from one call to another without +retracing, the function's behavior (and therefore the generated code) is **not** +allowed to change based on these values. To prevent incorrect behavior, reading +the contents from such variables is prohibited inside of a frozen function. + +.. code-block:: python + + @dr.freeze + def func(x, y): + # Depending on the content of x, one or the other kernel would be generated. + # This cannot be replayed and accessing x is therefore prohibited. + if x[1] > 0: + return y + 1 + else: + return y - 1 + + x = Float(0, 1) + y = Float(0, 1, 2) + + func(x, y) + +.. _non_recordable_operations: + +Non-recordable operations +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Whenever a device-level operation is called inside a frozen function, Dr.Jit +has to be made aware of it. Kernel launches and other common operations such as +reductions, are supported by hooking into a low-level abstraction in the core +library. + +However, applying any operation not known to Dr.Jit on the memory underlying a +variable is not supported and might result in incorrect outputs or exceptions. +As an example, such operations are used in the initialization of CUDA textures +or acceleration structure building in Mitsuba 3. + +.. code-block:: python + + @dr.freeze + def func(data, pos): + # On CUDA backends, this will call ``cuMemcpy2DAsync`` on the texture + # memory, without notifying the frozen function mechanism, and therefore fail. + tex = Texture1f([dr.width(data)], 1 + tex.set_value(data) + return tex.eval(pos) + + data = Float(0, 1) + pos = Float(0.3, 0.6) + func(data, pos) + + +Gradient propagation +~~~~~~~~~~~~~~~~~~~~ + +Very often, tracing the backward pass of an AD-attached computation is at least +as complex as the forward pass. Caching both the tracing and assembly steps is +therefore desirable. The :py:func:`drjit.freeze` decorator supports propagating +gradients within the function and can propagate gradients to variables that the +function's inputs depend on. + +However, propagating gradients from the result of a frozen function *through* +the function is not supported. All gradient backpropagation has to start +within the recorded function. + +In terms of automatic differentiation, annotating a function with the +:py:func:`dr.freeze` decorator is equivalent to wrapping the content with an +isolated gradient scope. + +.. code-block:: python + + @dr.freeze + def func(y): + # Some differentiable operation... + z = dr.mean(y) + # Propagate the gradients to the input of the function... + dr.backward(z) + + x = dr.arange(Float, 3) + dr.enable_grad(x) + + y = dr.square(x) + + # The first time the function is called, it will be recorded and the correct + # gradients will be accumulated into x. + func(y) + + y = x * 2 + + # On subsequent calls the the function will be replayed, and gradients will + # be accumulated in x. + func(y) + +The :py:func:`drjit.freeze` decorator adds an implicit +:py:func:`drjit.isolate_grad` context to the function. The above function is +then equivalent to the following function. + +.. code-block:: python + + def func(y): + # The isolate grad scope is added implicitly by the freezing decorator + with dr.isolate_grad(): + # Some differentiable operation... + z = dr.mean(y) + # Propagate the gradients to the input of the function... + dr.backward(z) + + +Compress +~~~~~~~~ + +A compress operation (:py:func:`drjit.compress`) will generate results whose +size (number of entries) is dependent on the content of the input. Therefore the +output size cannot be determined ahead of time. Using :py:func:`drjit.compress` +with any other function that needs to know array sizes in advance will cause the +function to be re-traced on every call, effectively rendering the freezing +mechanism useless. + +Examples of such functions include :py:func:`drjit.block_reduce`, +:py:func:`drjit.block_prefix_reduce`, and :py:func:`drjit.scatter_reduce` when +using the LLVM backend. + +.. code-block:: python + + @dr.freeze + func(x): + y = dr.block_reduce(dr.ReduceOp.Add, x, 2) + return dr.compress(y > 2) + + # Calling the function the first time, will cause it to be traced. + x = dr.arange(Float, 4) + func(x) + + # Successive calls will also re-trace the function, even when called with the + # same input. A warning will also be printed, to notify of such cases. + x = dr.arange(Float, 4) + func(x) + + +Offset pointers +~~~~~~~~~~~~~~~ + +Internally, new inputs to pre-recorded kernels are passed using the variables' +data pointer. This is also how variables are identified and disambiguated +in the function freezing implementation. + +However, this identification mechanism will not work for pointers pointing +*inside* of a memory region. Therefore, such pointers are not supported inside +of frozen functions. + +.. code-block:: cpp + + // This pattern is not supported inside of frozen functions. + UInt32::load_(x.data() + 4) + +Note that this pattern might be used in existing C++ code which is called inside +of the frozen function, which would result in an exception. + + +.. _pitfalls: + +Pitfalls +-------- + +When using the :py:func:`drjit.freeze` decorator, certain caveats have to be +considered. The following section will explain the most common pitfalls. + +Implicit inputs +~~~~~~~~~~~~~~~ + +A class can hold JIT arrays as members, and its methods can use them. Likewise, +a function can access variables of the outer scope (closures). These types of +implicit inputs to a frozen function are generally not supported: + +.. code-block:: python + + class MyClass: + def __init__(self, state: Float): + self.state = state + + @dr.freeze + def method(self, a: Float): + # The `self.state` variable is an implicit input to the frozen function. + # Attempting to record this function will raise an exception! + return self.state + a + + ... + + local_var = Float([1, 2, 3]) + def func(a: Float): + # `local_var` is an implicit input to the frozen function (closure variable). + return local_var + a + + @dr.freeze + def func2(b: Float): + return func(b) + b + + # This will raise an exception. Closure variables are not supported except + # in the most straightforward cases. + func2(Float([4, 5, 6])) + +When freezing such a method or function, these implicit inputs need to be made +visible to the freezing mechanism. There are two recommended ways to do so: + +1. Turn the class into a valid :ref:`PyTree `, e.g. a dataclass + (:py:class:`@dataclass`) or a ``DRJIT_STRUCT``. +2. Or, use the ``state_fn`` argument of the :py:func:`drjit.freeze` decorator to + manually specify the implicit inputs. ``state_fn`` will be called as a + function with the same arguments as the annotated function, and should return + a tuple of all extra inputs to be considered when recording and replaying. + +The following snippet illustrates correct usage: + +.. code-block:: python + + @dataclass + class MyDataClass: + # Dataclasses are valid PyTrees, which make these fields visible to Dr.Jit + # and the freezing mechanism. + x: Float + y: Float + + @dr.freeze + def func(self, z: Float): + return self.y + z + + def other_func(obj: MyDataClass, z: Float): + return obj.x + obj.y + x + + ... + + class OpaqueClass: + def __init__(self, x: Float): + # This field is not visible to Dr.Jit. + self.x = x + + # The ``state_fn`` argument can be used to make implicit inputs visible + # without modifying the class. + @dr.freeze(state_fn=(lambda obj, **_: obj.x)) + def func(obj: OpaqueClass): + return obj.x + 1 + + + +Kernel size inference +~~~~~~~~~~~~~~~~~~~~~ + +As explained above, frozen functions can in general be called many times with +JIT inputs of varying sizes (number of elements) without requiring re-tracing. + +In some situations, the size of an input may be used to determine the size of +another variable: + +.. code-block:: python + + @dr.freeze + def func(x): + indices = dr.arange(UInt32, dr.width(x) // 2) + # The size of the result depends on the size of input `x`. + return dr.gather(type(x), x, indices) + +The freezing mechanism uses a simple heuristic to detect variables whose size +is a direct multiple or fraction of the input size. + +.. code-block:: python + + # When calling the function, Dr.Jit will notice that the size of the output + # is a whole fraction of the input. This fact will be recorded by the frozen + # function. + x = dr.arange(Float, 8) + y1 = func(x) + assert dr.width(y1) == 4 + + # When replaying the function with a differently sized input, the size of + # the resulting variable will be derived according to this fraction. + x = dr.arange(Float, 16) + y2 = func(x) + assert dr.width(y2) == 8 + +Unfortunately, if this heuristic does not succeed (e.g. creating a variable with 3 +more entries than the input), the size of the new variable will be assumed to be +a constant, and will always be set to the size observed during the first recording, +even in subsequent calls. + +.. warning:: + + Because there is no way for Dr.Jit to reliably detect it, freezing a function + containing this pattern can result in unsafe code or undefined behavior! In + particular, there may be out-of-bounds accesses due to the incorrect variable + size. + +.. code-block:: python + + @dr.freeze + def func(x): + # The size of `indices` is not a simple multiple or fraction of the size + # of input `x`. + indices = dr.arange(UInt32, dr.width(x) - 1) + return dr.gather(type(x), x, indices) + + # When first calling the function with an input of size 8, the constant size + # of (8 - 1) = 7 is baked into the frozen function. + x = dr.arange(Float, 8) + y1 = func(x) + + # When replaying the function, a kernel of the hardcoded size 7 be replayed, + # resulting in an incorrect output. This is unsafe! + x = dr.arange(Float, 16) + y2 = func(x) + +When more than one variable are accessed using :py:func:`drjit.gather` or +:py:func:`drjit.scatter`, and the kernel size has to be inferred, it is +possible that Dr.Jit picks the wrong variable to base the kernel size on. +Such cases might also lead to undefined behavior and may cause out-of-bounds +memory accesses. In general, Dr.Jit will try to use the largest variable that +is either a fraction or multiple of the kernel input size. + +.. code-block:: python + + @dr.freeze + def func(x, y): + # The size of `indices` is not a simple multiple or fraction of the size + # of input `x`. + indices = dr.arange(UInt32, dr.width(x) // 2) + return dr.gather(type(x), x, indices) + dr.gather(type(y), y, indices) + + # When calling the function, Dr.Jit will notice, that the size of the output + # is a whole fraction of the size of ``x`` as well as ``y``. + x = dr.arange(Float, 8) + y = dr.arange(Float, 16) + z1 = func(x, y) + assert dr.width(z1) == 4 + + # When replaying the function, Dr.Jit will use the larger of the two inputs + # to determine the size of the output. + x = dr.arange(Float, 16) + y = dr.arange(Float, 32) + z2 = func(x, y) + assert dr.width(z2) == 8 + +Excessive recordings +~~~~~~~~~~~~~~~~~~~~ + +A common pattern when rendering scenes or running an optimization loop is to use +the iteration index, e.g. as a seed to initialize a random number generator. +This is also supported in a frozen function, however passing the iteration count +as a plain Python integer will cause the function to be re-recorded each time, +resulting in lower performance than not using frozen functions. + +.. code-block:: python + + @dr.freeze + def func(scene, it): + return render(scene, seed = it) + + for i in range(n): + # When this function is called with different int-typed seed values, the + # frozen function will be re-traced for each new value of ``i``! + func(scene, i) + + for i in range(n): + # Re-tracing can be prevented by using an opaque JIT variable instead. + i = dr.opaque(UInt32, i) + func(scene, i) + + +Auto-opaque +~~~~~~~~~~~ + +There is one more subtlety when using a *literal* JIT variable (:py:obj:`UInt32(i)`) +instead of an opaque one (:py:obj:`dr.opaque(UInt32, i)`). The "auto-opaque" +feature, which is enabled by default, will detect literal JIT inputs that +change between calls and make them opaque. However, this means that the function +has to be traced at least twice, which incurs additional overhead at the start. + +.. code-block:: python + + for i in range(n): + # By default, this literal JIT variable (non-opaque) will be made opaque + # when passed to the frozen function at the second call only. + # This means the function is traced twice instead of once. + i = UInt32(i) + func(scene, i) + +Disabling auto-opaque (:py:obj:`drjit.freeze(auto_opaque=False)`) will result +in a single recording, but all literal inputs will be made opaque regardless of +whether they would later remain constant or not. This will lead to higher memory +usage and may also worsen performance of the kernel itself. + +When possible, it is therefore recommended to **use opaque JIT variables for +inputs that are known to change across calls**. + +To help track changing inputs, Dr.Jit can provide a list of such changing +literals and their "paths" in the input arguments if they are detected: + +.. code-block:: python + + # For the literal "paths" to be printed the log level has to be set to ``Info`` + dr.set_log_level(dr.LogLevel.Info) + + @dr.freeze + def frozen(x, y, l, c): + return x + 1 + ... + + # Members of classes will be printed + @dataclass + class MyClass: + z: Float + + # We call the function twice. The first call will leave all literals untouched. + # In the second call, changing literals will be detected and their paths will + # be printed. + for i in range(2): + x = dr.arange(Float, i+2) + y = Float(i) + l = [Float(1), Float(i)] + c = MyClass(Float(i)) + + # The function can be called with arguments and keyword arguments. They will + # show up as a tuple in the path. + frozen(x, y, l, c = c) + +The above code will print the following message, when the function is called the second time: + +.. code-block:: text + + While traversing the frozen function input, new literal variables have + been discovered which changed from one call to another. These will be made + opaque, and the input will be traversed again. This will incur some + overhead. To prevent this, make those variables opaque in beforehand. Below, + a list of variables that changed will be shown. + args[1][0]: The literal value of this variable changed from 0x0 to 0x3f800000 + args[2][1][0]: The literal value of this variable changed from 0x0 to 0x3f800000 + kwargs["c"].z[0]: The literal value of this variable changed from 0x0 to 0x3f800000 + +This output can be used to determine which literal where made opaque. +As stated above, it can be beneficial to make these literals opaque beforehand. +In this case, the second argument of the function, the second argument of the +list and the member ``z`` of the class have been detected as changing literals. + + +Dry-run replay +~~~~~~~~~~~~~~ + +Some operations, such as block reductions, require the recording to be replayed +in a dry-run mode before executing it. This calculates the size of variables and +ensures that it will be possible to replay the recording later. If such a +dry-run fails, the function will have to be re-traced, however instead of adding +a new recording to the function, the old one will be overwritten. It is not +possible to add another recording, to the cache, since the condition that +causes a dry-run to fail can be dependent on the size (number of elements) of +JIT input variables, which is allowed to change. + +.. code-block:: python + + dr.freeze + def func(x): + return dr.block_reduce(dr.ReduceOp.Add, x, 2) + + # The first time the function is called, a new recording is made + x = dr.arange(Float, 4) + y = func(x) + + # The block reduction will require a dry-run before launching kernels. In + # this case, it is detected that the size of x is not divisible by 2. The + # function will be re-traced, however this will overwrite the old recording. + x = dr.arange(Float, 5) + y = func(x) + + # Calling the function in a loop with changing input sizes can cause all + # dry-runs to fail, rendering the freezing mechanism useless. + for i in range(5, 10): + x = dr.arange(Float, i) + y = func(x) + +A warning will be printed after more than 10 iterations have been re-traced. +This limit can be changed using the ``warn_after`` argument of the decorator. + +Such functions should therefore be used with caution and only called with +inputs that do not lead to a dry-run failure. + +Tensor shapes +~~~~~~~~~~~~~ + +When a frozen function is called with a tensor, the first dimension of the +tensor is assumed to be dynamic. It can change from one call to another without +triggering re-tracing of the function. However, changes in any other dimension +will cause it to be re-traced. + +This is due to the way tensors are indexed: computing indices to access tensor +entries does not involve the first (outermost) dimension, which makes it +possible to reuse the same code as long as trailing dimensions do not change. + +.. code-block:: python + + @dr.freeze + def func(t: TensorXf, i: UInt, j: UInt, k: UInt): + # Indexes into the tensor array, getting the entry at (row, col) + index = i * dr.shape(t)[1] * dr.shape(t)[2] + j * dr.shape(t)[2] + k + return dr.gather(Float, t.array, index) + + # The first call will record the function + t = TensorXf(dr.arange(Float, 10*7*3), shape=(10, 7, 3)) + func(t, UInt(1), UInt(1), UInt(1)) + + # Subsequent calls with the same trailing dimensions will be replayed + t = TensorXf(dr.arange(Float, 25*7*3), shape=(25, 7, 3)) + func(t, UInt(1), UInt(1), UInt(1)) + + # Changes in trailing dimensions will cause the function to be re-traced + t = TensorXf(dr.arange(Float, 10*3*7), shape=(10, 3, 7)) + func(t, UInt(1), UInt(1), UInt(1)) + +Dr.Jit also supports advanced tensor indexing, allowing you to use arrays to +index into a tensor e.g. ``t[UInt(1, 2, 3), :]``. This syntax can also be +used inside of frozen functions, however it might lead to kernels with baked-in +kernel sizes, and therefore incorrect outputs. If tensor indexing with indices +of changing sizes is required, calculating the array index manually with the +formula in the above example is recommended. + +.. code-block:: python + + @dr.freeze + def func(t: TensorXf, i: UInt, j: UInt, k: UInt): + # Indexes into the tensor array, getting the entry at (row, col) + return t[i, j, k] + + t = TensorXf(dr.arange(Float, 10*7*3), shape=(10, 7, 3)) + + # The first call will record the function, and will return a tensor of shape + # (3, 2, 1) + func(t, UInt(1, 2, 3), UInt(1, 2), UInt(1)) + + # Calling the function with a different number of index elements will be + # correct, as long as only the array with the largest number of indices + # changes. + func(t, UInt(1, 2, 3, 4), UInt(1, 2), UInt(1)) + + # Calling the function with a different number of index elements on multiple + # dimensions can lead to incorrect outputs. The heuristic will use the larger + # array to infer the size of the kernel, by multiplication with the recorded + # fraction (in this case 2). This call will (incorrectly) return a tensor of + # shape (4, 2, 1). + func(t, UInt(1, 2, 3, 4), UInt(1, 2, 3), UInt(1)) + + +.. warning:: + Using indexing or slicing inside of a frozen function can easily lead to + baked-in kernel sizes and as a result to incorrect outputs without any + warning. This should be used with caution when replaying frozen functions + with JIT inputs of varying sizes (number of elements). + +Textures +~~~~~~~~ + +:ref:`Textures ` can be used inside of frozen functions for lookups, +as well as for gradient calculations. However because they require special +memory operations on the CUDA backend, it is not possible to update or +initialize CUDA textures inside of frozen functions. +This is a special case of :ref:`non-recordable operation `. + +.. code-block:: python + + @dr.freeze + def func(tex: Texture1f, pos: Float): + return tex.eval(pos) + + tex = Texture1f([2], 1) + tex.set_value(t(0, 1)) + + pos = dr.arange(Float, 4) / 4 + + # The texture can be evaluated inside the frozen function. + func(tex, pos) + + +Virtual function calls +~~~~~~~~~~~~~~~~~~~~~~ + +As symbolic virtual function calls are generally supported by frozen functions. +However, some limitations apply. The following example shows a supported use of +virtual function calls in frozen functions. + +.. code-block:: python + + # `A` and `B` derive from `Base` + a, b = A(), B() + + @dr.freeze + def func(base: BasePtr, x: Float): + return base.f(x) + + base = BasePtr(a, a, None, b, b) + x = Float(1, 2, 3, 4, 5) + func(base, x) + +When a frozen function is called with a variable that can point to a virtual +base class, Dr.Jit's pointer registry is traversed to find all variables used +in the frozen function call. Since some objects can be registered, but not +referenced by the pointer, member JIT variables of these objects are traversed +**and evaluated**, even though they are not used in the function. +This side-effect can be unexpected. + +.. code-block:: python + + # `A` and `B` derive from `Base` + # These objects are registered with Dr.Jit's pointer registry + a, b = A(), B() + + @dr.freeze + def func(base: BasePtr, x: Float): + return base.f(x) + + # Even though only `a` is referenced, we have to traverse member variables + # of `b`. These can be evaluated by the frozen function call. + base = BasePtr(a, a, None) + x = Float(1, 2, 3, 4, 5) + func(base, x) + +Nested virtual function calls are supported when the inner base class pointer +is passed as an argument to the outer function. However, due to implementation +details nested calls are not supported when the outer function retrieves the +callee pointer from class member variables + +.. code-block:: python + + # Even though `A` is traversable, a frozen function with a call to + # ``nested_member`` will fail. + class A(Base): + DRJIT_STRUCT = { + "s": BasePtr, + } + + s: BasePtr + + def nested(self, s, x): + s.f(x) + + def nested_member(self, x): + self.s.f(x) + + a, b = A(), B() + + # This nested vcall is supported because the nested base pointer is an + # argument to the nested function. + @dr.freeze + def supported(base: BasePtr, nested_base: BasePtr, x: Float): + return base.nested(nested_base, x) + + a.s = BasePtr(b) + dr.make_opaque(a.s) + + # This nested vcall is unsupported because the nested base pointer is an + # opaque member of the class `A`. + @dr.freeze + def unsupported(base: BasePtr, x: Float): + return base.nested_member(x) + +Runaway recursion +~~~~~~~~~~~~~~~~~ + +Passing inputs to a frozen function that contain basic reference cycles is +supported. However, reference cycles going through C++ classes can lead to a +runaway recursion when traversing the function inputs, and raise an exception. + +.. code-block:: python + + @dr.freeze + def frozen(l): + return l[0] + 1 + + # This constructs a list with a reference cycle. + l = [Float(1)] + l.append(l) + + # Passing an object with a simple reference cycle is supported. + frozen(l) + +However, this more complex example shows an *unsupported* case of reference cycles that +can occur when using custom BSDFs in Mitsuba 3. + +.. code-block:: python + + # A class inheriting from a trampoline class is automatically traversed. + class MyBSDF(mi.BSDF): + def set_scene(self, scene): + self.scene = scene + ... + + @dr.freeze + def frozen(scene): + ... + + # Construct a scene that includes ``MyBSDF`` as an element. + scene = ... + # Setting the scene reference in the BSDF completes the reference cycle. + mybsdf.set_scene(scene) + + # Calling the function with such an object, will lead to a runaway + # recursion, and the frozen function will raise an exception. + frozen(scene) + + +.. _freezing_implementation_details: + +Implementation details +---------------------- + +Every time the annotated function is called, its inputs are analyzed. All JIT +variables are extracted into a flattened and de-duplicated array. Additionally, +a key describing the "layout" of the inputs is generated. This key will be used +to distinguish between different recordings of the same frozen function, in case +some of its inputs qualitatively change in subsequent calls. + +If no recording is found for the current key, Dr.Jit enters a "kernel recording" +mode (:py:obj:`drjit.JitFlag.FreezingScope`) and the actual function code is +executed. In this mode, all device level operations, such as kernel launches are +recorded as well as executed normally. + +The next time the function is called, the newly-provided inputs are traversed, +and the layout is used to look up compatible recordings. If such a recording is +found, any tracing is skipped: the various recorded operations and kernels are +directly replayed. + +Traversal +~~~~~~~~~ + +In order to map the variables provided to a frozen function as arguments to the +actual kernel inputs (slots), Dr.Jit must be able to traverse these arguments. +In addition to basic Python containers such as lists, tuples and dictionaries, +the following :ref:`PyTrees ` are traversable and can be part of the +input of a frozen function. + +*Dataclasses* are traversable by Dr.Jit and their fields are automatically made +visible to the traversal algorithm. + +.. code-block:: python + + # Fields of dataclasses are traversable + @dataclass + class MyClass: + x: Float + +Classes can be annotated with a static ``DRJIT_STRUCT`` field to make classes +traversable. + +.. code-block:: python + + class MyClass: + x: Float + + # Annotating the class with DRJIT_STRUCT will make the members listed + # available to traversal. + DRJIT_STRUCT = { + "x": Float + } + +Finally, C++ classes may additionally implement the ``TraversableBase`` class +to make them traversable. Python classes, inheriting from these classes through +trampolines are automatically traversed. This is useful when implementing your +own subclasses with virtual function calls. + +.. code-block:: python + + # If BSDF is a traversable trampoline class, + # then member variables of MyClass will also be traversed. + class MyClass(mi.BSDF): + x: Float + + +Output construction +~~~~~~~~~~~~~~~~~~~ + +After a frozen function has been replayed, the outputs of the replayed operation +(kernel launches, reductions, etc) have to be mapped back to outputs of the +frozen function, respecting the layout observed in the first launch. + +Since this output must be constructible at replay time, only a subset of +traversable types can be returned from frozen functions. This includes: + +- JIT and AD variables, +- Dr.Jit Tensors and Arrays, +- Python lists, tuples and dictionaries, +- Dataclasses, +- ``DRJIT_STRUCT`` annotated classes with a default constructor. + +The following example shows an *unsupported* return type: because the constructor +of ``MyClass`` expects a variable, an object of type ``MyClass`` cannot be +created at replay time. + +.. code-block:: python + + class MyClass: + x: Float + + DRJIT_STRUCT = { + "x": Float, + } + + # Non-default constructor (requires argument `x`) + def __init__(self, x: Float): + self.x = x + + @dr.freeze + def func(x): + return MyClass(x + 1) + + # Calling the function will fail, as the output of the frozen function + # cannot be constructed without a default constructor. + func(Float(1, 2, 3)) diff --git a/docs/index.rst b/docs/index.rst index e10e6611b..b60f0ef29 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -48,6 +48,7 @@ public API. textures coop_vec nn + freeze faq .. toctree:: diff --git a/docs/reference.rst b/docs/reference.rst index 54406e213..a57c06558 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -170,6 +170,11 @@ Just-in-time compilation .. automethod:: __enter__ .. automethod:: __exit__ +Function freezing +----------------- + +.. autofunction:: freeze + Type traits ----------- diff --git a/drjit/__init__.py b/drjit/__init__.py index 7e71a8709..529b0bb62 100644 --- a/drjit/__init__.py +++ b/drjit/__init__.py @@ -19,12 +19,12 @@ import sys as _sys if _sys.version_info < (3, 11): try: - from typing_extensions import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable + from typing_extensions import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable, TypeVar except ImportError: raise RuntimeError( "Dr.Jit requires the 'typing_extensions' package on Python <3.11") else: - from typing import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable + from typing import overload, Optional, Type, Tuple, List, Sequence, Union, Literal, Callable, TypeVar from .ast import syntax, hint from .interop import wrap @@ -2494,6 +2494,296 @@ def binary_search(start, end, pred): return start +# Represents the frozen function passed to the decorator without arguments +F = TypeVar("F") +# Represents the frozen function passed to the decorator with arguments +F2 = TypeVar("F2") + + +@overload +def freeze( + f: None = None, + *, + state_fn: Optional[Callable], + limit: Optional[int] = None, + warn_after: int = 10, + backend: Optional[JitBackend] = None, + auto_opaque: bool = True, +) -> Callable[[F], F]: + ... + + +@overload +def freeze( + f: F, + *, + state_fn: Optional[Callable] = None, + limit: Optional[int] = None, + warn_after: int = 10, + backend: Optional[JitBackend] = None, + auto_opaque: bool = True, +) -> F: + ... + + +def freeze( + f: Optional[F] = None, + *, + state_fn: Optional[Callable] = None, + limit: Optional[int] = None, + warn_after: int = 10, + backend: Optional[JitBackend] = None, + auto_opaque: bool = True, +) -> Union[F, Callable[[F2], F2]]: + """ + Decorator to "freeze" functions, which improves efficiency by removing + repeated JIT tracing overheads. + + In general, Dr.Jit traces computation and then compiles and launches kernels + containing this trace (see the section on :ref:`evaluation ` for + details). While the compilation step can often be skipped via caching, the + tracing cost can still be significant especially when repeatedly evaluating + complex models, e.g., as part of an optimization loop. + + The :py:func:`@dr.freeze ` decorator adresses this problem by + altogether removing the need to trace repeatedly. For example, consider the + following decorated function: + + .. code-block:: python + + @dr.freeze + def f(x, y, z): + return ... # Complicated code involving the arguments + + Dr.Jit will trace the first call to the decorated function ``f()``, while + collecting additional information regarding the nature of the function's inputs + and regarding the CPU/GPU kernel launches representing the body of ``f()``. + + If the function is subsequently called with *compatible* arguments (more on + this below), it will immediately launch the previously made CPU/GPU kernels + without re-tracing, which can substantially improve performance. + + When :py:func:`@dr.freeze ` detects *incompatibilities* (e.g., ``x`` + having a different type compared to the previous call), it will conservatively + re-trace the body and keep track of another potential input configuration. + + Frozen functions support arbitrary :ref:`PyTrees ` as function + arguments and return values. + + The following may trigger re-tracing: + + - Changes in the **type** of an argument or :ref:`PyTree ` element. + - Changes in the **length** of a container (``list``, ``tuple``, ``dict``). + - Changes of **dictionary keys** or **field names** of dataclasses. + - Changes in the AD status (:py:func:`dr.grad_enabled() `) of a variable. + - Changes of (non-PyTree) **Python objects**, as detected by mismatching ``hash()`` + or ``id()`` if they are not hashable. + + The following more technical conditions also trigger re-tracing: + + - A Dr.Jit variable changes from/to a **scalar** configuration (size ``1``). + - The sets of variables of the same size change. In the example above, this + would be the case if ``len(x) == len(y)`` in one call, and ``len(x) != len(y)`` + subsequently. + - When Dr.Jit variables reference external memory (e.g. mapped NumPy arrays), the + memory can be aligned or unaligned. A re-tracing step is needed when this + status changes. + + These all correspond to situations where the generated kernel code may need to + change, and the system conservatively re-traces to ensure correctness. + + Frozen functions support arguments with a different variable *width* (see + :py:func:`dr.with() `) without re-tracing, as long as the sets of + variables of the same width stay consistent. + + Some constructions are problematic and should be avoided in frozen functions. + + - The function :py:func:`dr.width() ` returns an integer literal + that may be merged into the generated code. If the frozen function is later + rerun with differently-sized arguments, the executed kernels will still + reference the old size. One exception to this rule are constructions like + `dr.arange(UInt32, dr.width(a))`, where the result only implicitly depends on + the width value. + + When calling a frozen function from within an outer frozen function, the content + of the inner function will be executed and recorded by the outer function. + No separate recording will be made for the inner function, and its ``n_recordings`` + count will not change. Calling the inner function separately from outside a + frozen function will therefore require re-tracing for the provided inputs. + + **Advanced features**. The :py:func:`@dr.freeze ` decorator takes + several optional parameters that are helpful in certain situations. + + - **Warning when re-tracing happens too often**: Incompatible arguments trigger + re-tracing, which can mask issues where *accidentally* incompatible arguments + keep :py:func:`@dr.freeze ` from producing the expected + performance benefits. + + In such situations, it can be helpful to warn and identify changing + parameters by name. This feature is enabled and set to ``10`` by default. + + .. code-block:: pycon + + >>> @dr.freeze(warn_after=1) + >>> def f(x): + ... return x + ... + >>> f(Int(1)) + >>> f(Float(1)) + The frozen function has been recorded 2 times, this indicates a problem + with how the frozen function is being called. For example, calling it + with changing python values such as an index. For more information about + which variables changed set the log level to ``LogLevel::Debug``. + + - **Limiting memory usage**. Storing kernels for many possible input + configuration requires device memory, which can become problematic. Set the + ``limit=`` parameter to enable a LRU cache. This is useful when calls to a + function are mostly compatible but require occasional re-tracing. + + Args: + limit (Optional[int]): An optional integer specifying the maximum number of + stored configurations. Once this limit is reached, incompatible calls + requiring re-tracing will cause the last used configuration to be dropped. + + warn_after (int): When the number of re-tracing steps exceeds this value, + Dr.Jit will generate a warning that explains which variables changed + between calls to the function. + + state_fn (Optional[Callable]): This optional callable can specify additional + state to identifies the configuration. ``state_fn`` will be called with + the same arguments as that of the decorated function. It should return a + traversable object (e.g., a list or tuple) that is conceptually treated + as if it was another input of the function. + + backend (Optional[JitBackend]): If no inputs are given when calling the + frozen function, the backend used has to be specified using this argument. + It must match the backend used for computation within the function. + + auto_opaque: (bool): If this flag is set true and only literal values + or their size changes between calls to the function, these variables + will be marked and made opaque. This reduces the memory usage, traversal + overhead, and can improve the performance of generated kernels. + If the flag is set to false, all input variables will be made opaque. + """ + + limit = limit if limit is not None else -1 + backend = backend if backend is not None else JitBackend.Invalid + + def decorator(f): + """ + Internal decorator, returned in ``dr.freeze`` was used with arguments. + """ + import functools + import inspect + + def inner(input: dict): + """ + This inner function is the one that is actually frozen, and it calls + the wrapped function. It receives the input such as args, kwargs and + any additional input such as closures or state specified with the ``state`` + lambda, and makes its traversal possible. + """ + args = input["args"] + kwargs = input["kwargs"] + return f(*args, **kwargs) + + class FrozenFunction: + def __init__(self, f) -> None: + self.f = f + self.frozen = detail.FrozenFunction( + inner, limit, warn_after, backend, auto_opaque + ) + + def __call__(self, *args, **kwargs): + _state = state_fn(*args, **kwargs) if state_fn is not None else None + # Capture closure variables to detect when nonlocal symbols change. + closure = inspect.getclosurevars(f) + return self.frozen( + { + "globals": closure.globals, + "nonlocals": closure.nonlocals, + "state_fn": _state, + "args": args, + "kwargs": kwargs, + } + ) + + @property + def n_recordings(self): + """ + Represents the number of times the function was recorded. This + includes occasions where it was recorded due to a dry-run failing. + It does not necessarily correspond to the number of recordings + currently cached see ``n_cached_recordings`` for that. + """ + return self.frozen.n_recordings + + @property + def n_cached_recordings(self): + """ + Represents the number of recordings currently cached of the frozen + function. If a recording fails in dry-run mode, it will not create + a new recording, but replace the recording that was attemted to be + replayed. The number of recordings can also be limited with + the ``max_cache_size`` argument. + """ + return self.frozen.n_cached_recordings + + def clear(self): + """ + Clears the recordings of the frozen function, and resets the + ``n_recordings`` counter. The reference to the function is still + kept, and the frozen function can be called again to re-trace + new recordings. + """ + return self.frozen.clear() + + def __get__(self, obj, type=None): + if obj is None: + return self + else: + return FrozenMethod(self.f, self.frozen, obj) + + class FrozenMethod(FrozenFunction): + """ + A FrozenMethod currying the object into the __call__ method. + + If the ``freeze`` decorator is applied to a method of some class, it has + to call the internal frozen function with the ``self`` argument. To this + end we implement the ``__get__`` method of the frozen function, to + return a ``FrozenMethod``, which holds a reference to the object. + The ``__call__`` method of the ``FrozenMethod`` then supplies the object + in addition to the arguments to the internal function. + """ + def __init__(self, f, frozen, obj) -> None: + self.f = f + self.obj = obj + self.frozen = frozen + + def __call__(self, *args, **kwargs): + _state = state_fn(self.obj, *args, **kwargs) if state_fn is not None else None + # Capture closure variables to detect when nonlocal symbols change. + closure = inspect.getclosurevars(self.f) + return self.frozen( + { + "globals": closure.globals, + "nonlocals": closure.nonlocals, + "state_fn": _state, + "args": [self.obj, *args], + "kwargs": kwargs, + } + ) + + return functools.wraps(f)(FrozenFunction(f)) + + if f is not None: + return decorator(f) + else: + return decorator + + +del F +del F2 def assert_true( cond, diff --git a/drjit/opt.py b/drjit/opt.py index 1dc2df4a7..fc0733aa5 100644 --- a/drjit/opt.py +++ b/drjit/opt.py @@ -126,6 +126,11 @@ class Optimizer(Generic[Extra], MutableMapping[str, dr.ArrayBase]): # - an arbitrary sequence of additional optimizer-dependent state values state: Dict[str, Tuple[dr.ArrayBase, Optional[LearningRate], Extra]] + DRJIT_STRUCT = { + "lr": LearningRate, + "state": dict, + } + def __init__( self, lr: LearningRate, @@ -960,10 +965,15 @@ def _step( # Compute the step size scale, which is a product of # - EMA debiasing factor # - Adaptive/parameter-specific scaling + Float32 = dr.float32_array_t(dr.leaf_t(grad)) + Float64 = dr.float64_array_t(dr.leaf_t(grad)) + ema_factor = Float32( + -dr.sqrt(1 - Float64(self.beta_2) ** t) / (1 - Float64(self.beta_1) ** t) + ) scale = cache.product( dr.leaf_t(grad), # Desired type lr, - -dr.sqrt(1 - self.beta_2**t) / (1 - self.beta_1**t), + ema_factor, ) # Optional: use maximum of second order term @@ -981,9 +991,11 @@ def _step( def _reset(self, key: str, value: dr.ArrayBase, /) -> None: valarr = value.array tp = type(valarr) + UInt = dr.uint32_array_t(dr.leaf_t(tp)) + t = UInt(0) m_t = dr.opaque(tp, 0, valarr.shape) v_t = dr.opaque(tp, 0, valarr.shape) - self.state[key] = value, None, (0, m_t, v_t) + self.state[key] = value, None, (t, m_t, v_t) # Blend between the old and new versions of the optimizer extra state def _select( diff --git a/ext/drjit-core b/ext/drjit-core index 5c3dabcee..d8967fe6a 160000 --- a/ext/drjit-core +++ b/ext/drjit-core @@ -1 +1 @@ -Subproject commit 5c3dabcee2ee23b30d42ec1dc846d042c88406e8 +Subproject commit d8967fe6acd5d097bb44167952b773107b535c66 diff --git a/include/drjit/array_router.h b/include/drjit/array_router.h index 3ec0e6bc2..a590291f9 100644 --- a/include/drjit/array_router.h +++ b/include/drjit/array_router.h @@ -526,9 +526,15 @@ DRJIT_ROUTE_BINARY_FALLBACK(dot, dot, (E) a1 * (E) a2) template DRJIT_INLINE auto mean(const Array &a) { - if constexpr (is_array_v) - return sum(a) * (1.f / a.derived().size()); - else + if constexpr (is_array_v){ + if (jit_flag(JitFlag::FreezingScope)) { + // Inside of frozen functions, we have to avoid baking the size of + // the array into the kernel as a literal. + return sum(a) * (1.f / a.derived().opaque_size_()); + } else { + return sum(a) * (1.f / a.derived().size()); + } + } else return a; } diff --git a/include/drjit/array_static.h b/include/drjit/array_static.h index acc7011cf..c41cd1fa7 100644 --- a/include/drjit/array_static.h +++ b/include/drjit/array_static.h @@ -79,6 +79,8 @@ struct StaticArrayBase : ArrayBaseT { DRJIT_INLINE constexpr size_t size() const { return Derived::Size; } + DRJIT_INLINE constexpr size_t opaque_size_() const { return Derived::Size; } + DRJIT_INLINE void init_(size_t) { } static Derived empty_(size_t size) { diff --git a/include/drjit/array_traverse.h b/include/drjit/array_traverse.h index 8fee8681e..88d640332 100644 --- a/include/drjit/array_traverse.h +++ b/include/drjit/array_traverse.h @@ -15,6 +15,8 @@ #pragma once +#include + #define DRJIT_STRUCT_NODEF(Name, ...) \ Name(const Name &) = default; \ Name(Name &&) = default; \ @@ -140,6 +142,18 @@ namespace detail { using det_traverse_1_cb_rw = decltype(T(nullptr)->traverse_1_cb_rw(nullptr, nullptr)); + template + using det_get = decltype(std::declval().get()); + + template + using det_const_get = decltype(std::declval().get()); + + template + using det_begin = decltype(std::declval().begin()); + + template + using det_end = decltype(std::declval().end()); + inline drjit::string get_label(const char *s, size_t i) { auto skip = [](char c) { return c == ' ' || c == '\r' || c == '\n' || c == '\t' || c == ','; @@ -168,7 +182,7 @@ template using traversable_t = detail::traversable> template static constexpr bool is_traversable_v = traversable_t::value; template using enable_if_traversable_t = enable_if_t>; -template static constexpr bool is_dynamic_traversable_v = +template static constexpr bool is_dynamic_traversable_v = is_jit_v && is_dynamic_array_v && is_vector_v && !is_tensor_v; template DRJIT_INLINE auto fields(T &&v) { @@ -179,11 +193,33 @@ template auto labels(const T &v) { return traversable_t::labels(v); } +/** + * This function traverses C++ objects, that have one of the following features: + * + * 1. They represent Jit arrays, in which case the callback is called with + * optional domain and variant arguments. + * 2. They fall under the \c traversable trait (see above), for example + * DRJIT_STRUCTs or tuples + * 3. They represent dynamic arrays. + * 4. They themselves implement the function \c traverse_1_cb_ro, in which case + * this function is called. + * 5. They represent iterables with a \c begin and \c end function, such as + * \c std::vector or \c drjit::vector. + * 6. They represent unique pointers, with a constant get method, such as + * \c std::unique_ptr. + */ template -void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint64_t)) { - (void) payload; (void) fn; +void traverse_1_fn_ro(const Value &value, void *payload, + void (*fn)(void *, uint64_t, const char *, + const char *)) { + DRJIT_MARK_USED(payload); + DRJIT_MARK_USED(fn); if constexpr (is_jit_v && depth_v == 1) { - fn(payload, value.index_combined()); + if constexpr(Value::IsClass) + fn(payload, value.index_combined(), Value::CallSupport::Variant, + Value::CallSupport::Domain); + else + fn(payload, value.index_combined(), "", ""); } else if constexpr (is_traversable_v) { traverse_1(fields(value), [payload, fn](auto &x) { traverse_1_fn_ro(x, payload, fn); @@ -198,14 +234,48 @@ void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint is_detected_v) { if (value) value->traverse_1_cb_ro(payload, fn); + + } else if constexpr (is_detected_v && + is_detected_v) { + for (auto elem : value) + traverse_1_fn_ro(elem, payload, fn); + } else if constexpr (is_detected_v) { + const auto *tmp = value.get(); + traverse_1_fn_ro(tmp, payload, fn); + } else if constexpr (is_detected_v) { + value.traverse_1_cb_ro(payload, fn); } } +/** + * This function traverses C++ objects, that have one of the following features: + * + * 1. They represent Jit arrays, in which case the callback is called with + * optional domain and variant arguments. + * 2. They fall under the \c traversable trait (see above), for example + * DRJIT_STRUCTs or tuples + * 3. They represent dynamic arrays. + * 4. They themselves implement the function \c traverse_1_cb_rw, in which case + * this function is called. + * 5. They represent iterables with a \c begin and \c end function, such as + * \c std::vector or \c drjit::vector. + * 6. They represent unique pointers, with a get method, such as + * \c std::unique_ptr. + */ template -void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64_t)) { - (void) payload; (void) fn; +void traverse_1_fn_rw(Value &value, void *payload, + uint64_t (*fn)(void *, uint64_t, const char *, + const char *)) { + DRJIT_MARK_USED(payload); + DRJIT_MARK_USED(fn); if constexpr (is_jit_v && depth_v == 1) { - value = Value::borrow((typename Value::Index) fn(payload, value.index_combined())); + if constexpr(Value::IsClass) + value = Value::borrow((typename Value::Index) fn( + payload, value.index_combined(), Value::CallSupport::Variant, + Value::CallSupport::Domain)); + else + value = Value::borrow((typename Value::Index) fn( + payload, value.index_combined(), "", "")); } else if constexpr (is_traversable_v) { traverse_1(fields(value), [payload, fn](auto &x) { traverse_1_fn_rw(x, payload, fn); @@ -220,6 +290,15 @@ void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64 is_detected_v) { if (value) value->traverse_1_cb_rw(payload, fn); + } else if constexpr (is_detected_v && + is_detected_v) { + for (auto elem : value) + traverse_1_fn_rw(elem, payload, fn); + } else if constexpr (is_detected_v) { + auto *tmp = value.get(); + traverse_1_fn_rw(tmp, payload, fn); + } else if constexpr (is_detected_v) { + value.traverse_1_cb_rw(payload, fn); } } diff --git a/include/drjit/autodiff.h b/include/drjit/autodiff.h index ad22d6303..58e9b87b0 100644 --- a/include/drjit/autodiff.h +++ b/include/drjit/autodiff.h @@ -709,6 +709,11 @@ struct DRJIT_TRIVIAL_ABI DiffArray size_t size() const { return jit_var_size((uint32_t) m_index); } + auto opaque_size_() const { + using UInt32 = JitArray; + return UInt32::steal(jit_var_opaque_width(m_index)); + } + bool grad_enabled_() const { if constexpr (IsFloat) return (m_index >> 32) != 0; @@ -973,7 +978,9 @@ NAMESPACE_BEGIN(detail) /// Internal operations for traversing nested data structures and fetching or /// storing indices. Used in ``call.h`` and ``loop.h``. -template void collect_indices_fn(void *p, uint64_t index) { +template +void collect_indices_fn(void *p, uint64_t index, const char * /*variant*/, + const char * /*domain*/) { vector &indices = *(vector *) p; if constexpr (IncRef) index = ad_var_inc_ref(index); @@ -985,7 +992,8 @@ struct update_indices_payload { size_t &pos; }; -inline uint64_t update_indices_fn(void *p, uint64_t) { +inline uint64_t update_indices_fn(void *p, uint64_t, const char * /*variant*/, + const char * /*domain*/) { update_indices_payload &payload = *(update_indices_payload *) p; return payload.indices[payload.pos++]; } diff --git a/include/drjit/custom.h b/include/drjit/custom.h index 7a38ec3dd..143a2d7d8 100644 --- a/include/drjit/custom.h +++ b/include/drjit/custom.h @@ -20,6 +20,7 @@ #include #include +#include NAMESPACE_BEGIN(drjit) NAMESPACE_BEGIN(detail) diff --git a/include/drjit/dynamic.h b/include/drjit/dynamic.h index e895c88dd..a39bb7263 100644 --- a/include/drjit/dynamic.h +++ b/include/drjit/dynamic.h @@ -123,6 +123,7 @@ struct DynamicArray } DRJIT_INLINE size_t size() const { return m_size; } + DRJIT_INLINE size_t opaque_size_() const { return m_size; } DRJIT_INLINE DynamicArray copy() { return DynamicArray(*this); } DRJIT_INLINE Value &entry(size_t i) { diff --git a/include/drjit/extra.h b/include/drjit/extra.h index 6ba877e9f..127a9e333 100644 --- a/include/drjit/extra.h +++ b/include/drjit/extra.h @@ -250,6 +250,11 @@ extern DRJIT_EXTRA_EXPORT bool ad_release_one_output(drjit::detail::CustomOpBase extern DRJIT_EXTRA_EXPORT void ad_copy_implicit_deps(drjit::vector &, bool input); +/// Retrieve a list of ad indices, that are the target of edges, that have been +/// postponed by the current scope +extern DRJIT_EXTRA_EXPORT void ad_scope_postponed(drjit::vector *dst); + + /// Kahan-compensated floating point atomic scatter-addition extern DRJIT_EXTRA_EXPORT void ad_var_scatter_add_kahan(uint64_t *target_1, uint64_t *target_2, uint64_t value, diff --git a/include/drjit/jit.h b/include/drjit/jit.h index 7922a8cda..df46df914 100644 --- a/include/drjit/jit.h +++ b/include/drjit/jit.h @@ -596,6 +596,10 @@ struct DRJIT_TRIVIAL_ABI JitArray bool valid() const { return m_index != 0; } size_t size() const { return jit_var_size(m_index); } + auto opaque_size_() const { + using UInt32 = JitArray; + return UInt32::steal(jit_var_opaque_width(m_index)); + } uint32_t index() const { return m_index; } uint32_t index_ad() const { return 0; } uint64_t index_combined() const { return m_index; } diff --git a/include/drjit/python.h b/include/drjit/python.h index f65755cce..0b01bcf7d 100644 --- a/include/drjit/python.h +++ b/include/drjit/python.h @@ -54,6 +54,7 @@ #include #include #include +#include NAMESPACE_BEGIN(drjit) struct ArrayBinding; @@ -1060,25 +1061,87 @@ template void bind_all(ArrayBinding &b) { // Expose already existing object tree traversal callbacks (T::traverse_1_..) in Python. // This functionality is needed to traverse custom/opaque C++ classes and correctly // update their members when they are used in vectorized loops, function calls, etc. -template auto& bind_traverse(nanobind::class_ &cls) { +template auto &bind_traverse(nanobind::class_ &cls) +{ namespace nb = nanobind; - struct Payload { nb::callable c; }; + struct Payload { + nb::callable c; + }; + + static_assert(std::is_base_of_v); cls.def("_traverse_1_cb_ro", [](const T *self, nb::callable c) { Payload payload{ std::move(c) }; - self->traverse_1_cb_ro((void *) &payload, [](void *p, uint64_t index) { - ((Payload *) p)->c(index); - }); + self->traverse_1_cb_ro((void *) &payload, + [](void *p, uint64_t index, const char *variant, const char *domain) { + ((Payload *) p)->c(index, variant, domain); + }); }); cls.def("_traverse_1_cb_rw", [](T *self, nb::callable c) { Payload payload{ std::move(c) }; - self->traverse_1_cb_rw((void *) &payload, [](void *p, uint64_t index) { - return nb::cast(((Payload *) p)->c(index)); + self->traverse_1_cb_rw((void *) &payload, [](void *p, uint64_t index, + const char *variant, + const char *domain) { + return nb::cast( + ((Payload *) p)->c(index, variant, domain)); }); }); return cls; } +/** + * \brief This function traverses a python object, that inherits from a + * trampoline class. + * + * Internally, this function calls the ``traverse_py_cb_ro_impl`` function, + * exposed through ``drjit.detail``, with the object and the callback. + */ +inline void traverse_py_cb_ro(const TraversableBase *base, void *payload, + void (*fn)(void *, uint64_t, const char *variant, + const char *domain)) { + namespace nb = nanobind; + nb::handle self = base->self_py(); + if (!self) + return; + + auto detail = nb::module_::import_("drjit.detail"); + nb::callable traverse_py_cb_ro_fn = + nb::borrow(nb::getattr(detail, "traverse_py_cb_ro")); + + traverse_py_cb_ro_fn(self, + nb::cpp_function([&](uint64_t index, const char *variant, + const char *domain) { + fn(payload, index, variant, domain); + })); +} + +/** + * \brief This function traverses a python object, that inherits from a + * trampoline class. + * + * Internally, this function calls the ``traverse_py_cb_rw_impl`` function, + * exposed through ``drjit.detail``, with the object and the callback. + */ +inline void traverse_py_cb_rw(TraversableBase *base, void *payload, + uint64_t (*fn)(void *, uint64_t, const char *, + const char *)) { + + namespace nb = nanobind; + nb::handle self = base->self_py(); + if (!self) + return; + + auto detail = nb::module_::import_("drjit.detail"); + nb::callable traverse_py_cb_rw_fn = + nb::borrow(nb::getattr(detail, "traverse_py_cb_rw")); + + traverse_py_cb_rw_fn(self, + nb::cpp_function([&](uint64_t index, const char *variant, + const char *domain) { + return fn(payload, index, variant, domain); + })); +} + NAMESPACE_END(drjit) diff --git a/include/drjit/texture.h b/include/drjit/texture.h index ff6b6f3a0..6ed5b3f81 100644 --- a/include/drjit/texture.h +++ b/include/drjit/texture.h @@ -18,6 +18,8 @@ #include #include #include +#include +#include #pragma once @@ -42,7 +44,7 @@ enum class CudaTextureFormat : uint32_t { Float16 = 1, /// Half precision storage format }; -template class Texture { +template class Texture : TraversableBase { public: static constexpr bool IsCUDA = is_cuda_v; static constexpr bool IsDiff = is_diff_v; @@ -1591,6 +1593,48 @@ template class Texture { mutable bool m_tensor_dirty = false; /* Flag to indicate whether public-facing unpadded tensor needs to be updated */ + +public: + void + traverse_1_cb_ro(void *payload, + drjit ::detail ::traverse_callback_ro fn) const override { + // Traverse the function to react to changes when freezing code via + // @dr.freeze. In all other contexts, the texture is read-only and does + // not require traversal + if (!jit_flag(JitFlag::EnableObjectTraversal)) + return; + + DRJIT_MAP(DR_TRAVERSE_MEMBER_RO, m_value, m_unpadded_value, + m_resolution_opaque, m_inv_resolution); + if constexpr (HasCudaTexture) { + uint32_t n_textures = 1 + ((m_channels - 1) / 4); + std::vector indices(n_textures); + jit_cuda_tex_get_indices(m_handle, indices.data()); + for (uint32_t i = 0; i < n_textures; i++) + fn(payload, indices[i], "", ""); + } + } + void traverse_1_cb_rw(void *payload, + drjit ::detail ::traverse_callback_rw fn) override { + // Only traverse the texture for frozen functions, since accidentally + // traversing the scene in loops or vcalls can cause issues. + if (!jit_flag(JitFlag::EnableObjectTraversal)) + return; + + DRJIT_MAP(DR_TRAVERSE_MEMBER_RW, m_value, m_unpadded_value, + m_resolution_opaque, m_inv_resolution); + if constexpr (HasCudaTexture) { + uint32_t n_textures = 1 + ((m_channels - 1) / 4); + std::vector indices(n_textures); + jit_cuda_tex_get_indices(m_handle, indices.data()); + for (uint32_t i = 0; i < n_textures; i++) { + uint64_t new_index = fn(payload, indices[i], "", ""); + if (new_index != indices[i]) + jit_raise("A texture was changed by traversing it. This is " + "not supported!"); + } + } + } }; NAMESPACE_END(drjit) diff --git a/include/drjit/traversable_base.h b/include/drjit/traversable_base.h new file mode 100644 index 000000000..3b50b401b --- /dev/null +++ b/include/drjit/traversable_base.h @@ -0,0 +1,260 @@ +#pragma once + +#include "fwd.h" +#include +#include +#include +#include +#include +#include + +NAMESPACE_BEGIN(drjit) + +NAMESPACE_BEGIN(detail) +/** + * \brief The callback used to traverse all JIT arrays of a C++ object. + * + * \param payload: + * To wrap closures, a payload can be provided to the ``traverse_1_cb_ro`` + * function, that is passed to the callback. + * + * \param index: + * A non-owning index of the traversed JIT array. + * + * \param variant: + * If a ``JitArray`` has the attribute ``IsClass`` it is referring to a + * drjit class. When such a variable is traversed, the ``variant`` and + * ``domain`` string of its ``CallSupport`` is provided to the callback + * using this argument. Otherwise the string is equal to "". + * + * \param domain: + * The domain of the ``CallSupport`` when traversing a class variable. + */ +using traverse_callback_ro = void (*)(void *payload, uint64_t index, + const char *variant, const char *domain); +/** + * \brief The callback used to traverse and modify all JIT arrays of a C++ object. + * + * \param payload: + * To wrap closures, a payload can be provided to the ``traverse_1_cb_ro`` + * function, that is passed to the callback. + * + * \param index: + * A non-owning index of the traversed JIT array. + * + * \param variant: + * If a ``JitArray`` has the attribute ``IsClass`` it is referring to a + * drjit class. When such a variable is traversed, the ``variant`` and + * ``domain`` string of its ``CallSupport`` is provided to the callback + * using this argument. Otherwise the string is equal to "". + * + * \param domain: + * The domain of the ``CallSupport`` when traversing a class variable. + * + * \return + * The new index of the traversed variable. This index is borrowed, and + * should therefore be non-owning. + */ +using traverse_callback_rw = uint64_t (*)(void *payload, uint64_t index, + const char *variant, + const char *domain); + +inline void log_member_open(bool rw, const char *member) { + DRJIT_MARK_USED(rw); + DRJIT_MARK_USED(member); +#ifndef NDEBUG + jit_log(LogLevel::Debug, "%s%s{", rw ? "rw " : "ro ", member); +#endif +} + +inline void log_member_close() { +#ifndef NDEBUG + jit_log(LogLevel::Debug, "}"); +#endif +} + +NAMESPACE_END(detail) + +/** + * \brief Interface for traversing C++ objects. + * + * This interface should be inherited by any class that can be added to the + * registry. We try to ensure this by wrapping the function ``jit_registry_put`` + * in the function ``drjit::registry_put`` that takes a ``TraversableBase`` for + * the pointer argument. + */ +struct DRJIT_EXTRA_EXPORT TraversableBase : public nanobind::intrusive_base { + /** + * \brief Traverse all JIT arrays in this c++ object. For every jit + * variable, the callback should be called, with the provided payload + * pointer. + * + * \param payload: + * A pointer to a payload struct. The callback ``cb`` is called with this + * pointer. + * + * \param cb: + * A function pointer, that is called with the ``payload`` pointer, the + * index of the jit variable, and optionally the domain and variant of a + * ``Class`` variable. + */ + virtual void traverse_1_cb_ro(void *payload, + detail::traverse_callback_ro cb) const = 0; + + /** + * \brief Traverse all JIT arrays in this c++ object, and assign the output of the + * callback to them. For every jit variable, the callback should be called, + * with the provided payload pointer. + * + * \param payload: + * A pointer to a payload struct. The callback ``cb`` is called with this + * pointer. + * + * \param cb: + * A function pointer, that is called with the ``payload`` pointer, the + * index of the jit variable, and optionally the domain and variant of a + * ``Class`` variable. The resulting index of calling this function + * pointer will be assigned to the traversed variable. The return value + * of the is borrowed from when overwriting assigning the traversed + * variable. + */ + virtual void traverse_1_cb_rw(void *payload, + detail::traverse_callback_rw cb) = 0; +}; + +/** + * \brief Macro for generating call to \c traverse_1_fn_ro for a class member. + * + * This is only a utility macro, for the DR_TRAVERSE_CB_RO macro. It can only be + * used in a context, where the ``payload`` and ``fn`` variables are present. + */ +#define DR_TRAVERSE_MEMBER_RO(member) \ + drjit::detail::log_member_open(false, #member); \ + drjit::traverse_1_fn_ro(member, payload, fn); \ + drjit::detail::log_member_close(); + +/** + * \brief Macro for generating call to \c traverse_1_fn_rw for a class member. + * + * This is only a utility macro, for the DR_TRAVERSE_CB_RW macro. It can only be + * used in a context, where the ``payload`` and ``fn`` variables are present. + */ +#define DR_TRAVERSE_MEMBER_RW(member) \ + drjit::detail::log_member_open(true, #member); \ + drjit::traverse_1_fn_rw(member, payload, fn); \ + drjit::detail::log_member_close(); + +/** + * \brief Macro, generating the implementation of the ``traverse_1_cb_ro`` + * method. + * + * The first argument should be the base class, from which the current class + * inherits. The other arguments should list all members of that class, which + * are supposed to be read only traversable. + */ +#define DR_TRAVERSE_CB_RO(Base, ...) \ + void traverse_1_cb_ro(void *payload, \ + drjit::detail::traverse_callback_ro fn) \ + const override { \ + static_assert( \ + std::is_base_of>::value); \ + if constexpr (!std::is_same_v) \ + Base::traverse_1_cb_ro(payload, fn); \ + DRJIT_MAP(DR_TRAVERSE_MEMBER_RO, __VA_ARGS__) \ + } + +/** + * \breif Macro, generating the implementation of the ``traverse_1_cb_rw`` + * method. + * + * The first argument should be the base class, from which the current class + * inherits. The other arguments should list all members of that class, which + * are supposed to be read and write traversable. + */ +#define DR_TRAVERSE_CB_RW(Base, ...) \ + void traverse_1_cb_rw(void *payload, \ + drjit::detail::traverse_callback_rw fn) override { \ + static_assert( \ + std::is_base_of>::value); \ + if constexpr (!std::is_same_v) \ + Base::traverse_1_cb_rw(payload, fn); \ + DRJIT_MAP(DR_TRAVERSE_MEMBER_RW, __VA_ARGS__) \ + } + +/** + * \brief Macro, generating the both the implementations of the + * ``traverse_1_cb_ro`` and ``traverse_1_cb_rw`` methods. + * + * The first argument should be the base class, from which the current class + * inherits. The other arguments should list all members of that class, which + * are supposed to be read and write traversable. + */ +#define DR_TRAVERSE_CB(Base, ...) \ +public: \ + DR_TRAVERSE_CB_RO(Base, __VA_ARGS__) \ + DR_TRAVERSE_CB_RW(Base, __VA_ARGS__) + +/** + * \brief Macro, generating the implementations of ``traverse_1_cb_ro`` and + * ``traverse_1_cb_rw`` of a nanobind trampoline class. + * + * This macro should only be instantiated on trampoline classes, that serve as + * the base class for derived types in Python. Adding this macro to a trampoline + * class, allows for the automatic traversal of all python members in any + * derived python class. + */ +#define DR_TRAMPOLINE_TRAVERSE_CB(Base) \ +public: \ + void traverse_1_cb_ro(void *payload, \ + drjit::detail::traverse_callback_ro fn) \ + const override { \ + DRJIT_MARK_USED(payload); \ + DRJIT_MARK_USED(fn); \ + if constexpr (!std ::is_same_v) \ + Base::traverse_1_cb_ro(payload, fn); \ + drjit::traverse_py_cb_ro(this, payload, fn); \ + } \ + void traverse_1_cb_rw(void *payload, \ + drjit::detail::traverse_callback_rw fn) override { \ + DRJIT_MARK_USED(payload); \ + DRJIT_MARK_USED(fn); \ + if constexpr (!std ::is_same_v) \ + Base::traverse_1_cb_rw(payload, fn); \ + drjit::traverse_py_cb_rw(this, payload, fn); \ + } + +/** + * \brief Register a \c TraversableBase pointer with Dr.Jit's pointer registry + * + * This should be used instead of \c jit_registry_put, as it enforces the + * pointers to be of type \c TraversableBase, allowing for traversal of registry + * bound pointers. + * + * Dr.Jit provides a central registry that maps registered pointer values to + * low-valued 32-bit IDs. The main application is efficient virtual function + * dispatch via \ref jit_var_call(), through the registry could be used for + * other applications as well. + * + * This function registers the specified pointer \c ptr with the registry, + * returning the associated ID value, which is guaranteed to be unique within + * the specified domain identified by the \c (variant, domain) strings. + * The domain is normally an identifier that is associated with the "flavor" + * of the pointer (e.g. instances of a particular class), and which ensures + * that the returned ID values are as low as possible. + * + * Caution: for reasons of efficiency, the \c domain parameter is assumed to a + * static constant that will remain alive. The RTTI identifier + * typeid(MyClass).name() is a reasonable choice that satisfies this + * requirement. + * + * Raises an exception when ``ptr`` is ``nullptr``, or when it has already been + * registered with *any* domain. + */ +inline uint32_t registry_put(const char *variant, const char *domain, + TraversableBase *ptr) { + return jit_registry_put(variant, domain, (void *) ptr); +} + +NAMESPACE_END(drjit) diff --git a/src/extra/autodiff.cpp b/src/extra/autodiff.cpp index 07d249a59..ba2074b8a 100644 --- a/src/extra/autodiff.cpp +++ b/src/extra/autodiff.cpp @@ -43,7 +43,8 @@ */ #include "common.h" -#include "drjit-core/jit.h" +#include +#include #include #include #include @@ -1839,6 +1840,17 @@ void ad_scope_leave(bool process_postponed) { } } +void ad_scope_postponed(drjit::vector *dst) { + LocalState &ls = local_state; + std::vector &scopes = ls.scopes; + if (scopes.empty()) + ad_raise("ad_scope_leave(): scope underflow!"); + Scope &scope = scopes.back(); + + for (auto &er : scope.postponed) + dst->push_back(er.target); +} + /// Check if gradient tracking is enabled for the given variable int ad_grad_enabled(Index index) { ADIndex ad_index = ::ad_index(index); diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 43ba1cfaf..bacf98c08 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -68,6 +68,7 @@ nanobind_add_module( reduce.h reduce.cpp apply.h apply.cpp eval.h eval.cpp + freeze.h freeze.cpp memop.h memop.cpp slice.h slice.cpp dlpack.h dlpack.cpp diff --git a/src/python/apply.cpp b/src/python/apply.cpp index fbfc0c471..18674dc44 100644 --- a/src/python/apply.cpp +++ b/src/python/apply.cpp @@ -591,27 +591,34 @@ nb::object apply_ret_pair(ArrayOp op, const char *name, nb::handle_t= 50) { + if (recursion_level >= 50) { PyErr_SetString(PyExc_RecursionError, "runaway recursion detected"); nb::raise_python_error(); } + // NOTE: the recursion_level has to be incremented after potentially + // throwing an exception, as throwing an exception in the constructor + // prevents the destructor from being called. + recursion_level++; } ~recursion_guard() { recursion_level--; } }; +} // namespace -void TraverseCallback::operator()(uint64_t) { } +uint64_t TraverseCallback::operator()(uint64_t, const char *, const char *) { return 0; } void TraverseCallback::traverse_unknown(nb::handle) { } /// Invoke the given callback on leaf elements of the pytree 'h' -void traverse(const char *op, TraverseCallback &tc, nb::handle h) { - nb::handle tp = h.type(); +void traverse(const char *op, TraverseCallback &tc, nb::handle h, bool rw) { recursion_guard guard; + nb::handle tp = h.type(); + try { if (is_drjit_type(tp)) { const ArraySupplement &s = supp(tp); @@ -623,29 +630,60 @@ void traverse(const char *op, TraverseCallback &tc, nb::handle h) { len = s.len(inst_ptr(h)); for (Py_ssize_t i = 0; i < len; ++i) - traverse(op, tc, nb::steal(s.item(h.ptr(), i))); + traverse(op, tc, nb::steal(s.item(h.ptr(), i)), rw); } else { tc(h); } } else if (tp.is(&PyTuple_Type)) { for (nb::handle h2 : nb::borrow(h)) - traverse(op, tc, h2); + traverse(op, tc, h2, rw); } else if (tp.is(&PyList_Type)) { for (nb::handle h2 : nb::borrow(h)) - traverse(op, tc, h2); + traverse(op, tc, h2, rw); } else if (tp.is(&PyDict_Type)) { for (nb::handle h2 : nb::borrow(h).values()) - traverse(op, tc, h2); + traverse(op, tc, h2, rw); } else if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { for (auto [k, v] : ds) - traverse(op, tc, nb::getattr(h, k)); + traverse(op, tc, nb::getattr(h, k), rw); } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { for (nb::handle field : df) { nb::object k = field.attr(DR_STR(name)); - traverse(op, tc, nb::getattr(h, k)); + traverse(op, tc, nb::getattr(h, k), rw); + } + } else if (auto traversable = get_traversable_base(h); traversable) { + struct Payload { + TraverseCallback &tc; + }; + Payload p{ tc }; + if (rw) { + traversable->traverse_1_cb_rw( + (void *) &p, + [](void *p, uint64_t index, const char *variant, + const char *domain) -> uint64_t { + Payload *payload = (Payload *) p; + uint64_t new_index = + payload->tc(index, variant, domain); + return new_index; + }); + } else { + traversable->traverse_1_cb_ro( + (void *) &p, [](void *p, uint64_t index, + const char *variant, const char *domain) { + Payload *payload = (Payload *) p; + payload->tc(index, variant, domain); + }); } - } else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid()) { - cb(h, nb::cpp_function([&](uint64_t index) { tc(index); })); + } else if (auto cb = get_traverse_cb_ro(tp); cb.is_valid() && !rw) { + cb(h, nb::cpp_function( + [&](uint64_t index, const char *variant, + const char *domain) { tc(index, variant, domain); })); + } else if (nb::object cb = get_traverse_cb_rw(tp); + cb.is_valid() && rw) { + cb(h, nb::cpp_function([&](uint64_t index, const char *variant, + const char *domain) { + return tc(index, variant, domain); + })); } else { tc.traverse_unknown(h); } @@ -897,14 +935,15 @@ nb::object transform(const char *op, TransformCallback &tc, nb::handle h) { nb::object tmp = nb::dict(); for (nb::handle field : df) { nb::object k = field.attr(DR_STR(name)); - tmp[k] = transform(op, tc, nb::getattr(h, k)); + tmp[k] = transform(op, tc, nb::getattr(h, k)); } result = tp(**tmp); } else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid()) { - cb(h, nb::cpp_function([&](uint64_t index) { return tc(index); })); + cb(h, nb::cpp_function([&](uint64_t index, const char *, + const char *) { return tc(index); })); result = nb::borrow(h); } else if (!result.is_valid()) { - result = tc.transform_unknown(h); + result = tc.transform_unknown(h); } return result; } catch (nb::python_error &e) { diff --git a/src/python/apply.h b/src/python/apply.h index 8e67ee089..c30289640 100644 --- a/src/python/apply.h +++ b/src/python/apply.h @@ -57,7 +57,8 @@ struct TraverseCallback { // Type-erased form which is needed in some cases to traverse into opaque // C++ code. This one just gets called with Jit/AD variable indices, an // associated Python/ instance/type is not available. - virtual void operator()(uint64_t index); + virtual uint64_t operator()(uint64_t index, const char *variant = nullptr, + const char *domain = nullptr); // Traverse an unknown object virtual void traverse_unknown(nb::handle h); @@ -80,9 +81,16 @@ struct TransformCallback { /// Initialize 'h2' (already allocated) based on 'h1' virtual void operator()(nb::handle h1, nb::handle h2) = 0; - // Type-erased form which is needed in some cases to traverse into opaque - // C++ code. This one just gets called with Jit/AD variable indices, an - // associated Python/ instance/type is not available. + /** Type-erased form which is needed in some cases to traverse into opaque + * C++ code. This one just gets called with Jit/AD variable indices, an + * associated Python/ instance/type is not available. + * This can optionally return a non-owning jit_index, that will be assigned + * to the underlying variable if \c traverse is called with the \c rw + * argument set to \c true. This can be used to modify JIT variables of + * PyTrees and their C++ objects in-place. For example, when applying + * operations such as \c jit_var_schedule_force to every JIT variable in a + * PyTree. + */ virtual uint64_t operator()(uint64_t index); }; @@ -96,9 +104,27 @@ struct TransformPairCallback { virtual nb::object transform_unknown(nb::handle h1, nb::handle h2) const; }; -/// Invoke the given callback on leaf elements of the pytree 'h' -extern void traverse(const char *op, TraverseCallback &callback, - nb::handle h); +/** + * \brief Invoke the given callback on leaf elements of the pytree 'h', + * including JIT indices in c++ objects, inheriting from + * \c drjit::TraversableBase. + * + * \param op: + * Name of the operation that is performed, this will be used in the + * exceptions that might be raised during traversal. + * + * \param callback: + * The \c TraverseCallback, called for every Jit variable in the pytree. + * + * \param rw: + * Boolean, indicating if C++ objects should be traversed in read-write + * mode. If this is set to \c true, the result from the method + * \c operator()(uint64_t) of the callback will be assigned to the + * underlying variable. This does not change how Python objects are + * traversed. + */ +extern void traverse(const char *op, TraverseCallback &callback, nb::handle h, + bool rw = false); /// Parallel traversal of two compatible pytrees 'h1' and 'h2' extern void traverse_pair(const char *op, TraversePairCallback &callback, diff --git a/src/python/common.h b/src/python/common.h index 215e840ed..07fc5c854 100644 --- a/src/python/common.h +++ b/src/python/common.h @@ -88,6 +88,13 @@ inline nb::object get_dataclass_fields(nb::handle tp) { } return result; } +/// Return a pointer to the underlying C++ class if the Python object inherits +/// from TraversableBase or null otherwise +inline drjit::TraversableBase *get_traversable_base(nb::handle h) { + drjit::TraversableBase *result = nullptr; + nb::try_cast(h, result); + return result; +} /// Extract a read-only callback to traverse custom data structures inline nb::object get_traverse_cb_ro(nb::handle tp) { diff --git a/src/python/detail.cpp b/src/python/detail.cpp index d9fc2cb26..158d71954 100644 --- a/src/python/detail.cpp +++ b/src/python/detail.cpp @@ -111,13 +111,15 @@ void collect_indices(nb::handle h, dr::vector &indices, bool inc_ref) void operator()(nb::handle h) override { auto index_fn = supp(h.type()).index; if (index_fn) - operator()(index_fn(inst_ptr(h))); + operator()(index_fn(inst_ptr(h)), nullptr, nullptr); } - void operator()(uint64_t index) override { + uint64_t operator()(uint64_t index, const char *, + const char *) override { if (inc_ref) ad_var_inc_ref(index); result.push_back(index); + return 0; } }; @@ -281,6 +283,112 @@ bool leak_warnings() { return nb::leak_warnings() || jit_leak_warnings() || ad_leak_warnings(); } +// Have to wrap this in an unnamed namespace to prevent collisions with the +// other declaration of ``recursion_guard``. +namespace { +static int recursion_level = 0; + +// PyTrees could theoretically include cycles. Catch infinite recursion below +struct recursion_guard { + recursion_guard() { + if (recursion_level >= 50) { + PyErr_SetString(PyExc_RecursionError, "runaway recursion detected"); + nb::raise_python_error(); + } + // NOTE: the recursion_level has to be incremented after potentially + // throwing an exception, as throwing an exception in the constructor + // prevents the destructor from being called. + recursion_level++; + } + ~recursion_guard() { recursion_level--; } +}; +} // namespace + +/** + * \brief Traverses all variables of a python object. + * + * This function is used to traverse variables of python objects, inheriting + * from trampoline classes. This allows the user to freeze a custom python + * version of a C++ class, without having to declare its variables. + */ +void traverse_py_cb_ro_impl(nb::handle self, nb::callable c) { + recursion_guard guard; + + struct PyTraverseCallback : TraverseCallback { + void operator()(nb::handle h) override { + const ArraySupplement &s = supp(h.type()); + auto index_fn = s.index; + if (index_fn){ + if (s.is_class){ + nb::str variant = + nb::borrow(nb::getattr(h, "Variant")); + nb::str domain = + nb::borrow(nb::getattr(h, "Domain")); + operator()(index_fn(inst_ptr(h)), variant.c_str(), + domain.c_str()); + } + else + operator()(index_fn(inst_ptr(h)), "", ""); + } + } + uint64_t operator()(uint64_t index, const char *variant, + const char *domain) override { + m_callback(index, variant, domain); + return 0; + } + nb::callable m_callback; + + PyTraverseCallback(nb::callable c) : m_callback(c) {} + }; + + PyTraverseCallback traverse_cb(std::move(c)); + + auto dict = nb::borrow(nb::getattr(self, "__dict__")); + for (auto value : dict.values()) + traverse("traverse_py_cb_ro", traverse_cb, value); +} + +/** + * \brief Traverses all variables of a python object. + * + * This function is used to traverse variables of python objects, inheriting + * from trampoline classes. This allows the user to freeze a custom python + * version of a C++ class, without having to declare its variables. + */ +void traverse_py_cb_rw_impl(nb::handle self, nb::callable c) { + recursion_guard guard; + + struct PyTraverseCallback : TraverseCallback { + void operator()(nb::handle h) override { + const ArraySupplement &s = supp(h.type()); + auto index_fn = s.index; + if (index_fn){ + uint64_t new_index; + if (s.is_class) { + nb::str variant = + nb::borrow(nb::getattr(h, "Variant")); + nb::str domain = nb::borrow(nb::getattr(h, "Domain")); + new_index = operator()(index_fn(inst_ptr(h)), + variant.c_str(), domain.c_str()); + } else + new_index = operator()(index_fn(inst_ptr(h)), "", ""); + s.reset_index(new_index, inst_ptr(h)); + } + } + uint64_t operator()(uint64_t index, const char *variant, const char *domain) override { + return nb::cast(m_callback(index, variant, domain)); + } + nb::callable m_callback; + + PyTraverseCallback(nb::callable c) : m_callback(c) {} + }; + + PyTraverseCallback traverse_cb(std::move(c)); + + auto dict = nb::borrow(nb::getattr(self, "__dict__")); + for (auto value : dict.values()) + traverse("traverse_py_cb_rw", traverse_cb, value, true); +} void export_detail(nb::module_ &) { nb::module_ d = nb::module_::import_("drjit.detail"); @@ -351,6 +459,8 @@ void export_detail(nb::module_ &) { d.def("leak_warnings", &leak_warnings, doc_leak_warnings); d.def("set_leak_warnings", &set_leak_warnings, doc_set_leak_warnings); + d.def("traverse_py_cb_ro", &traverse_py_cb_ro_impl); + d.def("traverse_py_cb_rw", traverse_py_cb_rw_impl); trace_func_handle = d.attr("trace_func"); } diff --git a/src/python/docstr.rst b/src/python/docstr.rst index 266e07a0d..b06ce1070 100644 --- a/src/python/docstr.rst +++ b/src/python/docstr.rst @@ -6056,6 +6056,35 @@ This flag is *enabled* by default. +.. topic:: JitFlag_KernelFreezing + + Enable recording and replay of functions annotated with :py:func:`freeze`. + + If KernelFreezing is enabled, all Dr.Jit operations executed in a function + annotated with :py:func:`freeze` are recorded during its first execution + and replayed without re-tracing on subsequent calls. + + If this flag is disabled, replay of previously frozen functions is disabled + as well. + +.. topic:: JitFlag_FreezingScope + + This flag is set to ``True`` when Dr.Jit is currently recording a frozen + function. The flag is automatically managed and should not be updated by + application code. + + User code may query this flag to conditionally optimize kernels for frozen + function recording, such as re-seeding the sampler, used for rendering. + +.. topic:: JitFlag_EnableObjectTraversal + + This flag is set to ``True`` when Dr.Jit is currently traversing + inputs and outputs of a frozen function. The flag is automatically managed + and should not be updated by application code. + + When enabled, traversal of complex objects, that usually are opaque to + loops and conditionals, is enabled. + .. topic:: JitFlag_Default The default set of optimization flags consisting of @@ -8136,7 +8165,7 @@ tensor, or iterable along the specified axis/axes. The function returns an output array of the same shape as the input. The - ``op`` paramater selects the operation to be performed. + ``op`` parameter selects the operation to be performed. For example, when reducing a 1D array using ``exclusive=True`` (the default), this produces the following output diff --git a/src/python/freeze.cpp b/src/python/freeze.cpp new file mode 100644 index 000000000..fbbd65d73 --- /dev/null +++ b/src/python/freeze.cpp @@ -0,0 +1,2167 @@ +/** + * This file implements the frontend for the frozen function feature. + * The `FrozenFunction` class represents a function that has been annotated with + * the `@dr.freeze` decorator. When calling the `operator()` of this object for + * the first time, the wrapped callable is recorded. On subsequent calls, the + * input is checked, and if compatible, the previously recorded function is + * replayed. The `FrozenFunction` class should not be used directly, and is + * wrapped in a higher-level class in `__init__.py`. + * + * When calling the `operator()` of a `FrozenFunction` object, the + * inputs of the function have to be traversed. This collects the JIT variables + * in the input and information about the layout of the PyTree. We store + * both in a `FlatVariables` struct. To evaluate all side effects, and so that + * the freezing backend can handle these JIT variables, they are evaluated. + * + * Only evaluated variables can change between calls to the function + * without re-tracing it. Therefore, we need to make changing literals opaque. + * The auto-opaque feature detects changes in literal variables from one call of + * the function to another. For this reason, traversal of the input is split up + * into six steps. These steps are performed every time the frozen function is + * called, and depending on the result, a previous recording can be replayed or a + * new recording has to be made. + * + * 1. The input is traversed using the `FlatVariables::traverse_with_registry` + * function. It traverses the PyTree, including a subset of the registry when + * class variables are present in the input (i.e., pointers to classes with + * Dr.Jit virtual function calls). Information about the layout of the PyTree + * is stored in the `Layout` structs of the `layout` vector of the + * `FlatVariables` class. If a JIT variable is encountered during this + * traversal step, its index will be stored in the `index` field of the + * corresponding `Layout` entry. + * 2. If the `auto_opaque` feature is enabled, the key of the previous iteration + * is checked to see if it is compatible with the currently recorded key. If this + * is the case, the `opaque_mask` of the last iteration will be used in the + * next step; otherwise, it will be cleared. + * 3. The flattened input variables are traversed by + * `FlatVariables::schedule_jit_variables`, and all JIT variables are either + * scheduled or force-scheduled, depending on the `opaque_mask` value for + * that layout node. After scheduling the variables, they are evaluated, + * clearing all side effects. + * 4. In order to be able to record or replay the function, an array of + * deduplicated indices of evaluated JIT variables has to be provided to + * either `jit_freeze_start` or `jit_freeze_replay` respectively. The + * function `FlatVariables::record_jit_variables` iterates over the flattened + * variables, deduplicates indices of evaluated JIT variables, and + * additionally records information about them that only becomes available + * after evaluation. Additional information is stored in the + * `FlatVariables::var_layout` vector. + * 5. If the `auto_opaque` feature is enabled, we create a mask of literal + * variables that have to be evaluated because they changed from one call + * to another. To this end, we use the `FlatVariables::fill_opaque_mask` + * function, which iterates over the flattened variables and finds such + * literals, setting the corresponding boolean to true. + * 6. If we detect that the number of such literal variables changed, steps 1 + * to 5 are repeated one time. + * + * After traversing the input, the frozen function decides whether to record a + * new version of the function or replay an old version. The + * `FrozenFunction::recordings` hashmap is used to look up old versions of the + * function. The flattened input PyTree serves as a key to the hashmap. However, + * only a subset of the flattened PyTree is used for the key. Refer to the + * `FlatVariables::operator==` function to see which changes to the input can + * cause the function to be re-traced. + * + * If no matching recording was found in the hashmap, the following steps will + * be executed to record the function. Steps 2 to 6 are located in the + * `FunctionRecording::record` function and can be invoked from the + * `FunctionRecording::replay` function as well if a dry-run failed. + * + * 1. The flattened version of the input is assigned back to the input PyTree. + * This is required so that evaluating variables is reflected in the PyTree. + * 2. Kernel recording is started with the `jit_freeze_start` function by + * providing it with the flattened evaluated variables of the input. + * 3. The inner function is executed while recording all launched kernels. + * 4. While kernels are still recorded, the outputs as well as the inputs of the + * function are traversed and collected into a single `FlatVariables` + * object. The inputs could have changed when executing the function, and we + * handle these new inputs in a similar manner to outputs of the function. + * Analogous to the above section, the outputs and new inputs are traversed, + * scheduled, evaluated, and recorded. + * 5. Using the deduplicated JIT indices collected by traversing the outputs, + * kernel freezing is stopped with the `jit_freeze_stop` function. + * 6. Finally, we catch any potential errors when assigning variables or + * constructing outputs early (i.e., after recording a function rather than + * after replaying it later). Therefore, we assign the flattened inputs to + * the input PyTree and construct the output from its flattened version. For + * assigning the input variables, `FlatVariables::assign_with_registry` is + * used. For constructing the output, `FlatVariables::construct` is used. + * The layout of the output is also stored in the recording. + * + * In the above case, the new recording is added to the `recordings` hashmap + * with the input `FlatVariables` as the key. + * + * If, on the other hand, the function has already been recorded with compatible + * inputs, this recording will be replayed. The following steps outline + * replaying a function recording: + * + * 1. A dry-run of the recording is launched if this is required, using + * `jit_freeze_dry_run`. + * 2. If the dry-run failed, the current recording will be overwritten by + * calling `FunctionRecording::record` on this recording, executing steps 2 + * to 6 of the above section. + * 3. If the dry-run succeeded, the recording will be replayed by calling + * `jit_freeze_replay` with the flattened evaluated variables of the input + * PyTree. The `variables` field of the output flat variables will be used to + * store the output variables from the replay. + * 4. The output from replaying the recording as well as the stored layout is + * used to both construct the output PyTree and assign the JIT + * variables to the input PyTree. + */ +#include "freeze.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "autodiff.h" +#include "base.h" +#include "common.h" +#include "listobject.h" +#include "object.h" +#include "pyerrors.h" +#include "reduce.h" +#include "shape.h" +#include "tupleobject.h" + +#include "../ext/nanobind/src/buffer.h" + +using Buffer = nanobind::detail::Buffer; + +/** + * \brief Helper struct to profile and log frozen functions. + */ +struct ProfilerPhase { + std::string m_message; + ProfilerPhase(const char *message) : m_message(message) { + jit_log(LogLevel::Debug, "profiler start: %s", message); +#if defined(DRJIT_ENABLE_NVTX) + jit_profile_range_push(message); +#endif + } + + ProfilerPhase(const drjit::TraversableBase *traversable) { + char message[1024] = { 0 }; + const char *name = typeid(*traversable).name(); + snprintf(message, 1024, "traverse_cb %s", name); + + jit_log(LogLevel::Debug, "profiler start: %s", message); + jit_profile_range_push(message); + m_message = message; + } + + ~ProfilerPhase() { +#if defined(DRJIT_ENABLE_NVTX) + jit_profile_range_pop(); +#endif + jit_log(LogLevel::Debug, "profiler end: %s", m_message.c_str()); + } +}; + +struct ADScopeContext { + bool process_postponed; + ADScopeContext(drjit::ADScope type, size_t size, const uint64_t *indices, + int symbolic, bool process_postponed) + : process_postponed(process_postponed) { + ad_scope_enter(type, size, indices, symbolic); + } + ~ADScopeContext() { ad_scope_leave(process_postponed); } +}; + +struct scoped_set_flag { + uint32_t backup; + scoped_set_flag(JitFlag flag, bool enabled) : backup(jit_flags()) { + uint32_t flags = backup; + if (enabled) + flags |= (uint32_t) flag; + else + flags &= ~(uint32_t) flag; + + jit_set_flags(flags); + } + + ~scoped_set_flag() { jit_set_flags(backup); } +}; + +struct state_lock_guard { + state_lock_guard() { jit_state_lock(); } + ~state_lock_guard() { jit_state_unlock(); } +}; +struct state_unlock_guard { + state_unlock_guard() { jit_state_unlock(); } + ~state_unlock_guard() { jit_state_lock(); } +}; + +using namespace detail; + +bool Layout::operator==(const Layout &rhs) const { + if (((bool) this->type != (bool) rhs.type) || !(this->type.equal(rhs.type))) + return false; + + if (this->num != rhs.num) + return false; + + if (this->fields.size() != rhs.fields.size()) + return false; + + for (uint32_t i = 0; i < this->fields.size(); ++i) { + if (!(this->fields[i].equal(rhs.fields[i]))) + return false; + } + + if (this->index != rhs.index) + return false; + + if (this->flags != rhs.flags) + return false; + + if (this->literal != rhs.literal) + return false; + + if (this->vt != rhs.vt) + return false; + + if (((bool) this->py_object != (bool) rhs.py_object) || + !this->py_object.equal(rhs.py_object)) + return false; + + return true; +} + +bool VarLayout::operator==(const VarLayout &rhs) const { + if (this->vt != rhs.vt) + return false; + + if (this->vs != rhs.vs) + return false; + + if (this->flags != rhs.flags) + return false; + + if (this->size_index != rhs.size_index) + return false; + + return true; +} + +/** + * \brief Add a variant domain pair to be traversed using the registry. + * + * When traversing a jit variable, that references a pointer to a class, + * such as a BSDF or Shape in Mitsuba, we have to traverse all objects + * registered with that variant-domain pair in the registry. This function + * adds the variant-domain pair, deduplicating the domain. Whether a + * variable references a class is represented by its ``IsClass`` const + * attribute. If the domain is an empty string (""), this function skips + * adding the variant-domain pair. + */ +void FlatVariables::add_domain(const char *variant, const char *domain) { + // Since it is not possible to pass nullptr strings to nanobind functions we + // assume, that a valid domain indicates a valid variant. If the variant is + // empty at the end of traversal, we know that no Class variable was + // traversed, and registry traversal is not necessary. + if (domain && variant && strcmp(domain, "") != 0) { + jit_log(LogLevel::Debug, "variant=%s, domain=%s", variant, domain); + + if (domains.empty()) { + this->variant = variant; + } else if (this->variant != variant) + jit_raise("traverse(): Variant mismatch! All arguments to a " + "frozen function have to have the same variant. " + "Variant %s of a previous argument does not match " + "variant %s of this argument.", + this->variant.c_str(), variant); + + bool contains = false; + for (std::string &d : domains) { + if (d == domain) { + contains = true; + break; + } + } + if (!contains) + domains.push_back(domain); + } +} + +/** + * Adds a jit index to the flattened array, deduplicating it. + * This allows to check for aliasing conditions, where two variables + * actually refer to the same index. The function should only be called for + * scheduled non-literal variable indices. + */ +uint32_t FlatVariables::add_jit_index(uint32_t index) { + uint32_t next_slot = this->variables.size(); + auto result = this->index_to_slot.try_emplace(index, next_slot); + auto it = result.first; + bool inserted = result.second; + + if (inserted) { + this->variables.push_back(index); + // Borrow the variable + jit_var_inc_ref(index); + this->var_layout.emplace_back(); + return next_slot; + } else { + return it.value(); + } +} + +/** + * + * The auto opaque feature, is able to track changes of literal values in the + * PyTree between calls to the function. If a literal value changes between two + * calls, it should be made opaque before the second call. To accomplish this, a + * boolean array (``opaque_mask``) is used, which indicates which variables + * should be made opaque before recording a function. Tracking the literals, + * which should be made opaque can only be done as long as the structure + * of the input PyTree does not change significantly. If such a change is + * detected, the opaque_mask is reset with ``false``, and the next call does not + * force evaluate literals. + * This function is responsible for detecting changes between two flattened + * PyTrees (``FlatVariables``), that force us to reset the ``opaque_mask`` + * array. + */ +bool compatible_auto_opaque(FlatVariables &cur, FlatVariables &prev) { + // NOTE: We only test the size of the layout, as a full test is somewhat + // expensive, and the worst case is that we make too many variables opaque, + // which does not impact correctness. If this causes problems, more + // extensive tests might have to be reintroduced. + if (cur.layout.size() != prev.layout.size()) { + return false; + } + return true; +} + +bool FlatVariables::fill_opaque_mask(FlatVariables &prev, + drjit::vector &opaque_mask) { + // If we notice that only a literal has changed, we can set the + // corresponding bit in the mask, indicating that this literal should be + // made opaque next time. + uint32_t opaque_counter = 0; + bool new_opaques = false; + for (uint32_t i = 0; i < this->layout.size(); i++) { + Layout &layout = this->layout[i]; + Layout &prev_layout = prev.layout[i]; + + bool requires_opaque = + (layout.flags & (uint32_t) LayoutFlag::Literal) && + (prev_layout.flags & (uint32_t) LayoutFlag::Literal) && + (layout.literal != prev_layout.literal || + layout.literal_size != prev_layout.literal_size); + + opaque_mask[i] |= requires_opaque; + new_opaques |= requires_opaque; + opaque_counter += requires_opaque; + } + + jit_log(LogLevel::Debug, + "compare_opaque(): %u variables will be made opaque", + opaque_counter); + + return new_opaques; +} + +void FlatVariables::schedule_jit_variables( + bool schedule_force, const drjit::vector *opaque_mask) { + + ProfilerPhase profiler("schedule_jit_variables"); + for (uint32_t i = layout_index; i < layout.size(); i++) { + Layout &layout = this->layout[i]; + + if (!(layout.flags & (uint32_t) LayoutFlag::JitIndex)) + continue; + + uint32_t index = layout.index; + + int rv = 0; + // Undefined variables (i.e. ones created with ``dr.empty``) are handled + // similarly to literals, and can be allocated when replaying. + if (schedule_force || + (opaque_mask && (*opaque_mask)[i - layout_index])) { + // Returns owning reference + index = jit_var_schedule_force(index, &rv); + } else { + // Schedule and create owning reference + rv = jit_var_schedule(index); + jit_var_inc_ref(index); + } + + VarInfo info = jit_var_info(index); + if (backend == info.backend || this->backend == JitBackend::None) { + backend = info.backend; + } else { + jit_raise("freeze(): backend mismatch error (backend of this " + "variable %s does not match backend of others %s)!", + info.backend == JitBackend::CUDA ? "CUDA" : "LLVM", + backend == JitBackend::CUDA ? "CUDA" : "LLVM"); + } + + if (info.state == VarState::Literal) { + // Special case, where the variable is a literal. + layout.literal = info.literal; + // Store size in index variable, as this is not used for literals. + layout.literal_size = info.size; + layout.vt = (uint32_t) info.type; + layout.literal_index = index; + + layout.flags |= (uint32_t) LayoutFlag::Literal; + } else if (info.state == VarState::Undefined) { + // Special case, where the variable is a literal. + // Store size in index variable, as this is not used for literals. + layout.literal_size = info.size; + layout.vt = (uint32_t) info.type; + layout.literal_index = index; + + layout.flags |= (uint32_t) LayoutFlag::Undefined; + } else { + layout.index = this->add_jit_index(index); + layout.vt = (uint32_t) info.type; + jit_var_dec_ref(index); + } + } + layout_index = layout.size(); +} + +/** + * \brief Records information about jit variables, that have been traversed. + * + * After traversing the PyTree, collecting non-literal indices in + * ``variables`` and evaluating the collected indices, we can collect + * information about the underlying variables. This information is used in + * the key of the ``RecordingMap`` to determine which recording should be + * replayed or if the function has to be re-traced. This function iterates + * over the collected indices and collects that information. + */ +void FlatVariables::record_jit_variables() { + ProfilerPhase profiler("record_jit_variables"); + assert(variables.size() == var_layout.size()); + for (uint32_t i = 0; i < var_layout.size(); i++) { + uint32_t index = variables[i]; + VarLayout &layout = var_layout[i]; + + VarInfo info = jit_var_info(index); + if (info.type == VarType::Pointer) { + // We do not support pointers as inputs. It might be possible with + // some extra handling, but they are never used directly. + jit_raise("Pointer inputs not supported!"); + } + + layout.vs = info.state; + layout.vt = info.type; + layout.size_index = this->add_size(info.size); + + if (info.state == VarState::Evaluated) { + // Special case, handling evaluated/opaque variables. + + layout.flags |= + (info.size == 1 ? (uint32_t) LayoutFlag::SingletonArray : 0); + layout.flags |= + (info.unaligned ? (uint32_t) LayoutFlag::Unaligned : 0); + + } else { + jit_raise("collect(): found variable %u in unsupported state %u!", + index, (uint32_t) info.state); + } + } +} + +/** + * This function returns an index of an equivalence class for the variable + * size in the flattened variables. + * It uses a hashmap and vector to deduplicate sizes. + * + * This is necessary, to catch cases, where two variables had the same size + * when freezing a function and two different sizes when replaying. + * In that case one kernel would be recorded, that evaluates both variables. + * However, when replaying two kernels would have to be launched since the + * now differently sized variables cannot be evaluated by the same kernel. + */ +uint32_t FlatVariables::add_size(uint32_t size) { + uint32_t next_slot = this->sizes.size(); + auto result = this->size_to_slot.try_emplace(size, next_slot); + auto it = result.first; + bool inserted = result.second; + + if (inserted) { + this->sizes.push_back(size); + return next_slot; + } else { + return it.value(); + } +} + +/** + * Traverse a variable referenced by a JIT index and add it to the flat + * variables. An optional Python type can be supplied if it is known. + * Depending on the ``TraverseContext::schedule_force`` the underlying + * variable is either scheduled (``jit_var_schedule``) or force scheduled + * (``jit_var_schedule_force``). If the variable after evaluation is a + * literal, it is directly recorded in the ``layout``, otherwise it is added + * to the ``variables`` array, allowing the variables to be used when + * recording the frozen function. + */ +void FlatVariables::traverse_jit_index(uint32_t index, TraverseContext &ctx, + nb::handle tp) { + (void) ctx; + Layout &layout = this->layout.emplace_back(); + + if (tp) + layout.type = nb::borrow(tp); + + layout.flags |= (uint32_t) LayoutFlag::JitIndex; + layout.index = index; + layout.vt = (uint32_t) jit_var_type(index); +} + +/** + * Construct a variable, given its layout. + * This is the counterpart to `traverse_jit_index`. + * + * Optionally, the index of a variable can be provided that will be + * overwritten with the result of this function. In that case, the function + * will check for compatible variable types. + */ +uint32_t FlatVariables::construct_jit_index(uint32_t prev_index) { + Layout &layout = this->layout[layout_index++]; + + uint32_t index; + VarType vt; + if ((layout.flags & (uint32_t) LayoutFlag::Literal) || + (layout.flags & (uint32_t) LayoutFlag::Undefined)) { + index = layout.literal_index; + jit_var_inc_ref(index); + vt = (VarType) layout.vt; + } else { + VarLayout &var_layout = this->var_layout[layout.index]; + index = this->variables[layout.index]; + jit_log(LogLevel::Debug, " uses output[%u] = r%u", layout.index, + index); + + jit_var_inc_ref(index); + vt = var_layout.vt; + } + + if (prev_index) { + if (vt != (VarType) jit_var_type(prev_index)) + jit_fail("VarType mismatch %u != %u while assigning (r%u) " + "-> (r%u)!", + (uint32_t) vt, (uint32_t) jit_var_type(prev_index), + (uint32_t) prev_index, (uint32_t) index); + } + return index; +} + +/** + * Add an AD variable by its index. Both the value and gradient are added + * to the flattened variables. If the AD index has been marked as postponed + * in the \c TraverseContext.postponed field, we mark the resulting layout + * with that flag. This will cause the gradient edges to be propagated when + * assigning to the input. The function takes an optional Python type if + * it is known. + */ +void FlatVariables::traverse_ad_index(uint64_t index, TraverseContext &ctx, + nb::handle tp) { + // NOTE: instead of emplacing a Layout representing the ad variable always, + // we only do so if the gradients have been enabled. We use this format, + // since most variables will not be ad enabled. The layout therefore has to + // be peeked in ``construct_ad_index`` before deciding if an ad or jit + // index should be constructed/assigned. + int grad_enabled = ad_grad_enabled(index); + if (grad_enabled) { + Layout &layout = this->layout.emplace_back(); + uint32_t ad_index = (uint32_t) (index >> 32); + + if (tp) + layout.type = nb::borrow(tp); + layout.num = 2; + + // Set flags + layout.flags |= (uint32_t) LayoutFlag::GradEnabled; + // If the edge with this node as its target has been postponed by + // the isolate gradient scope, it has been enqueued and we mark the + // ad variable as such. + if (ctx.postponed && ctx.postponed->contains(ad_index)) { + layout.flags |= (uint32_t) LayoutFlag::Postponed; + } + + traverse_jit_index((uint32_t) index, ctx, tp); + uint32_t grad = ad_grad(index); + traverse_jit_index(grad, ctx, tp); + ctx.free_list.push_back_steal(grad); + } else { + traverse_jit_index(index, ctx, tp); + } +} + +/** + * Construct/assign the variable index given a layout. + * This corresponds to `traverse_ad_index`. + * + * This function is also used for assignment to AD variables. + * If a `prev_index` is provided, and it is an AD variable, the gradient and + * value of the flat variables will be applied to the AD variable, + * preserving the `ad_index`. + * + * It returns an owning reference. + */ +uint64_t FlatVariables::construct_ad_index(uint64_t prev_index) { + Layout &layout = this->layout[this->layout_index]; + + uint64_t index; + if ((layout.flags & (uint32_t) LayoutFlag::GradEnabled) != 0) { + Layout &layout = this->layout[this->layout_index++]; + bool postponed = (layout.flags & (uint32_t) LayoutFlag::Postponed); + + uint32_t val = construct_jit_index(prev_index); + uint32_t grad = construct_jit_index(prev_index); + + // Resize the gradient if it is a literal + if ((VarState) jit_var_state(grad) == VarState::Literal) { + uint32_t new_grad = jit_var_resize(grad, jit_var_size(val)); + jit_var_dec_ref(grad); + grad = new_grad; + } + + // If the prev_index variable is provided we assign the new value + // and gradient to the ad variable of that index instead of creating + // a new one. + uint32_t ad_index = (uint32_t) (prev_index >> 32); + if (ad_index) { + index = (((uint64_t) ad_index) << 32) | ((uint64_t) val); + ad_var_inc_ref(index); + } else + index = ad_var_new(val); + + jit_log(LogLevel::Debug, " -> ad_var r%zu", index); + jit_var_dec_ref(val); + + // Equivalent to set_grad + ad_clear_grad(index); + ad_accum_grad(index, grad); + jit_var_dec_ref(grad); + + // Variables, that have been postponed by the isolate gradient scope + // will be enqueued, which propagates their gradient to previous + // functions. + if (ad_index && postponed) { + ad_enqueue(drjit::ADMode::Backward, index); + } + } else { + index = construct_jit_index(prev_index); + } + + return index; +} + +/** + * Wrapper around traverse_ad_index for a Python handle. + */ +void FlatVariables::traverse_ad_var(nb::handle h, TraverseContext &ctx) { + auto s = supp(h.type()); + + if (s.is_class) { + auto variant = nb::borrow(nb::getattr(h, "Variant")); + auto domain = nb::borrow(nb::getattr(h, "Domain")); + add_domain(variant.c_str(), domain.c_str()); + } + + raise_if(s.index == nullptr, "freeze(): ArraySupplement index function " + "pointer is nullptr."); + + uint64_t index = s.index(inst_ptr(h)); + + this->traverse_ad_index(index, ctx, h.type()); +} + +/** + * Construct an AD variable given its layout. + * This corresponds to `traverse_ad_var`. + */ +nb::object FlatVariables::construct_ad_var(const Layout &layout) { + uint64_t index = construct_ad_index(); + + auto result = nb::inst_alloc_zero(layout.type); + const ArraySupplement &s = supp(result.type()); + s.init_index(index, inst_ptr(result)); + nb::inst_mark_ready(result); + + // We have to release the reference, since assignment will borrow from + // it. + ad_var_dec_ref(index); + + return result; +} + +/** + * Assigns an AD variable. + * Corresponds to `traverse_ad_var`. + * This uses `construct_ad_index` to either construct a new AD variable or + * assign the value and gradient to an already existing one. + */ +void FlatVariables::assign_ad_var(Layout &layout, nb::handle dst) { + const ArraySupplement &s = supp(layout.type); + + uint64_t index; + if (s.index) { + // ``construct_ad_index`` is used for assignment + index = construct_ad_index(s.index(inst_ptr(dst))); + } else + index = construct_ad_index(); + + s.reset_index(index, inst_ptr(dst)); + jit_log(LogLevel::Debug, "index=%zu, grad_enabled=%u, ad_grad_enabled=%u", + index, grad_enabled(dst), ad_grad_enabled(index)); + + // Release reference, since ``construct_ad_index`` returns owning + // reference and ``s.reset_index`` borrows from it. + ad_var_dec_ref(index); +} + +/** + * Traverse a C++ tree using its `traverse_1_cb_ro` callback. + */ +void FlatVariables::traverse_cb(const drjit::TraversableBase *traversable, + TraverseContext &ctx, nb::object type) { + // ProfilerPhase profiler(traversable); + + uint32_t layout_index = this->layout.size(); + Layout &layout = this->layout.emplace_back(); + layout.type = nb::borrow(type); + + struct Payload { + TraverseContext &ctx; + FlatVariables *flat_variables = nullptr; + uint32_t num_fields = 0; + }; + + Payload p{ ctx, this, 0 }; + + traversable->traverse_1_cb_ro( + (void *) &p, + [](void *p, uint64_t index, const char *variant, const char *domain) { + if (!index) + return; + Payload *payload = (Payload *) p; + payload->flat_variables->add_domain(variant, domain); + payload->flat_variables->traverse_ad_index(index, payload->ctx); + payload->num_fields++; + }); + + this->layout[layout_index].num = p.num_fields; +} + +/** + * Helper function, used to assign a callback variable. + * + * \param tmp + * This vector is populated with the indices to variables that have been + * constructed. It is required to release the references, since the + * references created by `construct_ad_index` are owning and they are + * borrowed after the callback returns. + */ +uint64_t FlatVariables::assign_cb_internal(uint64_t index, + index64_vector &tmp) { + if (!index) + return index; + + uint64_t new_index = this->construct_ad_index(index); + + tmp.push_back_steal(new_index); + return new_index; +} + +/** + * Assigns variables using its `traverse_cb_rw` callback. + * This corresponds to `traverse_cb`. + */ +void FlatVariables::assign_cb(drjit::TraversableBase *traversable) { + Layout &layout = this->layout[layout_index++]; + + struct Payload { + FlatVariables *flat_variables = nullptr; + Layout &layout; + index64_vector tmp; + uint32_t field_counter = 0; + }; + Payload p{ this, layout, index64_vector(), 0 }; + traversable->traverse_1_cb_rw( + (void *) &p, + [](void *p, uint64_t index, const char *, const char *) { + if (!index) + return index; + Payload *payload = (Payload *) p; + if (payload->field_counter >= payload->layout.num) + jit_raise("While traversing an object " + "for assigning inputs, the number of variables to " + "assign (>%u) did not match the number of variables " + "traversed when recording (%u)!", + payload->field_counter, payload->layout.num); + payload->field_counter++; + return payload->flat_variables->assign_cb_internal(index, payload->tmp); + }); + + if (p.field_counter != layout.num) + jit_raise("While traversing and object for assigning inputs, the " + "number of variables to assign did not match the number " + "of variables traversed when recording!"); +} + +/** + * Helper struct to construct path strings to variables. + * Used to provide helpful logs and error messages. + */ +struct scoped_path { + TraverseContext &m_ctx; + + uint32_t m_size; + scoped_path(TraverseContext &ctx, const char *suffix, bool dict = false) + : m_ctx(ctx), m_size(ctx.path.size()) { + if (dict) { + if (ctx.recursion_level == 0) + ctx.path.fmt("%s", suffix); + else + ctx.path.fmt("[\"%s\"]", suffix); + } else { + ctx.path.fmt(".%s", suffix); + } + ctx.recursion_level++; + } + scoped_path(TraverseContext &ctx, uint32_t suffix) + : m_ctx(ctx), m_size(ctx.path.size()) { + ctx.path.fmt("[%u]", suffix); + ctx.recursion_level++; + } + ~scoped_path() { + m_ctx.path.rewind(m_ctx.path.size() - m_size); + m_ctx.recursion_level--; + } +}; + +/** + * Traverses a PyTree in DFS order, and records its layout in the + * `layout` vector. + * + * When hitting a drjit primitive type, it calls the + * `traverse_dr_var` method, which will add their indices to the + * `flat_variables` vector. The collect method will also record metadata + * about the drjit variable in the layout. Therefore, the layout can be used + * as an identifier to the recording of the frozen function. + */ +void FlatVariables::traverse(nb::handle h, TraverseContext &ctx) { + recursion_guard guard(this); + + scoped_set_flag traverse_scope(JitFlag::EnableObjectTraversal, true); + + ProfilerPhase profiler("traverse"); + nb::handle tp = h.type(); + + auto tp_name = nb::type_name(tp).c_str(); + jit_log(LogLevel::Debug, "FlatVariables::traverse(): %s {", tp_name); + + uint32_t layout_index = this->layout.size(); + Layout &layout = this->layout.emplace_back(); + + const void *key = h.ptr(); + auto [it, inserted] = ctx.visited.try_emplace(key, nb::borrow(h)); + if (!inserted) { + layout.flags |= (uint32_t) LayoutFlag::RecursiveRef; + return; + } + try { + layout.type = nb::borrow(tp); + if (is_drjit_type(tp)) { + const ArraySupplement &s = supp(tp); + if (s.is_tensor) { + nb::handle array = s.tensor_array(h.ptr()); + + auto full_shape = nb::borrow(shape(h)); + + // Instead of adding the whole shape of a tensor to the key, we + // only add the inner part, not containing dimension 0. When + // indexing into a tensor, this is the only dimension that is + // not used in the index calculation. When constructing a tensor + // this dimension is reconstructed from the width of the + // underlying array. + + nb::list inner_shape; + if (full_shape.size() > 0) + for (uint32_t i = 1; i < full_shape.size(); i++) { + inner_shape.append(full_shape[i]); + } + + layout.py_object = nb::tuple(inner_shape); + + traverse(nb::steal(array), ctx); + } else if (s.ndim != 1) { + Py_ssize_t len = s.shape[0]; + if (len == DRJIT_DYNAMIC) + len = s.len(inst_ptr(h)); + + layout.num = len; + + for (Py_ssize_t i = 0; i < len; ++i) { + scoped_path ps(ctx, i); + traverse(nb::steal(s.item(h.ptr(), i)), ctx); + } + } else { + layout.num = 1; + traverse_ad_var(h, ctx); + } + } else if (tp.is(&PyTuple_Type)) { + nb::tuple tuple = nb::borrow(h); + + layout.num = tuple.size(); + + for (uint32_t i = 0; i < tuple.size(); i++) { + scoped_path ps(ctx, i); + auto h2 = tuple[i]; + traverse(h2, ctx); + } + } else if (tp.is(&PyList_Type)) { + nb::list list = nb::borrow(h); + + layout.num = list.size(); + + for (uint32_t i = 0; i < list.size(); i++) { + scoped_path ps(ctx, i); + auto h2 = list[i]; + traverse(h2, ctx); + } + } else if (tp.is(&PyDict_Type)) { + nb::dict dict = nb::borrow(h); + + layout.num = dict.size(); + layout.fields.reserve(layout.num); + for (auto k : dict.keys()) { + layout.fields.push_back(nb::borrow(k)); + } + + for (auto [k, v] : dict) { + scoped_path ps(ctx, nb::str(k).c_str(), true); + traverse(v, ctx); + } + } else if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { + + layout.num = ds.size(); + layout.fields.reserve(layout.num); + for (auto k : ds.keys()) { + layout.fields.push_back(nb::borrow(k)); + } + + for (auto [k, v] : ds) { + scoped_path ps(ctx, nb::str(k).c_str()); + traverse(nb::getattr(h, k), ctx); + } + } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { + + for (auto field : df) { + nb::object k = field.attr(DR_STR(name)); + layout.fields.push_back(nb::borrow(k)); + } + layout.num = layout.fields.size(); + + for (nb::handle field : df) { + nb::object k = field.attr(DR_STR(name)); + scoped_path ps(ctx, nb::str(k).c_str()); + traverse(nb::getattr(h, k), ctx); + } + } else if (auto traversable = get_traversable_base(h); traversable) { + traverse_cb(traversable, ctx, nb::borrow(tp)); + } else if (auto cb = get_traverse_cb_ro(tp); cb.is_valid()) { + ProfilerPhase profiler("traverse cb"); + + uint32_t num_fields = 0; + + // Traverse the opaque C++ object + cb(h, nb::cpp_function([&](uint64_t index, const char *variant, + const char *domain) { + if (!index) + return; + add_domain(variant, domain); + num_fields++; + this->traverse_ad_index(index, ctx, nb::none()); + return; + })); + + // Update layout number of fields + this->layout[layout_index].num = num_fields; + } else { + jit_log(LogLevel::Info, + "traverse(): You passed a value of type %s to a frozen " + "function, it could not be converted to a Dr.Jit type. " + "Changing this value in future calls to the frozen " + "function will cause it to be re-traced. The value is " + "located at %s.", + nb::str(tp).c_str(), ctx.path.get()); + + layout.py_object = nb::borrow(h); + } + } catch (nb::python_error &e) { + auto ts = nb::str(tp); + nb::raise_from( + e, PyExc_RuntimeError, + "FlatVariables::traverse(): error encountered while " + "processing an argument of type '%s' at location %s (see above).", + ts.c_str(), ctx.path.get()); + } catch (const std::exception &e) { + auto ts = nb::str(tp); + nb::chain_error( + PyExc_RuntimeError, + "FlatVariables::traverse(): error encountered " + "while processing an argument of type '%s' at location %s: %s", + ts.c_str(), ctx.path.get(), e.what()); + nb::raise_python_error(); + } + + if (!ctx.deduplicate_pytree) + ctx.visited.erase(key); + + jit_log(LogLevel::Debug, "}"); +} + +/** + * This is the counterpart to the ``traverse`` method, used to construct the + * output of a frozen function. Given a layout vector and flat_variables, it + * re-constructs the PyTree. + */ +nb::object FlatVariables::construct() { + recursion_guard guard(this); + + if (this->layout.size() == 0) { + return nb::none(); + } + + const Layout &layout = this->layout[layout_index++]; + + auto tp_name = nb::type_name(layout.type).c_str(); + jit_log(LogLevel::Debug, "FlatVariables::construct(): %s {", tp_name); + + if (layout.type.is(nb::none().type())) { + return nb::none(); + } + try { + if (is_drjit_type(layout.type)) { + const ArraySupplement &s = supp(layout.type); + if (s.is_tensor) { + nb::object array = construct(); + + // Reconstruct the full shape from the inner part, stored in the + // layout and the width of the underlying array. + auto inner_shape = nb::borrow(layout.py_object); + auto first_dim = prod(shape(array), nb::none()) + .floor_div(prod(inner_shape, nb::none())); + + nb::list full_shape; + full_shape.append(first_dim); + for (uint32_t i = 0; i < inner_shape.size(); i++) { + full_shape.append(inner_shape[i]); + } + + nb::object tensor = layout.type(array, nb::tuple(full_shape)); + return tensor; + } else if (s.ndim != 1) { + auto result = nb::inst_alloc_zero(layout.type); + dr::ArrayBase *p = inst_ptr(result); + size_t size = s.shape[0]; + if (size == DRJIT_DYNAMIC) { + size = layout.num; + s.init(size, p); + } + for (size_t i = 0; i < size; ++i) { + result[i] = construct(); + } + nb::inst_mark_ready(result); + return result; + } else { + return construct_ad_var(layout); + } + } else if (layout.type.is(&PyTuple_Type)) { + nb::list list; + for (uint32_t i = 0; i < layout.num; ++i) { + list.append(construct()); + } + return nb::tuple(list); + } else if (layout.type.is(&PyList_Type)) { + nb::list list; + for (uint32_t i = 0; i < layout.num; ++i) { + list.append(construct()); + } + return std::move(list); + } else if (layout.type.is(&PyDict_Type)) { + nb::dict dict; + for (auto k : layout.fields) { + dict[k] = construct(); + } + return std::move(dict); + } else if (nb::dict ds = get_drjit_struct(layout.type); ds.is_valid()) { + nb::object tmp = layout.type(); + // TODO: validation against `ds` + for (auto k : layout.fields) { + nb::setattr(tmp, k, construct()); + } + return tmp; + } else if (nb::object df = get_dataclass_fields(layout.type); + df.is_valid()) { + nb::dict dict; + for (auto k : layout.fields) { + dict[k] = construct(); + } + return layout.type(**dict); + } else if (layout.py_object) { + return layout.py_object; + } else { + nb::raise("Tried to construct a variable of type %s that is not " + "constructable!", + nb::type_name(layout.type).c_str()); + } + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_RuntimeError, + "FlatVariables::construct(): error encountered while " + "processing an argument of type '%U' (see above).", + nb::type_name(layout.type).ptr()); + } catch (const std::exception &e) { + nb::chain_error(PyExc_RuntimeError, + "FlatVariables::construct(): error encountered " + "while processing an argument of type '%U': %s", + nb::type_name(layout.type).ptr(), e.what()); + nb::raise_python_error(); + } + + jit_log(LogLevel::Debug, "}"); +} + +/** + * Assigns the flattened variables to an already existing PyTree. + * This is used when input variables have changed. + */ +void FlatVariables::assign(nb::handle dst, TraverseContext &ctx) { + recursion_guard guard(this); + scoped_set_flag traverse_scope(JitFlag::EnableObjectTraversal, true); + + nb::handle tp = dst.type(); + Layout &layout = this->layout[layout_index++]; + + if (layout.flags & (uint32_t) LayoutFlag::RecursiveRef) + return; + + jit_log(LogLevel::Debug, "FlatVariables::assign(): %s with %s {", + nb::type_name(tp).c_str(), nb::type_name(layout.type).c_str()); + + if (!layout.type.equal(tp)) + jit_raise("Type mismatch! Type of the object at location %s when " + "recording (%s) does not match type of object that is " + "assigned (%s).", + ctx.path.get(), nb::type_name(tp).c_str(), + nb::type_name(layout.type).c_str()); + + try { + if (is_drjit_type(tp)) { + const ArraySupplement &s = supp(tp); + + if (s.is_tensor) { + nb::handle array = s.tensor_array(dst.ptr()); + assign(nb::steal(array), ctx); + } else if (s.ndim != 1) { + Py_ssize_t len = s.shape[0]; + if (len == DRJIT_DYNAMIC) + len = s.len(inst_ptr(dst)); + + for (Py_ssize_t i = 0; i < len; ++i) { + scoped_path ps(ctx, i); + assign(dst[i], ctx); + } + } else { + assign_ad_var(layout, dst); + } + } else if (tp.is(&PyTuple_Type)) { + nb::tuple tuple = nb::borrow(dst); + raise_if( + tuple.size() != layout.num, + "The number of objects in this tuple changed from %u to %u " + "while recording the function.", + layout.num, (uint32_t) tuple.size()); + + for (uint32_t i = 0; i < tuple.size(); i++) { + scoped_path ps(ctx, i); + auto h2 = tuple[i]; + assign(h2, ctx); + } + } else if (tp.is(&PyList_Type)) { + nb::list list = nb::borrow(dst); + raise_if( + list.size() != layout.num, + "The number of objects in a list at %s changed from %u to %u " + "while recording the function.", + ctx.path.get(), layout.num, (uint32_t) list.size()); + + for (uint32_t i = 0; i < list.size(); i++) { + scoped_path ps(ctx, i); + auto h2 = list[i]; + assign(h2, ctx); + } + } else if (tp.is(&PyDict_Type)) { + nb::dict dict = nb::borrow(dst); + for (auto &k : layout.fields) { + scoped_path ps(ctx, nb::str(k).c_str(), true); + if (dict.contains(&k)) + assign(dict[k], ctx); + else + dst[k] = construct(); + } + } else if (nb::dict ds = get_drjit_struct(dst); ds.is_valid()) { + for (auto &k : layout.fields) { + scoped_path ps(ctx, nb::str(k).c_str()); + if (nb::hasattr(dst, k)) + assign(nb::getattr(dst, k), ctx); + else + nb::setattr(dst, k, construct()); + } + } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { + for (auto k : layout.fields) { + scoped_path ps(ctx, nb::str(k).c_str()); + if (nb::hasattr(dst, k)) + assign(nb::getattr(dst, k), ctx); + else + nb::setattr(dst, k, construct()); + } + } else if (auto traversable = get_traversable_base(dst); traversable) { + assign_cb(traversable); + } else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid()) { + index64_vector tmp; + uint32_t num_fields = 0; + + cb(dst, nb::cpp_function([&](uint64_t index, const char *, + const char *) { + if (!index) + return index; + jit_log(LogLevel::Debug, + "assign(): traverse_cb[%u] was a%u r%u", num_fields, + (uint32_t) (index >> 32), (uint32_t) index); + num_fields++; + if (num_fields > layout.num) + jit_raise( + "While traversing the object of type %s at location " + "%s for assigning inputs, the number of variables " + "to assign (>%u) did not match the number of " + "variables traversed when recording(%u)!", + ctx.path.get(), nb::str(tp).c_str(), num_fields, + layout.num); + return assign_cb_internal(index, tmp); + })); + if (num_fields != layout.num) + jit_raise( + "While traversing the object of type %s at location %s " + "for assigning inputs, the number of variables " + "to assign did not match the number of variables " + "traversed when recording!", + ctx.path.get(), nb::str(tp).c_str()); + } else { + } + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_RuntimeError, + "FlatVariables::assign(): error encountered while " + "processing an argument at %s " + "of type '%U' (see above).", + ctx.path.get(), nb::type_name(tp).ptr()); + } catch (const std::exception &e) { + nb::chain_error(PyExc_RuntimeError, + "FlatVariables::assign(): error encountered " + "while processing an argument at %s " + "of type '%U': %s", + ctx.path.get(), nb::type_name(tp).ptr(), e.what()); + nb::raise_python_error(); + } + + jit_log(LogLevel::Debug, "}"); +} + +/** + * First traverses the PyTree, then the registry. This ensures that + * additional data to vcalls is tracked correctly. + */ +void FlatVariables::traverse_with_registry(nb::handle h, TraverseContext &ctx) { + scoped_set_flag traverse_scope(JitFlag::EnableObjectTraversal, true); + + // Traverse the handle + traverse(h, ctx); + + // Traverse the registry (if a class variable was traversed) + if (!domains.empty()) { + ProfilerPhase profiler("traverse_registry"); + uint32_t layout_index = this->layout.size(); + Layout &layout = this->layout.emplace_back(); + layout.type = nb::borrow(nb::none()); + + uint32_t num_fields = 0; + + jit_log(LogLevel::Debug, "registry{"); + + drjit::vector registry_pointers; + for (std::string &domain : domains) { + uint32_t registry_bound = + jit_registry_id_bound(variant.c_str(), domain.c_str()); + uint32_t offset = registry_pointers.size(); + registry_pointers.resize(registry_pointers.size() + registry_bound, + nullptr); + jit_registry_get_pointers(variant.c_str(), domain.c_str(), + ®istry_pointers[offset]); + } + + jit_log(LogLevel::Debug, "registry_bound=%u", registry_pointers.size()); + jit_log(LogLevel::Debug, "layout_index=%u", this->layout.size()); + for (void *ptr : registry_pointers) { + jit_log(LogLevel::Debug, "ptr=%p", ptr); + if (!ptr) + continue; + + // WARN: very unsafe cast! + // We assume, that any object added to the registry inherits from + // TraversableBase. This is ensured by the signature of the + // ``drjit::registry_put`` function. + auto traversable = (drjit::TraversableBase *) ptr; + auto self = traversable->self_py(); + + if (self) + traverse(self, ctx); + else + traverse_cb(traversable, ctx); + + num_fields++; + } + jit_log(LogLevel::Debug, "}"); + + this->layout[layout_index].num = num_fields; + } +} + +/** + * First assigns the registry and then the PyTree. + * Corresponds to `traverse_with_registry`. + */ +void FlatVariables::assign_with_registry(nb::handle dst, TraverseContext &ctx) { + scoped_set_flag traverse_scope(JitFlag::EnableObjectTraversal, true); + + // Assign the handle + assign(dst, ctx); + + // Assign registry (if a class variable was traversed) + if (!domains.empty()) { + Layout &layout = this->layout[layout_index++]; + + jit_log(LogLevel::Debug, "registry{"); + + drjit::vector registry_pointers; + for (std::string &domain : domains) { + uint32_t registry_bound = + jit_registry_id_bound(variant.c_str(), domain.c_str()); + uint32_t offset = registry_pointers.size(); + registry_pointers.resize(registry_pointers.size() + registry_bound, + nullptr); + jit_registry_get_pointers(variant.c_str(), domain.c_str(), + ®istry_pointers[offset]); + } + + uint32_t num_fields = 0; + + for (void *ptr : registry_pointers) + if (ptr) + num_fields++; + + if (num_fields != layout.num) + jit_raise("assign_with_registry(): The number of registry " + "entries (%zu) did not match the number of registry " + "entries recorded (%u)!", + registry_pointers.size(), layout.num); + + jit_log(LogLevel::Debug, "registry_bound=%u", registry_pointers.size()); + jit_log(LogLevel::Debug, "layout_index=%u", this->layout_index); + for (void *ptr : registry_pointers) { + jit_log(LogLevel::Debug, "ptr=%p", ptr); + if (!ptr) + continue; + + // WARN: very unsafe cast! + // We assume, that any object added to the registry inherits from + // TraversableBase. This is ensured by the signature of the + // ``drjit::registry_put`` function. + auto traversable = (drjit::TraversableBase *) ptr; + auto self = traversable->self_py(); + + if (self) + assign(self, ctx); + else + assign_cb(traversable); + } + jit_log(LogLevel::Debug, "}"); + } +} + +FlatVariables::~FlatVariables() { + state_lock_guard guard; + for (uint32_t i = 0; i < layout.size(); ++i) { + Layout &l = layout[i]; + if (((l.flags & (uint32_t) LayoutFlag::Literal) || + (l.flags & (uint32_t) LayoutFlag::Undefined)) && + l.literal_index) { + jit_var_dec_ref(l.literal_index); + } + } +} + +void FlatVariables::borrow() { + state_lock_guard guard; + for (uint32_t &index : this->variables) + jit_var_inc_ref(index); +} + +void FlatVariables::release() { + state_lock_guard guard; + for (uint32_t &index : this->variables) + jit_var_dec_ref(index); +} + +bool log_diff_variable(LogLevel level, const FlatVariables &curr, + const FlatVariables &prev, uint32_t slot, + TraverseContext &ctx) { + const VarLayout &curr_l = curr.var_layout[slot]; + const VarLayout &prev_l = prev.var_layout[slot]; + + if (curr_l.vt != prev_l.vt) { + jit_log(level, "%s: The variable type changed from %u to %u.", + ctx.path.get(), prev_l.vt, curr_l.vt); + return false; + } + if (curr_l.size_index != prev_l.size_index) { + jit_log(level, + "%s: The size equivalence class of the variable changed from " + "%u to %u.", + ctx.path.get(), prev_l.size_index, curr_l.size_index); + return false; + } + + return true; +} + +/** + * Log the difference of the layout nodes at ``index`` for the two + * FlatVariables. + */ +bool log_diff(LogLevel level, const FlatVariables &curr, + const FlatVariables &prev, uint32_t &index, + TraverseContext &ctx) { + + const Layout &curr_l = curr.layout[index]; + const Layout &prev_l = prev.layout[index]; + index++; + + if (curr_l.flags != prev_l.flags) { + jit_log(level, "%s: The flags of this node changed from 0x%lx to 0x%lx", + ctx.path.get(), prev_l.flags, curr_l.flags); + return false; + } + + if (curr_l.index != prev_l.index) { + jit_log(level, + "%s: The index into the array of deduplicated variables " + "changed from s%u to s%u. This can occur if two variables " + "referred to the same JIT index, but do no longer.", + ctx.path.get(), prev_l.index, curr_l.index); + return false; + } + + if (curr_l.flags & (uint32_t) LayoutFlag::JitIndex && + !(curr_l.flags & (uint32_t) LayoutFlag::Literal) && + !(curr_l.flags & (uint32_t) LayoutFlag::Undefined)) { + uint32_t slot = curr_l.index; + if (!log_diff_variable(level, curr, prev, slot, ctx)) + return false; + } + + if (((bool) curr_l.type != (bool) prev_l.type) || + !(curr_l.type.equal(prev_l.type))) { + jit_log(level, "%s: The type of this node changed from %s to %s", + ctx.path.get(), nb::str(prev_l.type).c_str(), + nb::str(curr_l.type).c_str()); + return false; + } + + if (curr_l.literal != prev_l.literal) { + jit_log(level, + "%s: The literal value of this variable changed from 0x%llx to " + "0x%llx", + ctx.path.get(), prev_l.literal, curr_l.literal); + return false; + } + + if (((bool) curr_l.py_object != (bool) prev_l.py_object) || + !curr_l.py_object.equal(prev_l.py_object)) { + jit_log(level, "%s: The object changed from %s to %s", ctx.path.get(), + nb::str(prev_l.py_object).c_str(), + nb::str(curr_l.py_object).c_str()); + return false; + } + + if (curr_l.num != prev_l.num) { + jit_log(level, + "%s: The number of elements of this container changed from %u " + "to %u", + ctx.path.get(), prev_l.num, curr_l.num); + return false; + } + + if (curr_l.fields.size() != prev_l.fields.size()) { + jit_log(level, + "%s: The number of elements of this container changed from %u " + "to %u", + ctx.path.get(), prev_l.fields.size(), curr_l.fields.size()); + return false; + } + + for (uint32_t i = 0; i < curr_l.fields.size(); ++i) { + if (!(curr_l.fields[i].equal(prev_l.fields[i]))) { + jit_log(level, "%s: The %ith key changed from \"%s\" to \"%s\"", + ctx.path.get(), i, nb::str(curr_l.fields[i]).c_str(), + nb::str(prev_l.fields[i]).c_str()); + } + } + + if (curr_l.fields.size() > 0) { + for (uint32_t i = 0; i < curr_l.fields.size(); i++) { + auto &field = curr_l.fields[i]; + + scoped_path ps(ctx, nb::str(field).c_str(), + curr_l.type.is(&PyDict_Type)); + + log_diff(level, curr, prev, index, ctx); + } + } else { + for (uint32_t i = 0; i < curr_l.num; i++) { + scoped_path ps(ctx, i); + + log_diff(level, curr, prev, index, ctx); + } + } + + return true; +} + +/** + * Log the difference of the two FlatVariables. + */ +bool log_diff(LogLevel level, const FlatVariables &curr, + const FlatVariables &prev) { + // Skip expensive logging if set log level is not high enough. + if (level > jit_log_level_callback() && level > jit_log_level_stderr()) + return true; + + if (curr.flags != prev.flags) { + jit_log(level, "The flags of the input changed from %lx to %lx", + prev.flags, curr.flags); + return false; + } + if (curr.layout.size() != prev.layout.size()) { + jit_log(level, + "The number of elements in the input changed from %u to %u.", + prev.layout.size(), curr.layout.size()); + return false; + } + if (curr.var_layout.size() != prev.var_layout.size()) { + jit_log(level, + "The number of opaque variables in the input changed from %u " + "to %u.", + prev.var_layout.size(), curr.var_layout.size()); + return false; + } + + uint32_t index = 0; + TraverseContext ctx; + log_diff(level, curr, prev, index, ctx); + + return true; +} + +size_t FlatVariablesHasher::operator()( + const std::shared_ptr &key) const { + ProfilerPhase profiler("hash"); + // Hash the layout + + // TODO: Maybe we can use xxh by first collecting in vector? + + drjit::vector data; + data.reserve(2 + key->layout.size() * 4 + key->var_layout.size()); + + data.push_back(key->flags); + data.push_back((uint64_t) (key->layout.size() << 32) | + (uint64_t) (key->var_layout.size() << 2)); + + for (const Layout &layout : key->layout) { + // If layout.fields is not 0 then layout.num == layout.fields.size() + // therefore we can omit layout.fields.size(). + // This makes the assumption that we don't have more than 2^27-1 + // elements in one layout or variables in the FlatVariables. If more + // elements are part of the layout, hash collisions might occur, + // impacting performance but not correctness. + if (layout.num >> 26) + jit_log(LogLevel::Warn, + "The layout consists of more than 100M elements, which " + "might lead to hash collisions when looking up previous " + "recordings of frozen functions."); + if (layout.index >> 26) + jit_log( + LogLevel::Warn, + "The layout consists of more than 100M opaque variables, which " + "might lead to hash collisions when looking up previous " + "recordings of frozen functions."); + union { + struct { + uint64_t num : 26; + uint64_t index : 26; + uint64_t flags : 8; + uint64_t vt : 4; + }; + uint64_t data; + } lkey; + static_assert(sizeof(lkey) == sizeof(uint64_t)); + lkey.num = layout.num; + lkey.index = layout.index; + lkey.flags = layout.flags; + lkey.vt = layout.vt; + + data.push_back(lkey.data); + if (layout.flags & (uint32_t) LayoutFlag::JitIndex) + data.push_back(layout.literal); + + uint32_t type_hash = 0; + if (layout.type) + type_hash = nb::hash(layout.type); + + uint32_t object_hash = 0; + if (layout.py_object) { + PyObject *ptr = layout.py_object.ptr(); + Py_hash_t rv = PyObject_Hash(ptr); + + // Try to hash the object, and otherwise fallback to ``id()`` + if (rv == -1 && PyErr_Occurred()) { + PyErr_Clear(); + object_hash = (uintptr_t) ptr; + } else { + object_hash = rv; + } + } + if (type_hash && object_hash) + data.push_back(((uint64_t) type_hash << 32) | + ((uint64_t) (uint32_t) object_hash)); + for (auto &field : layout.fields) + data.push_back(nb::hash(field.ptr())); + } + + for (const VarLayout &layout : key->var_layout) { + // layout.vt: 4 + // layout.vs: 4 + // layout.flags: 8 + data.push_back(((uint64_t) layout.size_index << 32) | + ((uint64_t) layout.flags << 8) | + ((uint64_t) layout.vs << 4) | ((uint64_t) layout.vt)); + } + + uint64_t hash = XXH3_64bits(data.data(), data.size()); + + return hash; +} + +/* + * Record a function, given its python input and flattened input. + */ +nb::object FunctionRecording::record(nb::callable func, + FrozenFunction *frozen_func, + nb::dict input, + const FlatVariables &in_variables) { + ProfilerPhase profiler("record"); + JitBackend backend = in_variables.backend; + + frozen_func->recording_counter++; + if (frozen_func->recording_counter > frozen_func->warn_recording_count && + frozen_func->recordings.size() >= 1) { + if (frozen_func->recordings.size() < frozen_func->recording_counter) { + jit_log(LogLevel::Warn, + "The frozen function has been recorded %u times, this " + "indicates a problem with how the frozen function is being " + "called. The number of cached recordings %u is smaller " + "than the number of times this function has been recorded " + "%u, indicating that dry-running the recording failed at " + "least %u times.", + frozen_func->recording_counter, + frozen_func->recordings.size(), + frozen_func->recording_counter, + frozen_func->recording_counter - + frozen_func->recordings.size()); + } else { + jit_log( + LogLevel::Warn, + "The frozen function has been recorded %u times, this " + "indicates a problem with how the frozen function is being " + "called. For example, calling it with changing python values " + "such as an index. For more information about which variables " + "changed set the log level to ``LogLevel::Debug``.", + frozen_func->recording_counter); + } + log_diff(LogLevel::Info, in_variables, *frozen_func->prev_key); + } + + jit_log(LogLevel::Debug, + "Recording (n_inputs=%u):", in_variables.variables.size()); + jit_freeze_start(backend, in_variables.variables.data(), + in_variables.variables.size()); + + // Record the function + nb::object output; + { + ProfilerPhase profiler("function"); + state_unlock_guard guard; + output = func(input); + } + + // Collect nodes, that have been postponed by the `Isolate` scope in a + // hash set. + // These are the targets of postponed edges, as the isolate gradient + // scope only handles backward mode differentiation. + // If they are, then we have to enqueue them when replaying the + // recording. + tsl::robin_set postponed; + { + drjit::vector postponed_vec; + ad_scope_postponed(&postponed_vec); + for (uint32_t index : postponed_vec) + postponed.insert(index); + } + + { + ProfilerPhase profiler("traverse output"); + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, false); + + { + TraverseContext ctx; + ctx.postponed = &postponed; + ctx.deduplicate_pytree = false; + out_variables.traverse(output, ctx); + out_variables.schedule_jit_variables(false, nullptr); + } + + { + TraverseContext ctx; + ctx.postponed = &postponed; + out_variables.traverse_with_registry(input, ctx); + out_variables.schedule_jit_variables(false, nullptr); + } + + out_variables.layout_index = 0; + + { // Evaluate the variables, scheduled when traversing + nb::gil_scoped_release guard; + jit_eval(); + } + + out_variables.record_jit_variables(); + } + + jit_freeze_pause(backend); + + if ((out_variables.variables.size() > 0 && + in_variables.variables.size() > 0) && + out_variables.backend != backend) { + Recording *recording = jit_freeze_stop(backend, nullptr, 0); + jit_freeze_destroy(recording); + + nb::raise( + "freeze(): backend mismatch error (backend %u of " + "output variables did not match backend %u of input variables)", + (uint32_t) out_variables.backend, (uint32_t) backend); + } + + // Exceptions, thrown by the recording functions will be recorded and + // re-thrown when calling ``jit_freeze_stop``. Since the output variables + // are borrowed, we have to release them in that case, and catch these + // exceptions. + try { + recording = jit_freeze_stop(backend, out_variables.variables.data(), + out_variables.variables.size()); + } catch (nb::python_error &e) { + out_variables.release(); + nb::raise_from(e, PyExc_RuntimeError, + "record(): error encountered while recording a function " + "(see above)."); + } catch (const std::exception &e) { + out_variables.release(); + nb::chain_error(PyExc_RuntimeError, "record(): %s", e.what()); + nb::raise_python_error(); + } + + jit_log(LogLevel::Debug, "Recording done (n_outputs=%u)", + out_variables.variables.size()); + + // For catching input assignment mismatches, we assign the input and + // output + { + state_lock_guard guard; + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, false); + + out_variables.layout_index = 0; + jit_log(LogLevel::Debug, "Construct:"); + try { + output = nb::borrow(out_variables.construct()); + } catch (std::exception &e) { + out_variables.release(); + throw; + } + // NOTE: temporarily disable this to not enqueue twice + try { + TraverseContext ctx; + out_variables.assign(input, ctx); + } catch (std::exception &e) { + out_variables.release(); + throw; + } + out_variables.layout_index = 0; + } + + // Traversal takes owning references, so here we need to release them. + out_variables.release(); + + return output; +} +/* + * Replays the recording. + * + * This constructs the output and re-assigns the input. + */ +nb::object FunctionRecording::replay(nb::callable func, + FrozenFunction *frozen_func, + nb::dict input, + const FlatVariables &in_variables) { + ProfilerPhase profiler("replay"); + + jit_log(LogLevel::Info, "Replaying:"); + int dryrun_success; + { + ProfilerPhase profiler("dry run"); + dryrun_success = + jit_freeze_dry_run(recording, in_variables.variables.data()); + } + if (!dryrun_success) { + // Dry run has failed. Re-record the function. + jit_log(LogLevel::Info, "Dry run failed! re-recording"); + this->clear(); + try { + return this->record(func, frozen_func, input, in_variables); + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_RuntimeError, + "replay(): error encountered while re-recording a " + "function (see above)."); + } catch (const std::exception &e) { + jit_freeze_abort(in_variables.backend); + + nb::chain_error(PyExc_RuntimeError, "record(): %s", e.what()); + nb::raise_python_error(); + } + } else { + ProfilerPhase profiler("jit replay"); + nb::gil_scoped_release guard; + jit_freeze_replay(recording, in_variables.variables.data(), + out_variables.variables.data()); + } + jit_log(LogLevel::Info, "Replaying done:"); + + // Construct Output variables + nb::object output; + { + state_lock_guard guard; + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, false); + out_variables.layout_index = 0; + try { + ProfilerPhase profiler("construct output"); + output = nb::borrow(out_variables.construct()); + } catch (std::exception &e) { + out_variables.release(); + throw; + } + try { + ProfilerPhase profiler("assign input"); + TraverseContext ctx; + out_variables.assign_with_registry(input, ctx); + } catch (std::exception &e) { + out_variables.release(); + throw; + } + } + + // out_variables is assigned by ``jit_record_replay``, which transfers + // ownership to this array. Therefore, we have to drop the variables + // afterwards. + out_variables.release(); + + return output; +} + +nb::object FrozenFunction::operator()(nb::dict input) { + ProfilerPhase profiler("frozen function"); + state_lock_guard guard; + nb::object result; + { + // Enter Isolate grad scope, so that gradients are not propagated + // outside of the function scope. + ADScopeContext ad_scope(drjit::ADScope::Isolate, 0, nullptr, -1, true); + + // Kernel freezing can be enabled or disabled with the + // ``JitFlag::KernelFreezing``. Alternatively, when calling a frozen + // function from another one, we simply record the inner function. + if (!jit_flag(JitFlag::KernelFreezing) || + jit_flag(JitFlag::FreezingScope) || max_cache_size == 0) { + ProfilerPhase profiler("function"); + state_unlock_guard guard; + return func(input); + } + + call_counter++; + + auto in_variables = + std::make_shared(FlatVariables(in_heuristics)); + uint32_t flags = jit_flags(); + in_variables->flags = flags; + in_variables->backend = this->default_backend; + // Evaluate and traverse input variables (args and kwargs) + // Repeat this a max of 2 times if the number of variables that should + // be made opaque changed. + for (uint32_t i = 0; i < 2; i++) { + state_lock_guard guard; + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, 0, + true); + + // Traverse input variables + ProfilerPhase profiler("traverse input"); + + TraverseContext ctx; + in_variables->traverse_with_registry(input, ctx); + + // If this is the first time the frozen function has been called or + // the layout is not compatible with the previous one, we clear the + // opaque_mask. + bool auto_opaque = false; + if (prev_key) { + auto_opaque = compatible_auto_opaque(*in_variables, *prev_key); + if (!auto_opaque) { + // The mask is reset if they are not compatible + opaque_mask.resize(in_variables->layout.size()); + for (uint32_t i = 0; i < opaque_mask.size(); i++) + opaque_mask[i] = false; + jit_log(LogLevel::Debug, "auto-opaque incompatible"); + } + } else + opaque_mask.resize(in_variables->layout.size(), false); + + in_variables->schedule_jit_variables(!this->auto_opaque, + &opaque_mask); + + in_variables->layout_index = 0; + + { // Evaluate the variables, scheduled when traversing + ProfilerPhase profiler("eval"); + nb::gil_scoped_release guard; + jit_eval(); + } + + in_variables->record_jit_variables(); + bool new_opaques = false; + if (prev_key && auto_opaque) + new_opaques = + in_variables->fill_opaque_mask(*prev_key, opaque_mask); + + if (new_opaques) { + // If new variables have been discovered that should be made + // opaque, we repeat traversal of the input to make them opaque. + // This reduces the number of variants that are saved by one. + jit_log(LogLevel::Info, + "While traversing the frozen function input, new " + "literal variables have been discovered which changed " + "from one call to another. These will be made opaque, " + "and the input will be traversed again. This will " + "incur some overhead. To prevent this, make those " + "variables opaque in beforehand. Below, a list of " + "variables that changed will be shown."); + if (prev_key) + log_diff(LogLevel::Info, *in_variables, *prev_key); + in_variables->release(); + in_variables = std::make_shared( + FlatVariables(in_heuristics)); + in_variables->flags = flags; + } else { + break; + } + } + + in_heuristics = in_heuristics.max(in_variables->heuristic()); + + raise_if(in_variables->backend == JitBackend::None, + "freeze(): Cannot infer backend without providing input " + "variable to frozen function!"); + + auto it = this->recordings.find(in_variables); + + // Evict the least recently used recording if the cache is "full" + if (max_cache_size > 0 && + recordings.size() >= (uint32_t) max_cache_size && + it == this->recordings.end()) { + + uint32_t lru_last_used = UINT32_MAX; + RecordingMap::iterator lru_it = recordings.begin(); + + for (auto it = recordings.begin(); it != recordings.end(); it++) { + auto &recording = it.value(); + if (recording->last_used < lru_last_used) { + lru_last_used = recording->last_used; + lru_it = it; + } + } + recordings.erase(lru_it); + + it = this->recordings.find(in_variables); + } + + if (it == this->recordings.end()) { + { + // TODO: single traverse + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, 0, + true); + TraverseContext ctx; + in_variables->assign_with_registry(input, ctx); + } + + // FunctionRecording recording; + auto recording = std::make_unique(); + recording->last_used = call_counter - 1; + + try { + result = recording->record(func, this, input, *in_variables); + } catch (nb::python_error &e) { + in_variables->release(); + jit_freeze_abort(in_variables->backend); + nb::raise_from( + e, PyExc_RuntimeError, + "record(): error encountered while recording a frozen " + "function (see above)."); + } catch (const std::exception &e) { + in_variables->release(); + jit_freeze_abort(in_variables->backend); + + nb::chain_error(PyExc_RuntimeError, "record(): %s", e.what()); + nb::raise_python_error(); + }; + + in_variables->release(); + + this->prev_key = in_variables; + this->recordings.insert( + { std::move(in_variables), std::move(recording) }); + + } else { + FunctionRecording *recording = it.value().get(); + + recording->last_used = call_counter - 1; + + try { + result = recording->replay(func, this, input, *in_variables); + } catch (std::exception &e) { + in_variables->release(); + throw; + } + + // Drop references to variables + in_variables->release(); + } + } + ad_traverse(drjit::ADMode::Backward, + (uint32_t) drjit::ADFlag::ClearVertices); + return result; +} + +void FrozenFunction::clear() { + recordings.clear(); + prev_key = std::make_shared(FlatVariables()); + recording_counter = 0; + call_counter = 0; +} + +/** + * This function inspects the content of the frozen function to detect reference + * cycles, that could lead to memory or type leaks. It can be called by the + * garbage collector by adding it to the ``type_slots`` of the + * ``FrozenFunction`` definition. + */ +int frozen_function_tp_traverse(PyObject *self, visitproc visit, void *arg) { + FrozenFunction *f = nb::inst_ptr(self); + + nb::handle func = nb::find(f->func); + Py_VISIT(func.ptr()); + + for (auto &it : f->recordings) { + for (auto &layout : it.first->layout) { + nb::handle type = nb::find(layout.type); + Py_VISIT(type.ptr()); + nb::handle object = nb::find(layout.py_object); + Py_VISIT(object.ptr()); + } + for (auto &layout : it.second->out_variables.layout) { + nb::handle type = nb::find(layout.type); + Py_VISIT(type.ptr()); + nb::handle object = nb::find(layout.py_object); + Py_VISIT(object.ptr()); + } + } + + return 0; +} + +/** + * This function releases the internal function of the ``FrozenFunction`` + * object. It is used by the garbage collector to "break" potential reference + * cycles, resulting from the frozen function being referenced in the closure of + * the wrapped variable. + */ +int frozen_function_clear(PyObject *self) { + FrozenFunction *f = nb::inst_ptr(self); + + f->func.release(); + + return 0; +} + +// Slot data structure referencing the above two functions +static PyType_Slot slots[] = { { Py_tp_traverse, + (void *) frozen_function_tp_traverse }, + { Py_tp_clear, (void *) frozen_function_clear }, + { 0, nullptr } }; + +void export_freeze(nb::module_ & /*m*/) { + + nb::module_ d = nb::module_::import_("drjit.detail"); + auto traversable_base = + nb::class_(d, "TraversableBase"); + nb::class_(d, "FrozenFunction", nb::type_slots(slots)) + .def(nb::init()) + .def_prop_ro( + "n_cached_recordings", + [](FrozenFunction &self) { return self.n_cached_recordings(); }) + .def_ro("n_recordings", &FrozenFunction::recording_counter) + .def("clear", &FrozenFunction::clear) + .def("__call__", &FrozenFunction::operator()); +} diff --git a/src/python/freeze.h b/src/python/freeze.h new file mode 100644 index 000000000..0216ea4d1 --- /dev/null +++ b/src/python/freeze.h @@ -0,0 +1,666 @@ +/* + freeze.h -- Bindings for drjit.freeze() + + Dr.Jit: A Just-In-Time-Compiler for Differentiable Rendering + Copyright 2023, Realistic Graphics Lab, EPFL. + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE.txt file. +*/ + +#pragma once + +#include "common.h" +#include +#include +#include +#include +#include +#include +#include + +#include "../ext/nanobind/src/buffer.h" + +struct FrozenFunction; + +namespace detail { + +using Buffer = nanobind::detail::Buffer; + +using index64_vector = drjit::detail::index64_vector; +using index32_vector = drjit::detail::index32_vector; + +// This enum defines flags used in the input layout nodes, representing the +// PyTree. +enum class LayoutFlag : uint32_t { + /// Whether this variable has size 1 + SingletonArray = (1 << 0), + /// Whether this variable is unaligned in memory + Unaligned = (1 << 1), + /// Whether this layout represents a literal variable + Literal = (1 << 2), + /// Whether this layout represents an undefined variable (they behave + /// similarly to literals) + Undefined = (1 << 3), + /// Whether this variable has gradients enabled + GradEnabled = (1 << 4), + /// Did this variable have gradient edges attached when recording, that + /// were postponed by the ``isolate_grad`` function? + Postponed = (1 << 5), + /// Does this node represent a JIT Index? + JitIndex = (1 << 6), + /// This layout node is a recursive reference to another node. + RecursiveRef = (1 << 7), +}; + +/// Stores information about python objects, such as their type, their number of +/// sub-elements or their field keys. This can be used to reconstruct a PyTree +/// from a flattened variable array. +struct Layout { + + /// The literal data + uint64_t literal = 0; + + /// Optional field identifiers of the container + /// for example: keys in dictionary + drjit::vector fields; + + /// Number of members in this container. + /// Can be used to traverse the layout without knowing the type. + uint32_t num = 0; + + /// Either store the index if this is an opaque variable or the size of + /// the variable if this is a Literal or Undefined variable. This will be + /// hashed as part of the key. + union { + /// The index in the flat_variables array of this variable. + /// This can be used to determine aliasing. + uint32_t index = 0; + /// If this node is representing a literal or undefined variable, the + /// size is stored here instead. + uint32_t literal_size; + }; + + /// Flags, storing information about variables and literals. + uint32_t flags : 8; // LayoutFlag + + /// Optional drjit type of the variable + uint32_t vt : 4; // VarType + + /// Variable index of literal. Instead of constructing a literal every time, + /// we keep a reference to it. + uint32_t literal_index = 0; + + /// If a non drjit type is passed as function arguments or result, we simply + /// cache it here. + nb::object py_object; + + /// Nanobind type of the container/variable + nb::type_object type; + + bool operator==(const Layout &rhs) const; + bool operator!=(const Layout &rhs) const { return !(*this == rhs); } + + Layout() + : literal(0), fields(), num(0), index(0), flags(0), vt(0), + literal_index(0), py_object(), type() {}; + + Layout(const Layout &) = delete; + Layout &operator=(const Layout &) = delete; + + Layout(Layout &&) = default; + Layout &operator=(Layout &&) = default; +}; + +/** + * \brief Stores information about opaque variables. + * + * When traversing a PyTree, literal variables are stored directly and + * non-literal variables are first scheduled and their indices deduplicated and + * added to the ``FlatVariables::variables`` field. After calling ``jit_eval``, + * information about variables can be recorded using + * ``FlatVariables::record_jit_variables``. This struct stores that information + * per deduplicated variable. + */ +struct VarLayout { + /// Optional drjit type of the variable + VarType vt = VarType::Void; + + /// Optional evaluation state of the variable + VarState vs = VarState::Invalid; + + /// Flags, storing information about variables (see `LayoutFlag` enum above) + uint32_t flags = 0; + + /// We have to track the condition, where two variables have the same size + /// during recording but don't when replaying, however we do not want to + /// bake the size into the recording key.Therefore we construct equivalence + /// classes of sizes using a hashmap. When a variable with a new size is + /// traversed, this size is added to the `size_to_slot` map in + /// `FlatVariables`, and a unique index representing this size is assigned + /// to this field. + uint32_t size_index = 0; + + VarLayout() = default; + + VarLayout(const VarLayout &) = delete; + VarLayout &operator=(const VarLayout &) = delete; + + VarLayout(VarLayout &&) = default; + VarLayout &operator=(VarLayout &&) = default; + + bool operator==(const VarLayout &rhs) const; + bool operator!=(const VarLayout &rhs) const { return !(*this == rhs); } +}; + +// Additional context required when traversing the inputs +struct TraverseContext { + /// Set of postponed AD nodes, used to mark inputs to functions. + const tsl::robin_set *postponed = nullptr; + tsl::robin_map visited; + index32_vector free_list; + /// If this flag is set to ``true``, the PyTree will not be deduplicated + /// during traversal. Refcycles will still be prevented, but some objects + /// might be traversed multiple times. + bool deduplicate_pytree = true; + uint32_t recursion_level = 0; + Buffer path; + + TraverseContext() : path(1024) {} +}; + +/** + * \brief A flattened representation of the PyTree. + * + * This struct stores a flattened representation of a PyTree as well a + * representation of it. It can therefore be used to either construct the PyTree + * as well as assign the variables to an existing PyTree. Furthermore, this + * struct can also be used as a key to the ``RecordingMap``, determining which + * recording should be used given an input to a frozen function. + * Information about the PyTree is stored in DFS Encoding. Every node of the + * tree is represented by a ``Layout`` element in the ``layout`` vector. + */ +struct FlatVariables { + // Stores the JIT flags (see `jit_flags`), set when traversing the inputs. + uint32_t flags = 0; + + /// The flattened and de-duplicated variable indices of the input/output to + /// a frozen function + drjit::vector variables; + + /// Mapping from drjit jit index to index in flat variables. Used to + /// deduplicate jit indices. + tsl::robin_map index_to_slot; + + /// We have to track the condition, where two variables have the same size + /// during recording but don't when replaying. + /// Therefore we construct equivalence classes of sizes. + /// This vector represents the different sizes, encountered during + /// traversal. The algorithm used to "add" a size is the same as for adding + /// a variable index. + drjit::vector sizes; + + /// Mapping from the size to its index in the ``sizes`` vector. This is used + /// to construct size equivalence classes (i.e. deduplicating sizes). + tsl::robin_map size_to_slot; + + /// This saves information about the type, size and fields of pytree + /// objects. The information is stored in DFS order. + drjit::vector layout; + + /// Stores information about non-literal jit variables. + drjit::vector var_layout; + + /// The collective backend for all input variables. It can be used to ensure + /// that all variables have the same backend. + JitBackend backend = JitBackend::None; + + /// The variant, if any, used to traverse the registry. + std::string variant; + + /// All domains (deduplicated), encountered while traversing the PyTree and + /// its C++ objects. This can be used to traverse the registry. We use a + /// vector instead of a hash set, since we expect the number of domains not + /// to exceed 100. + drjit::vector domains; + + // Index, used to iterate over the variables/layouts when constructing + // python objects + uint32_t layout_index = 0; + + uint32_t recursion_level = 0; + + struct recursion_guard { + FlatVariables *flat_variables; + recursion_guard(FlatVariables *flat_variables) + : flat_variables(flat_variables) { + if (flat_variables->recursion_level >= 50) { + PyErr_SetString(PyExc_RecursionError, + "runaway recursion detected"); + nb::raise_python_error(); + } + // NOTE: the recursion_level has to be incremented after potentially + // throwing an exception, as throwing an exception in the + // constructor prevents the destructor from being called. + flat_variables->recursion_level++; + } + ~recursion_guard() { flat_variables->recursion_level--; } + }; + + /** + * Describes how many elements have to be pre-allocated for the ``layout``, + * ``index_to_slot`` and ``size_to_slot`` containers. + */ + struct Heuristic { + size_t layout = 0; + size_t index_to_slot = 0; + size_t size_to_slot = 0; + + Heuristic max(Heuristic rhs) { + return Heuristic{ + std::max(layout, rhs.layout), + std::max(index_to_slot, rhs.index_to_slot), + std::max(size_to_slot, rhs.size_to_slot), + }; + } + }; + + FlatVariables() {} + FlatVariables(Heuristic heuristic) { + layout.reserve(heuristic.layout); + index_to_slot.reserve(heuristic.index_to_slot); + size_to_slot.reserve(heuristic.size_to_slot); + } + + FlatVariables(const FlatVariables &) = delete; + FlatVariables &operator=(const FlatVariables &) = delete; + + FlatVariables(FlatVariables &&) = default; + FlatVariables &operator=(FlatVariables &&) = default; + + ~FlatVariables(); + + void clear() { + layout_index = 0; + variables.clear(); + index_to_slot.clear(); + layout.clear(); + backend = JitBackend::None; + } + /// Borrow all variables held by this struct. + void borrow(); + /// Release all variables held by this struct. + void release(); + + /** + * Generates a mask of variables that should be made opaque in the next + * iteration. This should only be called if \c compatible_auto_opaque + * returns true for the corresponding \c FlatVariables pair. + * + * Returns true if new variables have been discovered that should be made + * opaque, otherwise returns false. + */ + bool fill_opaque_mask(FlatVariables &prev, + drjit::vector &opaque_mask); + + /** + * Schedule variables that have been collected when traversing the PyTree. + * + * This function iterates over all ``Layout`` nodes that represent JIT + * indices and either calls ``jit_var_schedule`` or + * ``jit_var_schedule_force`` on them, depending on whether + * ``schedule_force`` is true or the boolean in the ``opaque_mask`` + * corresponding to that variable is true. + * + * \param schedule_force + * Overrides the use of \c opaque_mask and makes all variables opaque + * + * \param opaque_mask + * A pointer to a compatible boolean array, indicating if some of the + * variables should be made opaque. Can be \c nullptr, in which case it + * will be ignored. + */ + void schedule_jit_variables(bool schedule_force, + const drjit::vector *opaque_mask); + + /** + * \brief Records information about JIT variables that have been traversed. + * + * After traversing the PyTree, collecting non-literal indices in + * ``variables`` and evaluating the collected indices, we can collect + * information about the underlying variables. This information is used in + * the key of the ``RecordingMap`` to determine which recording should be + * replayed or if the function has to be re-traced. This function iterates + * over the collected indices and collects that information. + */ + void record_jit_variables(); + + /** + * Returns a struct representing heuristics to pre-allocate memory for the + * layout, of the flat variables. This accelerates subsequent traversals and + * replays. + */ + Heuristic heuristic() { + return Heuristic{ + layout.size(), + index_to_slot.size(), + size_to_slot.size(), + }; + }; + + /** + * \brief Add a variant domain pair to be traversed using the registry. + * + * When traversing a jit variable, that references a pointer to a class, + * such as a BSDF or Shape in Mitsuba, we have to traverse all objects + * registered with that variant-domain pair in the registry. This function + * adds the variant-domain pair, deduplicating the domain. Whether a + * variable references a class is represented by its ``IsClass`` const + * attribute. If the domain is an empty string (""), this function skips + * adding the variant-domain pair. + */ + void add_domain(const char *variant, const char *domain); + + /** + * Adds a JIT index to the flattened array, deduplicating it. + * This allows to check for aliasing conditions, where two variables + * actually refer to the same index. The function should only be called for + * scheduled non-literal variable indices. + */ + uint32_t add_jit_index(uint32_t variable_index); + + /** + * This function returns an index into the ``sizes`` vector, representing an + * equivalence class of variable sizes. It uses a HashMap and vector to + * deduplicate sizes. + * + * This is necessary, to catch cases, where two variables had the same size + * when recording a function and two different sizes when replaying. + * In that case one kernel would be recorded, that evaluates both variables. + * However, when replaying two kernels would have to be launched since the + * now differently sized variables cannot be evaluated by the same kernel. + */ + uint32_t add_size(uint32_t size); + + /** + * Traverse a variable referenced by a JIT index and add it to the flat + * variables. An optional Python type can be supplied if it is known. + * Depending on the ``TraverseContext::schedule_force`` the underlying + * variable is either scheduled (``jit_var_schedule``) or force scheduled + * (``jit_var_schedule_force``). If the variable after evaluation is a + * literal, it is directly recorded in the ``layout``, otherwise it is added + * to the ``variables`` array, allowing the variables to be used when + * recording the frozen function. + */ + void traverse_jit_index(uint32_t index, TraverseContext &ctx, + nb::handle tp = {}); + /** + * Add an AD variable by its index. Both the value and gradient are added + * to the flattened variables. If the AD index has been marked as postponed + * in the \c TraverseContext.postponed field, we mark the resulting layout + * with that flag. This will cause the gradient edges to be propagated when + * assigning to the input. The function takes an optional Python type if + * it is known. + */ + void traverse_ad_index(uint64_t index, TraverseContext &ctx, + nb::handle tp = {}); + + /** + * Wrapper around traverse_ad_index for a Python handle. + */ + void traverse_ad_var(nb::handle h, TraverseContext &ctx); + + /** + * Traverse a C++ tree using its `traverse_1_cb_ro` callback. + */ + void traverse_cb(const drjit::TraversableBase *traversable, + TraverseContext &ctx, nb::object type = nb::none()); + + /** + * Traverses a PyTree in DFS order, and records its layout in the + * `layout` vector. + * + * When hitting a drjit primitive type, it calls the + * `traverse_dr_var` method, which will add their indices to the + * `flat_variables` vector. The collect method will also record metadata + * about the drjit variable in the layout. Therefore, the layout can be used + * as an identifier to the recording of the frozen function. + */ + void traverse(nb::handle h, TraverseContext &ctx); + + /** + * First traverses the PyTree, then the registry. This ensures that + * additional data to vcalls is tracked correctly. + */ + void traverse_with_registry(nb::handle h, TraverseContext &ctx); + + /** + * Construct a variable, given its layout. + * This is the counterpart to `traverse_jit_index`. + * + * Optionally, the index of a variable can be provided that will be + * overwritten with the result of this function. In that case, the function + * will check for compatible variable types. + */ + uint32_t construct_jit_index(uint32_t prev_index = 0); + + /** + * Construct/assign the variable index given a layout. + * This corresponds to `traverse_ad_index`. + * + * This function is also used for assignment to AD variables. + * If a `prev_index` is provided, and it is an AD variable the gradient and + * value of the flat variables will be applied to the ad variable, + * preserving the `ad_index`. + * + * It returns an owning reference. + */ + uint64_t construct_ad_index(uint64_t prev_index = 0); + + /** + * Construct an ad variable given its layout. + * This corresponds to `traverse_ad_var` + */ + nb::object construct_ad_var(const Layout &layout); + + /** + * This is the counterpart to the traverse method, used to construct the + * output of a frozen function. Given a layout vector and flat_variables, it + * re-constructs the PyTree. + */ + nb::object construct(); + + /** + * Assigns an ad variable. + * Corresponds to `traverse_ad_var`. + * This uses `construct_ad_index` to either construct a new ad variable or + * assign the value and gradient to an already existing one. + */ + void assign_ad_var(Layout &layout, nb::handle dst); + + /** + * Helper function, used to assign a callback variable. + * + * \param tmp + * This vector is populated with the indices to variables that have been + * constructed. It is required to release the references, since the + * references created by `construct_ad_index` are owning and they are + * borrowed after the callback returns. + */ + uint64_t assign_cb_internal(uint64_t index, index64_vector &tmp); + + /** + * Assigns variables using its `traverse_cb_rw` callback. + * This corresponds to `traverse_cb`. + */ + void assign_cb(drjit::TraversableBase *traversable); + + /** + * Assigns the flattened variables to an already existing PyTree. + * This is used when input variables have changed. + */ + void assign(nb::handle dst, TraverseContext &ctx); + + /** + * First assigns the registry and then the PyTree. + * Corresponds to `traverse_with_registry`. + */ + void assign_with_registry(nb::handle dst, TraverseContext &ctx); + + bool operator==(const FlatVariables &rhs) const { + return this->layout == rhs.layout && + this->var_layout == rhs.var_layout && this->flags == rhs.flags; + } +}; + +/// Helper struct to hash input variables +struct FlatVariablesHasher { + size_t operator()(const std::shared_ptr &key) const; +}; + +/// Helper struct to compare input variables +struct FlatVariablesEqual { + using is_transparent = void; + bool operator()(const std::shared_ptr &lhs, + const std::shared_ptr &rhs) const { + return *lhs.get() == *rhs.get(); + } +}; + +/** + * \brief A recording of a frozen function, recorded with a certain layout of + * input variables. + */ +struct FunctionRecording { + /// The index of the \c call_counter when this recording was last used + /// (recorded or replayed). If the \c max_cache_size variable is set, this + /// will be used to evict the least recently used recording. + uint32_t last_used = 0; + + /// The opaque JIT recording, that has been recorded with \c + /// jit_freeze_start and \c jit_freeze_stop, and is held by this wrapper. + Recording *recording = nullptr; + + /// The layout of the output variables of this version of the function + /// recording. The JIT variables of this object have to be released after + /// use in \c record and \c replay. + FlatVariables out_variables; + + FunctionRecording() : out_variables() {} + FunctionRecording(const FunctionRecording &) = delete; + FunctionRecording &operator=(const FunctionRecording &) = delete; + FunctionRecording(FunctionRecording &&) = default; + FunctionRecording &operator=(FunctionRecording &&) = default; + + ~FunctionRecording() { + if (this->recording) + jit_freeze_destroy(this->recording); + + this->recording = nullptr; + } + + /// Clears the recording. + void clear() { + if (this->recording) + jit_freeze_destroy(this->recording); + + this->recording = nullptr; + this->out_variables = FlatVariables(); + } + + /* + * Record a function, given its python input and flattened input. + */ + nb::object record(nb::callable func, FrozenFunction *frozen_func, + nb::dict input, const FlatVariables &in_variables); + /* + * Replays the recording. + * + * This constructs the output and re-assigns the input. + */ + nb::object replay(nb::callable func, FrozenFunction *frozen_func, + nb::dict input, const FlatVariables &in_variables); +}; + +using RecordingMap = tsl::robin_map, + std::unique_ptr, + FlatVariablesHasher, FlatVariablesEqual>; + +} // namespace detail + +struct FrozenFunction { + /// The inner function, that is wrapped by this frozen function. + nb::callable func; + + /// Previously taken recordings, referenced by the layout of the input + /// variables. + detail::RecordingMap recordings; + + /// The layout of the previous recording, used for taking diffs and auto + /// opaque masks. + std::shared_ptr prev_key; + + /// This is used by the auto opaque feature to tag variables that should be + /// made opaque before calling the function. + drjit::vector opaque_mask; + + /// The number of times this function has been recorded. Note, this can + /// differ from the number of recordings actually cached in \c recordings, + /// when dry running recordings failed. + uint32_t recording_counter = 0; + + /// A counter, incremented whenever this function is called. It is used to + /// determine the least recently used recording in order to evict it if the + /// \c max_cache_size is set. + uint32_t call_counter = 0; + + /// Maximum number of recordings that should be made before evicting the + /// least recently used one. If this value is -1, recordings can be made + /// without limit. + int max_cache_size = -1; + + /// The number of recordings after which a warning message will be + /// displayed. This is useful to detect cases in which changing Python + /// values prevents replay. + uint32_t warn_recording_count = 10; + + /// If no JIT variable inputs are given to the function, this can indicate a + /// default backend, on which the function is recorded and replayed. + JitBackend default_backend = JitBackend::None; + + /// Whether the auto opaque feature is enabled. It allows us find literal + /// values that change between calls to the frozen function, and selectively + /// make those opaque. + bool auto_opaque = true; + + /// The maximum sizes previously seen for the vectors in \c FlatVariables. + /// Pre-allocating these vectors helps with performance. + detail::FlatVariables::Heuristic in_heuristics; + + FrozenFunction(nb::callable func, int max_cache_size = -1, + uint32_t warn_recording_count = 10, + JitBackend backend = JitBackend::None, + bool auto_opaque = false) + : func(func), max_cache_size(max_cache_size), + warn_recording_count(warn_recording_count), default_backend(backend), + auto_opaque(auto_opaque) {} + ~FrozenFunction() {} + + FrozenFunction(const FrozenFunction &) = delete; + FrozenFunction &operator=(const FrozenFunction &) = delete; + FrozenFunction(FrozenFunction &&) = default; + FrozenFunction &operator=(FrozenFunction &&) = default; + + /// Returns the number of recordings currently cached. + uint32_t n_cached_recordings() { return this->recordings.size(); } + + /// Clears the frozen function recordings and resets the counters. + void clear(); + + /// Operator to call the frozen function and either record a new version or + /// replay an old one. It expects a dictionary input, containing the args, + /// kwargs and closure of the Python function. + nb::object operator()(nb::dict input); +}; + +extern void export_freeze(nb::module_ &); diff --git a/src/python/main.cpp b/src/python/main.cpp index 7d34a7a02..0000d0b0f 100644 --- a/src/python/main.cpp +++ b/src/python/main.cpp @@ -8,7 +8,7 @@ BSD-style license that can be found in the LICENSE.txt file. */ -#define NB_INTRUSIVE_EXPORT NB_EXPORT +#define NB_INTRUSIVE_EXPORT NB_IMPORT #include #include @@ -22,6 +22,7 @@ #include "cuda.h" #include "reduce.h" #include "eval.h" +#include "freeze.h" #include "iter.h" #include "init.h" #include "memop.h" @@ -109,6 +110,9 @@ NB_MODULE(_drjit_ext, m_) { .value("SymbolicConditionals", JitFlag::SymbolicConditionals, doc_JitFlag_SymbolicConditionals) .value("SymbolicScope", JitFlag::SymbolicScope, doc_JitFlag_SymbolicScope) .value("ShaderExecutionReordering", JitFlag::ShaderExecutionReordering, doc_JitFlag_ShaderExecutionReordering) + .value("KernelFreezing", JitFlag::KernelFreezing, doc_JitFlag_KernelFreezing) + .value("FreezingScope", JitFlag::FreezingScope, doc_JitFlag_FreezingScope) + .value("EnableObjectTraversal", JitFlag::EnableObjectTraversal, doc_JitFlag_EnableObjectTraversal) .value("Default", JitFlag::Default, doc_JitFlag_Default) // Deprecated aliases @@ -245,6 +249,7 @@ NB_MODULE(_drjit_ext, m_) { export_iter(detail); export_reduce(m); export_eval(m); + export_freeze(m); export_memop(m); export_slice(m); export_dlpack(m); diff --git a/src/python/reduce.cpp b/src/python/reduce.cpp index c492c7b42..36102c8b1 100644 --- a/src/python/reduce.cpp +++ b/src/python/reduce.cpp @@ -516,6 +516,18 @@ nb::object mean(nb::handle value, nb::handle axis, nb::handle mode) { return out; } + if (jit_flag(JitFlag::FreezingScope) && width(out) == 1 && + width(value) > 1) { + // To avoid incorrect values when replaying frozen functions, we have to + // avoid baking the size of the array into the kernel as a literal. We + // therefore use the functions ``jit_opaque_width`` to compute the + // number of elements. + auto num_input = opaque_n_elements(value); + auto num_output = prod(shape(out), nb::none()); + + return (out * num_output) / num_input; + } + // mean = sum / (num_input/num_output) return (out * prod(shape(out), nb::none())) / prod(shape(value), nb::none()); } diff --git a/src/python/shape.cpp b/src/python/shape.cpp index 526043aea..690738359 100644 --- a/src/python/shape.cpp +++ b/src/python/shape.cpp @@ -9,8 +9,10 @@ */ #include "shape.h" -#include "base.h" #include "apply.h" +#include "base.h" +#include "meta.h" +#include Py_ssize_t sq_length(PyObject *o) noexcept { const ArraySupplement &s = supp(Py_TYPE(o)); @@ -178,12 +180,118 @@ size_t width(nb::handle h) { return to.width; } - /// Return the vectorization width of the given input array or PyTree extern size_t width(nb::handle h); + +nb::object opaque_width(nb::handle h) { + struct TraverseOp : TraverseCallback { + bool ragged = false; + size_t width = 0, items = 0; + ArrayMeta meta; + uint64_t index; + + void operator()(nb::handle h) override { + nb::handle tp = h.type(); + const ArraySupplement &s = supp(tp); + + if (s.index) { + index = s.index(inst_ptr(h)); + meta = supp(tp); + } + + size_t value = s.len(inst_ptr(h)); + if (items++ == 0) + width = value; + else if (width != 1 && value != 1 && width != value) + ragged = true; + if (value > width) + width = value; + } + + void traverse_unknown(nb::handle) override { + if (width == 0) + width = 1; + items++; + } + }; + + TraverseOp to; + traverse("drjit.opaque_width", to, h); + if (to.ragged) + nb::raise("drjit.opaque_width(): the input is ragged (i.e., it does not have a consistent size)."); + + uint32_t opaque_width = jit_var_opaque_width(to.index); + + ArrayMeta meta = to.meta; + meta.type = (uint16_t) VarType::UInt32; + + nb::handle width_tp = meta_get_type(meta); + const ArraySupplement width_s = supp(width_tp); + + if (!width_s.init_index) + nb::raise("drjit.opaque_width(): unsupported dtype."); + + nb::object width = nb::inst_alloc(width_tp); + width_s.init_index(opaque_width, inst_ptr(width)); + nb::inst_mark_ready(width); + + jit_var_dec_ref(opaque_width); + + return width; +} + +/// Same as \c width, but returns the width as an opaque array, allowing this +/// relationship to be recorded as part of a frozen function. Used in \c dr::mean. +extern nb::object opaque_width(nb::handle h); + + +/// Recursively traverses the PyTree of this object to compute the number of +/// elements. If a leaf object is a JIT array, the result will be an opaque +/// array. +nb::object opaque_n_elements(nb::handle h) { + nb::handle tp = h.type(); + + // We use dr::shape() to test for ragged arrays + auto s = shape(h); + + if (is_drjit_type(tp)) { + + const ArraySupplement &s = supp(tp); + + if (s.is_tensor) + return opaque_n_elements(nb::steal(s.tensor_array(h.ptr()))); + + if (!s.index) + jit_raise("opaque_n_lements(): Could not find indexing function"); + + uint32_t index = s.index(inst_ptr(h)); + + // Construct the opaque_width python object + uint32_t opaque_width = jit_var_opaque_width(index); + + ArrayMeta meta = supp(tp); + meta.type = (uint16_t) VarType::UInt32; + nb::handle width_tp = meta_get_type(meta); + const ArraySupplement width_s = supp(width_tp); + + nb::object width = nb::inst_alloc(width_tp); + width_s.init_index(opaque_width, inst_ptr(width)); + nb::inst_mark_ready(width); + + jit_var_dec_ref(opaque_width); + + return width; + } else { + Py_ssize_t rv = PyObject_Length(h.ptr()); + + return opaque_n_elements(h[0]) * nb::int_(rv); + } +} + void export_shape(nb::module_ &m) { m.def("shape", &shape, doc_shape, nb::sig("def shape(arg: object) -> tuple[int, ...]")); m.def("width", &width, doc_width) .def("width", [](nb::args args) { return width(args); }); + m.def("opaque_width", &opaque_width); } diff --git a/src/python/shape.h b/src/python/shape.h index bef0c327a..b804dbdca 100644 --- a/src/python/shape.h +++ b/src/python/shape.h @@ -24,6 +24,15 @@ extern size_t ndim(nb::handle_t h) noexcept; /// Return the vectorization width of the given input array or PyTree extern size_t width(nb::handle h); +/// Same as \c width, but returns the width as an opaque array, allowing this +/// relationship to be recorded as part of a frozen function. Used in \c dr::mean. +extern nb::object opaque_width(nb::handle h); + +/// Recursively traverses the PyTree of this object to compute the number of +/// elements. If a leaf object is a JIT array, the result will be an opaque +/// array. +extern nb::object opaque_n_elements(nb::handle h); + /// Convert vector into a python tuple extern nb::tuple cast_shape(const vector &shape); diff --git a/src/python/texture.h b/src/python/texture.h index e8d009856..4068ec226 100644 --- a/src/python/texture.h +++ b/src/python/texture.h @@ -173,6 +173,8 @@ void bind_texture(nb::module_ &m, const char *name) { #undef def_tex_eval_cubic_helper tex.attr("IsTexture") = true; + + drjit::bind_traverse(tex); } template diff --git a/src/python/tracker.cpp b/src/python/tracker.cpp index f974cb0c9..7158bbac9 100644 --- a/src/python/tracker.cpp +++ b/src/python/tracker.cpp @@ -192,8 +192,8 @@ struct VariableTracker::Context { check_size(check_size), index_offset(0) { } // Internal API for type-erased traversal - uint64_t _traverse_write(uint64_t idx); - void _traverse_read(uint64_t index); + uint64_t _traverse_write(uint64_t idx, const char *, const char *); + void _traverse_read(uint64_t index, const char *, const char *); }; // Temporarily push a value onto the stack @@ -611,7 +611,8 @@ bool VariableTracker::Impl::traverse(Context &ctx, nb::handle h) { return changed; } -uint64_t VariableTracker::Context::_traverse_write(uint64_t idx) { +uint64_t VariableTracker::Context::_traverse_write(uint64_t idx, const char *, + const char *) { if (!idx) return 0; if (index_offset >= indices.size()) @@ -648,7 +649,7 @@ uint64_t VariableTracker::Context::_traverse_write(uint64_t idx) { return idx_new; } -void VariableTracker::Context::_traverse_read(uint64_t index) { +void VariableTracker::Context::_traverse_read(uint64_t index, const char *, const char *) { if (!index) return; indices.push_back(ad_var_inc_ref(index)); diff --git a/tests/call_ext.cpp b/tests/call_ext.cpp index 90dd869c2..affbbc05c 100644 --- a/tests/call_ext.cpp +++ b/tests/call_ext.cpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace nb = nanobind; namespace dr = drjit; @@ -13,30 +14,28 @@ namespace dr = drjit; using namespace nb::literals; template -struct Sampler { +struct Sampler : dr::TraversableBase { + Sampler() : rng(1) {} Sampler(size_t size) : rng(size) { } T next() { return rng.next_float32(); } - void traverse_1_cb_ro(void *payload, void (*fn)(void *, uint64_t)) const { - traverse_1_fn_ro(rng, payload, fn); - } - - void traverse_1_cb_rw(void *payload, uint64_t (*fn)(void *, uint64_t)) { - traverse_1_fn_rw(rng, payload, fn); - } - dr::PCG32> rng; + + DR_TRAVERSE_CB(dr::TraversableBase, rng); }; -template struct Base : nb::intrusive_base { +template struct Base : drjit::TraversableBase { using Mask = dr::mask_t; using UInt32 = dr::uint32_array_t; virtual std::pair f(Float x, Float y) = 0; virtual std::pair f_masked(const std::pair &xy, Mask active) = 0; virtual Float g(Float, Mask) = 0; + virtual Float h(Float) = 0; virtual Float nested(Float x, UInt32 s) = 0; + /// Nested vcall, using a member variable as a pointer. + virtual Float nested_self(Float x) = 0; virtual void dummy() = 0; virtual float scalar_getter() = 0; virtual Float opaque_getter() = 0; @@ -48,12 +47,19 @@ template struct Base : nb::intrusive_base { virtual void scatter_packet(UInt32, dr::Array) = 0; virtual void scatter_add_packet(UInt32, dr::Array) = 0; + static constexpr const char *variant_() { + return Float::Backend == JitBackend::CUDA ? "cuda" : "llvm"; + } + Base() { - if constexpr (dr::is_jit_v) - jit_registry_put("", "Base", this); + if constexpr (dr::is_jit_v){ + drjit::registry_put(Base::variant_(), "Base", this); + } } virtual ~Base() { jit_registry_remove(this); } + + DR_TRAVERSE_CB(drjit::TraversableBase) }; template struct A : Base { @@ -74,10 +80,18 @@ template struct A : Base { return value; } + virtual Float h(Float x) override{ + return value + x; + } + virtual Float nested(Float x, UInt32 /*s*/) override { return x + dr::gather(value, UInt32(0)); } + virtual Float nested_self(Float x) override { + return x + dr::gather(value, UInt32(0)); + } + virtual std::pair *, Float> sample(Sampler *s) override { return { s, s->next() }; } @@ -112,6 +126,8 @@ template struct A : Base { uint32_t scalar_property; Float value, extra_value; Float opaque = dr::opaque(1.f); + + DR_TRAVERSE_CB(Base, value, opaque) }; template struct B : Base { @@ -132,12 +148,22 @@ template struct B : Base { return value*x; } + virtual Float h(Float x) override{ + return value - x; + } + virtual Float nested(Float x, UInt32 s) override { using BaseArray = dr::replace_value_t*>; BaseArray self = dr::reinterpret_array(s); return self->nested(x, s); } + virtual Float nested_self(Float x) override { + using BaseArray = dr::replace_value_t*>; + BaseArray self = dr::reinterpret_array(this->s); + return self->nested(x, this->s); + } + virtual std::pair *, Float> sample(Sampler *s) override { return { s, 0 }; } @@ -160,14 +186,23 @@ template struct B : Base { Float value; Float opaque = dr::opaque(2.f); + UInt32 s; + + DR_TRAVERSE_CB(Base, value, opaque) }; +template constexpr const char *get_variant() { + return Float::Backend == JitBackend::CUDA ? "cuda" : "llvm"; +} + DRJIT_CALL_TEMPLATE_BEGIN(Base) DRJIT_CALL_METHOD(f) DRJIT_CALL_METHOD(f_masked) DRJIT_CALL_METHOD(dummy) DRJIT_CALL_METHOD(g) + DRJIT_CALL_METHOD(h) DRJIT_CALL_METHOD(nested) + DRJIT_CALL_METHOD(nested_self) DRJIT_CALL_METHOD(sample) DRJIT_CALL_METHOD(gather_packet) DRJIT_CALL_METHOD(scatter_packet) @@ -177,6 +212,8 @@ DRJIT_CALL_TEMPLATE_BEGIN(Base) DRJIT_CALL_GETTER(complex_getter) DRJIT_CALL_GETTER(constant_getter) DRJIT_CALL_METHOD(get_self) +public: + static constexpr const char *variant_() { return get_variant(); } DRJIT_CALL_END(Base) @@ -198,20 +235,23 @@ void bind(nb::module_ &m) { using Sampler = ::Sampler; auto sampler = nb::class_(m, "Sampler") + .def(nb::init<>()) .def(nb::init()) .def("next", &Sampler::next) .def_rw("rng", &Sampler::rng); bind_traverse(sampler); - nb::class_(m, "Base") + auto base_cls = nb::class_(m, "Base") .def("f", &BaseT::f) .def("f_masked", &BaseT::f_masked) .def("g", &BaseT::g) .def("nested", &BaseT::nested) + .def("nested_self", &BaseT::nested_self) .def("sample", &BaseT::sample); + bind_traverse(base_cls); - nb::class_(m, "A") + auto a_cls = nb::class_(m, "A") .def(nb::init<>()) .def("a_get_property", &AT::a_get_property) .def("a_gather_extra_value", &AT::a_gather_extra_value) @@ -219,11 +259,14 @@ void bind(nb::module_ &m) { .def_rw("value", &AT::value) .def_rw("extra_value", &AT::extra_value) .def_rw("scalar_property", &AT::scalar_property); + bind_traverse(a_cls); - nb::class_(m, "B") + auto b_cls = nb::class_(m, "B") .def(nb::init<>()) .def_rw("opaque", &BT::opaque) - .def_rw("value", &BT::value); + .def_rw("value", &BT::value) + .def_rw("s", &BT::s); + bind_traverse(b_cls); using BaseArray = dr::DiffArray; m.def("dispatch_f", [](BaseArray &self, Float a, Float b) { @@ -243,9 +286,13 @@ void bind(nb::module_ &m) { .def("g", [](BaseArray &self, Float x, Mask m) { return self->g(x, m); }, "x"_a, "mask"_a = true) + .def("h", [](BaseArray &self, Float x) { return self->h(x); }, "x"_a) .def("nested", [](BaseArray &self, Float x, UInt32 s) { return self->nested(x, s); }, "x"_a, "s"_a) + .def("nested_self", + [](BaseArray &self, Float x) { return self->nested_self(x); }, + "x"_a) .def("dummy", [](BaseArray &self) { return self->dummy(); }) .def("scalar_getter", [](BaseArray &self, Mask m) { return self->scalar_getter(m); diff --git a/tests/custom_type_ext.cpp b/tests/custom_type_ext.cpp index 50c936ee0..7bff08779 100644 --- a/tests/custom_type_ext.cpp +++ b/tests/custom_type_ext.cpp @@ -1,6 +1,10 @@ -#include #include #include +#include +#include +#include +#include +#include namespace nb = nanobind; namespace dr = drjit; @@ -42,6 +46,65 @@ struct CustomHolder { Value m_value; }; +class Object : public drjit::TraversableBase { + DR_TRAVERSE_CB(drjit::TraversableBase); +}; + +template +class CustomBase : public Object{ + Value m_base_value; + +public: + CustomBase(const Value &base_value) : Object(), m_base_value(base_value) {} + + Value &base_value() { return m_base_value; } + virtual Value &value() = 0; + + DR_TRAVERSE_CB(Object, m_base_value); +}; + +template +class PyCustomBase : public CustomBase{ +public: + using Base = CustomBase; + NB_TRAMPOLINE(Base, 1); + + PyCustomBase(const Value &base_value) : Base(base_value) {} + + Value &value() override { NB_OVERRIDE_PURE(value); } + + DR_TRAMPOLINE_TRAVERSE_CB(Base); +}; + +template +class CustomA: public CustomBase{ +public: + using Base = CustomBase; + + CustomA(const Value &value, const Value &base_value) : Base(base_value), m_value(value) {} + + Value &value() override { return m_value; } + +private: + Value m_value; + + DR_TRAVERSE_CB(Base, m_value); +}; + +template +class Nested: Object{ + using Base = Object; + + std::vector, size_t>> m_nested; + +public: + Nested(nb::ref a, nb::ref b) { + m_nested.push_back(std::make_pair(a, 0)); + m_nested.push_back(std::make_pair(b, 1)); + } + + DR_TRAVERSE_CB(Base, m_nested); +}; template void bind(nb::module_ &m) { dr::ArrayBinding b; @@ -64,12 +127,50 @@ template void bind(nb::module_ &m) { .def(nb::init()) .def("value", &CustomFloatHolder::value, nanobind::rv_policy::reference); + using CustomBase = CustomBase; + using PyCustomBase = PyCustomBase; + using CustomA = CustomA; + using Nested = Nested; + + auto object = nb::class_( + m, "Object", + nb::intrusive_ptr( + [](Object *o, PyObject *po) noexcept { o->set_self_py(po); })); + + auto base = + nb::class_(m, "CustomBase") + .def(nb::init()) + .def("value", nb::overload_cast<>(&CustomBase::value)) + .def("base_value", nb::overload_cast<>(&CustomBase::base_value)); + + drjit::bind_traverse(base); + + auto a = nb::class_(m, "CustomA") + .def(nb::init()); + + drjit::bind_traverse(a); + + auto nested = nb::class_(m, "Nested") + .def(nb::init, nb::ref>()); + + drjit::bind_traverse(nested); + m.def("cpp_make_opaque", [](CustomFloatHolder &holder) { dr::make_opaque(holder); } ); } NB_MODULE(custom_type_ext, m) { + nb::intrusive_init( + [](PyObject *o) noexcept { + nb::gil_scoped_acquire guard; + Py_INCREF(o); + }, + [](PyObject *o) noexcept { + nb::gil_scoped_acquire guard; + Py_DECREF(o); + }); + #if defined(DRJIT_ENABLE_LLVM) nb::module_ llvm = m.def_submodule("llvm"); bind(llvm); diff --git a/tests/test_custom_type_ext.py b/tests/test_custom_type_ext.py index 90c9b7a23..a997c2dda 100644 --- a/tests/test_custom_type_ext.py +++ b/tests/test_custom_type_ext.py @@ -1,7 +1,6 @@ import drjit as dr import pytest - def get_pkg(t): with dr.detail.scoped_rtld_deepbind(): m = pytest.importorskip("custom_type_ext") @@ -69,3 +68,135 @@ def test03_cpp_make_opaque(t): pkg.cpp_make_opaque(holder) assert holder.value().state == dr.VarState.Evaluated + + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test04_traverse_opaque(t): + """ + Tests that it is possible to traverse an opaque C++ object. + """ + pkg = get_pkg(t) + Float = t + + value = dr.arange(Float, 10) + base_value = dr.arange(Float, 10) + + a = pkg.CustomA(value, base_value) + assert dr.detail.collect_indices(a) == [base_value.index, value.index] + + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test05_traverse_py(t): + """ + Tests the implementation of ``traverse_py_cb_ro``, which is used to traverse + python objects in trampoline classes. + """ + Float = t + + v = dr.arange(Float, 10) + + class PyClass: + def __init__(self, v) -> None: + self.v = v + + c = PyClass(v) + + result = [] + + def callback(index, domain, variant): + result.append(index) + + dr.detail.traverse_py_cb_ro(c, callback) + + assert result == [v.index] + + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test06_trampoline_traversal(t): + """ + Tests that classes inheriting from trampoline classes are traversed + automatically. + """ + pkg = get_pkg(t) + Float = t + + value = dr.opaque(Float, 0, 3) + base_value = dr.opaque(Float, 1, 3) + + class B(pkg.CustomBase): + def __init__(self, value, base_value) -> None: + super().__init__(base_value) + self._value = value + + def value(self): + return self._value + + b = B(value, base_value) + + assert dr.detail.collect_indices(b) == [base_value.index, value.index] + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test07_nested_traversal(t): + """ + Test traversal of nested objects, and more specifically the traversal of + ``std::vector, size_t>>`` members. + """ + pkg = get_pkg(t) + Float = t + + value = dr.arange(Float, 10) + 0 + base_value = dr.arange(Float, 10) + 1 + + a = pkg.CustomA(value, base_value) + + value = dr.arange(Float, 10) + 2 + base_value = dr.arange(Float, 10) + 3 + + b = pkg.CustomA(value, base_value) + + nested = pkg.Nested(a, b) + + indices_a = dr.detail.collect_indices(a) + indices_b = dr.detail.collect_indices(b) + indices_nested = dr.detail.collect_indices(nested) + + assert indices_nested == indices_a + indices_b + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test08_custom_type_refcycle(t): + """ + Tests that it is possible to collect indices from PyTrees with refcycles, + without throwing runaway recursion errors, if ``EnableObjectTraversal`` is + set to ``True``. + """ + pkg = get_pkg(t) + Float = t + + value = dr.opaque(Float, 0, 3) + base_value = dr.opaque(Float, 1, 3) + + class B(pkg.CustomBase): + def __init__(self, value, base_value) -> None: + super().__init__(base_value) + self._value = value + + def value(self): + return self._value + + class C(pkg.CustomBase): + def __init__(self, value, base_value, ref) -> None: + super().__init__(base_value) + self._value = value + self._ref = ref + + def value(self): + return self._value + + # Construct a reference cycle + b = B(value, base_value) + c = C(value, base_value, b) + b.child = c + + with pytest.raises(RuntimeError): + indices = dr.detail.collect_indices(b) + diff --git a/tests/test_freeze.py b/tests/test_freeze.py new file mode 100644 index 000000000..690bf0974 --- /dev/null +++ b/tests/test_freeze.py @@ -0,0 +1,3722 @@ +import drjit as dr +import pytest +from math import ceil +from dataclasses import dataclass +import sys + +def skip_if_coopvec_not_supported(t): + if dr.backend_v(t) == dr.JitBackend.CUDA: + if dr.detail.cuda_version() < (12, 8): + pytest.skip("CUDA driver does not support cooperative vectors (Driver R570) or later is required") + +def get_single_entry(x): + tp = type(x) + result = x + shape = dr.shape(x) + if len(shape) == 2: + result = result[shape[0] - 1] + if len(shape) == 3: + result = result[shape[0] - 1][shape[1] - 1] + return result + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test01_basic(t, auto_opaque): + """ + Tests a very basic frozen function, adding two integers x, y. + """ + + @dr.freeze(auto_opaque=auto_opaque) + def func(x, y): + return x + y + + i0 = t(0, 1, 2) + i1 = t(2, 1, 0) + + o0 = func(i0, i1) + assert dr.all(t(2, 2, 2) == o0) + + i0 = t(1, 2, 3) + i1 = t(3, 2, 1) + + o0 = func(i0, i1) + assert dr.all(t(4, 4, 4) == o0) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test02_flush_kernel_cache(t, auto_opaque): + """ + Tests that flushing the kernel between recording and replaying a frozen + function causes the function to be re-traced. + """ + + def func(x, y): + return x + y + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + x = t(0, 1, 2) + y = t(2, 1, 0) + + res = frozen(x, y) + ref = func(x, y) + assert dr.all(res == ref) + + dr.flush_kernel_cache() + + x = t(1, 2, 3) + y = t(3, 2, 1) + + # Flushing the kernel cache should force a re-trace + res = frozen(x, y) + ref = func(x, y) + assert dr.all(res == ref) + assert frozen.n_recordings == 2 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test03_output_tuple(t, auto_opaque): + """ + Tests that returning tuples from frozen functions is possible. + """ + + @dr.freeze(auto_opaque=auto_opaque) + def func(x, y): + return (x + y, x * y) + + i0 = t(0, 1, 2) + i1 = t(2, 1, 0) + + (o0, o1) = func(i0, i1) + assert dr.all(t(2, 2, 2) == o0) + assert dr.all(t(0, 1, 0) == o1) + + i0 = t(1, 2, 3) + i1 = t(3, 2, 1) + + (o0, o1) = func(i0, i1) + assert dr.all(t(4, 4, 4) == o0) + assert dr.all(t(3, 4, 3) == o1) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test04_output_list(t, auto_opaque): + """ + Tests that returning lists from forzen functions is possible. + """ + + @dr.freeze(auto_opaque=auto_opaque) + def func(x, y): + return [x + y, x * y] + + i0 = t(0, 1, 2) + i1 = t(2, 1, 0) + + [o0, o1] = func(i0, i1) + assert dr.all(t(2, 2, 2) == o0) + assert dr.all(t(0, 1, 0) == o1) + + i0 = t(1, 2, 3) + i1 = t(3, 2, 1) + + [o0, o1] = func(i0, i1) + assert dr.all(t(4, 4, 4) == o0) + assert dr.all(t(3, 4, 3) == o1) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test05_output_dict(t, auto_opaque): + """ + Tests that returning dictionaries from forzen functions is possible. + """ + + @dr.freeze(auto_opaque=auto_opaque) + def func(x, y): + return {"add": x + y, "mul": x * y} + + i0 = t(0, 1, 2) + i1 = t(2, 1, 0) + + o = func(i0, i1) + o0 = o["add"] + o1 = o["mul"] + assert dr.all(t(2, 2, 2) == o0) + assert dr.all(t(0, 1, 0) == o1) + + i0 = t(1, 2, 3) + i1 = t(3, 2, 1) + + o = func(i0, i1) + o0 = o["add"] + o1 = o["mul"] + assert dr.all(t(4, 4, 4) == o0) + assert dr.all(t(3, 4, 3) == o1) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test06_nested_tuple(t, auto_opaque): + """ + Tests that returning nested tuples from forzen functions is possible. + """ + + @dr.freeze(auto_opaque=auto_opaque) + def func(x): + return (x + 1, x + 2, (x + 3, x + 4)) + + i0 = t(0, 1, 2) + + (o0, o1, (o2, o3)) = func(i0) + assert dr.all(t(1, 2, 3) == o0) + assert dr.all(t(2, 3, 4) == o1) + assert dr.all(t(3, 4, 5) == o2) + assert dr.all(t(4, 5, 6) == o3) + + i0 = t(1, 2, 3) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test07_drjit_struct(t, auto_opaque): + """ + Tests that returning custom classes, annotated with ``DRJIT_STRUCT`` from + forzen functions is possible. + """ + + class Point: + x: t + y: t + DRJIT_STRUCT = {"x": t, "y": t} + + @dr.freeze(auto_opaque=auto_opaque) + def func(x): + p = Point() + p.x = x + 1 + p.y = x + 2 + return p + + i0 = t(0, 1, 2) + + o = func(i0) + o0 = o.x + o1 = o.y + assert dr.all(t(1, 2, 3) == o0) + assert dr.all(t(2, 3, 4) == o1) + + i0 = t(1, 2, 3) + + o = func(i0) + o0 = o.x + o1 = o.y + assert dr.all(t(2, 3, 4) == o0) + assert dr.all(t(3, 4, 5) == o1) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test08_dataclass(t, auto_opaque): + """ + Tests that returning custom dataclasses from forzen functions is possible. + """ + + @dataclass + class Point: + x: t + y: t + + @dr.freeze(auto_opaque=auto_opaque) + def func(x): + p = Point(x + 1, x + 2) + return p + + i0 = t(0, 1, 2) + + o = func(i0) + o0 = o.x + o1 = o.y + assert dr.all(t(1, 2, 3) == o0) + assert dr.all(t(2, 3, 4) == o1) + + i0 = t(1, 2, 3) + + o = func(i0) + o0 = o.x + o1 = o.y + assert dr.all(t(2, 3, 4) == o0) + assert dr.all(t(3, 4, 5) == o1) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test09_traverse_cb(t, auto_opaque): + """ + Tests that passing opaque C++ objects to frozen functions is possible. + It should not be possible to return these from frozen functions. + """ + pkg = get_pkg(t) + Sampler = pkg.Sampler + + def func(sampler): + return sampler.next() + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + sampler_frozen = Sampler(10) + sampler_func = Sampler(10) + + result1_frozen = frozen(sampler_frozen) + result1_func = func(sampler_func) + assert dr.allclose(result1_frozen, result1_func) + + sampler_frozen = Sampler(10) + sampler_func = Sampler(10) + + result2_frozen = frozen(sampler_frozen) + result2_func = func(sampler_func) + assert dr.allclose(result2_frozen, result2_func) + + assert frozen.n_recordings == 1 + + result3_frozen = frozen(sampler_frozen) + result3_func = func(sampler_func) + assert dr.allclose(result3_func, result3_frozen) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test10_scatter(t, auto_opaque): + """ + Tests that it is possible to scatter to the input of a frozen function, + while leaving variables depending on the input the same (scattering problem). + """ + + @dr.freeze(auto_opaque=auto_opaque) + def func(x): + dr.scatter(x, 0, dr.arange(t, 3)) + + x = t(0, 1, 2) + func(x) + + x = t(0, 1, 2) + y = x + 1 + z = x + w = t(x) + + func(x) + + assert dr.all(t(0, 0, 0) == x) + assert dr.all(t(1, 2, 3) == y) + assert dr.all(t(0, 0, 0) == z) + assert dr.all(t(0, 1, 2) == w) + + +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test11_optimization(t, auto_opaque): + """ + Implements a simple gradient descent optimization of a variable in a + frozen function. This verifies that gradient descent kernels are evaluated + correctly. + """ + + @dr.freeze(auto_opaque=auto_opaque) + def func(state, ref): + for k, x in state.items(): + dr.enable_grad(x) + loss = dr.mean(dr.square(x - ref)) + + dr.backward(loss) + + grad = dr.grad(x) + dr.disable_grad(x) + state[k] = x - grad + + state = {"x": t(0, 0, 0, 0)} + + ref = t(1, 1, 1, 1) + func(state, ref) + assert dr.allclose(t(0.5, 0.5, 0.5, 0.5), state["x"]) + + state = {"x": t(0, 0, 0, 0)} + ref = t(1, 1, 1, 1) + func(state, ref) + + assert dr.allclose(t(0.5, 0.5, 0.5, 0.5), state["x"]) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test12_resized(t, auto_opaque): + """ + Tests that it is possible to call a frozen function with inputs of different + size compared to the recording without having to re-trace the function. + """ + + @dr.freeze(auto_opaque=auto_opaque) + def func(x, y): + return x + y + + i0 = t(0, 1, 2) + i1 = t(2, 1, 0) + + o0 = func(i0, i1) + assert dr.all(t(2, 2, 2) == o0) + + i0 = dr.arange(t, 64) + dr.opaque(t, 0) + i1 = dr.arange(t, 64) + dr.opaque(t, 0) + r0 = i0 + i1 + dr.eval(i0, i1, r0) + + o0 = func(i0, i1) + assert dr.all(r0 == o0) + assert func.n_recordings == 1 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test13_changed_input_dict(t, auto_opaque): + """ + Test that it is possible to pass a dictionary to a frozen function, that is + inserting the result at a new key in said dictionary. This ensures that the + input is evaluated correctly, and the dictionary is back-assigned to the input. + """ + + @dr.freeze(auto_opaque=auto_opaque) + def func(d: dict): + d["y"] = d["x"] + 1 + + x = t(0, 1, 2) + d = {"x": x} + + func(d) + assert dr.all(t(1, 2, 3) == d["y"]) + + x = t(1, 2, 3) + d = {"x": x} + + func(d) + assert dr.all(t(2, 3, 4) == d["y"]) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test14_changed_input_dataclass(t, auto_opaque): + """ + Tests that it is possible to asing to the input of a dataclass inside a + frozen function. This also relies on correct back-assignment of the input. + """ + + @dataclass + class Point: + x: t + + @dr.freeze(auto_opaque=auto_opaque) + def func(p: Point): + p.x = p.x + 1 + + p = Point(x=t(0, 1, 2)) + + func(p) + assert dr.all(t(1, 2, 3) == p.x) + + p = Point(x=t(1, 2, 3)) + + func(p) + assert dr.all(t(2, 3, 4) == p.x) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test15_kwargs(t, auto_opaque): + """ + Tests that it is possible to pass keyword arguments to a frozen function + that modifies them. + """ + + @dr.freeze(auto_opaque=auto_opaque) + def func(x=t(0, 1, 2)): + return x + 1 + + y = func(x=t(0, 1, 2)) + assert dr.all(t(1, 2, 3) == y) + + y = func(x=t(1, 2, 3)) + assert dr.all(t(2, 3, 4) == y) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test16_opaque(t, auto_opaque): + """ + Tests that changing from an opaque (1-sized array) to an array of size + larger than 1 causes the funcion to be re-traced. This is necessary, because + different kernels are compiled for the two cases. + """ + + @dr.freeze(auto_opaque=auto_opaque) + def func(x, y): + return x + y + + x = t(0, 1, 2) + dr.set_label(x, "x") + y = dr.opaque(t, 1) + dr.set_label(y, "y") + z = func(x, y) + assert dr.all(t(1, 2, 3) == z) + + x = t(1, 2, 3) + y = t(1, 2, 3) + z = func(x, y) + assert dr.all(t(2, 4, 6) == z) + + assert func.n_recordings == 2 + + +@pytest.test_arrays("float32, jit, -is_diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test17_performance(t, auto_opaque): + """ + Tests the performance of a frozen function versus a non-frozen function. + """ + import time + + n = 1024 + n_iter = 1_000 + n_iter_warmeup = 10 + + def func(x, y): + z = 0.5 + result = dr.fma(dr.square(x), y, z) + result = dr.sqrt(dr.abs(result) + dr.power(result, 10)) + result = dr.log(1 + result) + return result + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for name, fn in [("normal", func), ("frozen", frozen)]: + x = dr.arange(t, n) # + dr.opaque(t, i) + y = dr.arange(t, n) # + dr.opaque(t, i) + dr.eval(x, y) + for i in range(n_iter + n_iter_warmeup): + if i == n_iter_warmeup: + t0 = time.time() + + result = fn(x, y) + + dr.eval(result) + + dr.sync_thread() + elapsed = time.time() - t0 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test18_aliasing(t, auto_opaque): + """ + Tests that changing the inputs from being the same variable to two different + variables causes the function to be re-traced. + """ + + @dr.freeze(auto_opaque=auto_opaque) + def func(x, y): + return x + y + + x = t(0, 1, 2) + y = x + z = func(x, y) + assert dr.all(t(0, 2, 4) == z) + + x = t(1, 2, 3) + y = x + z = func(x, y) + assert dr.all(t(2, 4, 6) == z) + + x = t(1, 2, 3) + y = t(2, 3, 4) + z = func(x, y) + assert dr.all(t(3, 5, 7) == z) + assert func.n_recordings == 2 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test19_non_jit_types(t, auto_opaque): + """ + Tests that it is possible to pass non-jit types such as integers to frozen + functions. + """ + + def func(x, y): + return x + y + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + x = t(1, 2, 3) + y = i + + res = frozen(x, y) + ref = func(x, y) + assert dr.all(res == ref) + + assert frozen.n_recordings == 3 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test20_literal(t, auto_opaque): + """ + Test that drjit literals, passed to frozen functions do not cause the + function to be re-traced if they change. This is enabled by making the input + opaque. + """ + + @dr.freeze(auto_opaque=auto_opaque) + def func(x, y): + z = x + y + w = t(1) + return z, w + + # Literals + x = t(0, 1, 2) + dr.set_label(x, "x") + y = t(1) + dr.set_label(y, "y") + z, w = func(x, y) + assert dr.all(z == t(1, 2, 3)) + assert dr.all(w == t(1)) + + x = t(0, 1, 2) + y = t(1) + z, w = func(x, y) + assert dr.all(z == t(1, 2, 3)) + assert dr.all(w == t(1)) + + assert func.n_recordings == 1 + + x = t(0, 1, 2) + y = t(2) + z = func(x, y) + + if auto_opaque: + assert func.n_recordings == 2 + else: + assert func.n_recordings == 1 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test21_pointers(t, auto_opaque): + """ + Test that it is possible to gather from a same-sized variable. This tests + the kernel size inference algorithm as well as having two kernels in a + frozen function. + """ + UInt32 = dr.uint32_array_t(t) + + @dr.freeze(auto_opaque=auto_opaque) + def func(x): + idx = dr.arange(UInt32, 0, dr.width(x), 3) + + return dr.gather(t, x, idx) + + y = func(t(0, 1, 2, 3, 4, 5, 6)) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test22_gather_memcpy(t, auto_opaque): + """ + The gather operation might be elided in favor of a memcpy + if the index is a literal of size 1. + The source of the memcpy is however not known to the recording + mechansim as it might index into the source array. + """ + + def func(x, idx: int): + idx = t(idx) + return dr.gather(t, x, idx) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + x = dr.arange(t, i, 3 + i) + dr.make_opaque(x) + ref = func(x, 2) + result = frozen(x, 2) + + assert dr.all(ref == result) + + assert frozen.n_recordings == 1 + + +def get_pkg(t): + with dr.detail.scoped_rtld_deepbind(): + m = pytest.importorskip("call_ext") + backend = dr.backend_v(t) + if backend == dr.JitBackend.LLVM: + return m.llvm + elif backend == dr.JitBackend.CUDA: + return m.cuda + + +@pytest.mark.parametrize("symbolic", [True]) +@pytest.test_arrays("float32, jit, -is_diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test23_vcall(t, auto_opaque, symbolic): + """ + Tests a basic symbolic vcall being called inside a frozen function. + """ + pkg = get_pkg(t) + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + Mask = dr.mask_t(t) + a, b = A(), B() + + c = BasePtr(a, a, None, a, a) + + xi = t(1, 2, 8, 3, 4) + yi = t(5, 6, 8, 7, 8) + + @dr.freeze(auto_opaque=auto_opaque) + def func(c, xi, yi): + return c.f(xi, yi) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + xo, yo = func(c, xi, yi) + + assert dr.all(xo == t(10, 12, 0, 14, 16)) + assert dr.all(yo == t(-1, -2, 0, -3, -4)) + + c = BasePtr(a, a, None, b, b) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + xo, yo = func(c, xi, yi) + + assert func.n_recordings == 1 + + assert dr.all(xo == t(10, 12, 0, 21, 24)) + assert dr.all(yo == t(-1, -2, 0, 3, 4)) + + +@pytest.mark.parametrize("symbolic", [True]) +@pytest.mark.parametrize("optimize", [True, False]) +@pytest.mark.parametrize("opaque", [True, False]) +@pytest.test_arrays("float32, -is_diff, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test24_vcall_optimize(t, auto_opaque, symbolic, optimize, opaque): + """ + Test a basic vcall being called inside a frozen function, with the + "OptimizeCalls" flag either being set or not set. As well as opaque and + non-opaque inputs. + """ + pkg = get_pkg(t) + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + Mask = dr.mask_t(t) + a, b = B(), B() + + dr.set_label(A.opaque, "A.opaque") + dr.set_label(B.opaque, "B.opaque") + + a.value = t(2) + b.value = t(3) + + if opaque: + dr.make_opaque(a.value, b.value) + + c = BasePtr(a, a, None, a, a) + dr.set_label(c, "c") + + x = t(1, 2, 8, 3, 4) + dr.set_label(x, "x") + + def func(c, xi): + return c.g(xi) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + with dr.scoped_set_flag(dr.JitFlag.OptimizeCalls, optimize): + xo = frozen(c, x) + + assert dr.all(xo == func(c, x)) + + a.value = t(3) + b.value = t(2) + + if opaque: + dr.make_opaque(a.value, b.value) + + c = BasePtr(a, a, None, b, b) + dr.set_label(c, "c") + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + with dr.scoped_set_flag(dr.JitFlag.OptimizeCalls, optimize): + xo = frozen(c, x) + + if not auto_opaque: + assert frozen.n_recordings == 1 + assert dr.all(xo == func(c, x)) + + +@pytest.mark.parametrize("symbolic", [True]) +@pytest.mark.parametrize("optimize", [True, False]) +@pytest.mark.parametrize("opaque", [True, False]) +@pytest.test_arrays("float32, -is_diff, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test25_multiple_vcalls(t, auto_opaque, symbolic, optimize, opaque): + """ + Test calling multiple vcalls in a frozen function, where the result of the + first is used as the input to the second. + """ + pkg = get_pkg(t) + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + Mask = dr.mask_t(t) + a, b = B(), B() + + dr.set_label(A.opaque, "A.opaque") + dr.set_label(B.opaque, "B.opaque") + + a.value = t(2) + b.value = t(3) + + if opaque: + dr.make_opaque(a.value, b.value) + + c = BasePtr(a, a, None, a, a) + dr.set_label(c, "c") + + x = t(1, 2, 8, 3, 4) + dr.set_label(x, "x") + + def func(c, xi): + x = c.h(xi) + dr.make_opaque(x) + return c.g(x) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + with dr.scoped_set_flag(dr.JitFlag.OptimizeCalls, optimize): + xo = frozen(c, x) + + assert dr.all(xo == func(c, x)) + + a.value = t(3) + b.value = t(2) + + if opaque: + dr.make_opaque(a.value, b.value) + + c = BasePtr(a, a, None, b, b) + dr.set_label(c, "c") + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, symbolic): + with dr.scoped_set_flag(dr.JitFlag.OptimizeCalls, optimize): + xo = frozen(c, x) + + if not auto_opaque: + assert frozen.n_recordings == 1 + + assert dr.all(xo == func(c, x)) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test26_freeze(t, auto_opaque): + """ + Test freezing a simple frozen function. + """ + UInt32 = dr.uint32_array_t(t) + Float = dr.float32_array_t(t) + + @dr.freeze(auto_opaque=auto_opaque) + def my_kernel(x): + x_int = UInt32(x) + result = x * x + result_int = UInt32(result) + + return result, x_int, result_int + + for i in range(3): + x = Float([1.0, 2.0, 3.0]) + dr.opaque(Float, i) + + y, x_int, y_int = my_kernel(x) + dr.schedule(y, x_int, y_int) + assert dr.allclose(y, dr.square(x)) + assert dr.allclose(x_int, UInt32(x)) + assert dr.allclose(y_int, UInt32(y)) + + +@pytest.mark.parametrize("freeze_first", (True, False)) +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test27_calling_frozen_from_frozen(t, auto_opaque, freeze_first): + """ + Test calling a frozen function from within another frozen function. + The inner frozen function should behave as a normal function. + """ + mod = sys.modules[t.__module__] + Float = mod.Float32 + Array3f = mod.Array3f + n = 37 + x = dr.full(Float, 1.5, n) + dr.opaque(Float, 2) + y = dr.full(Float, 0.5, n) + dr.opaque(Float, 10) + dr.eval(x, y) + + @dr.freeze(auto_opaque=auto_opaque) + def fun1(x): + return dr.square(x) + + @dr.freeze(auto_opaque=auto_opaque) + def fun2(x, y): + return fun1(x) + fun1(y) + + # Calling a frozen function from a frozen function. + if freeze_first: + dr.eval(fun1(x)) + + result1 = fun2(x, y) + assert dr.allclose(result1, dr.square(x) + dr.square(y)) + + if not freeze_first: + # If the nested function hasn't been recorded yet, calling it + # while freezing the outer function shouldn't freeze it with + # those arguments. + # In other words, any freezing mechanism should be completely + # disabled while recording a frozen function. + # assert fun1.frozen.kernels is None + + # We should therefore be able to freeze `fun1` with a different + # type of argument, and both `fun1` and `fun2` should work fine. + result2 = fun1(Array3f(0.5, x, y)) + assert dr.allclose(result2, Array3f(0.5 * 0.5, dr.square(x), dr.square(y))) + + result3 = fun2(2 * x, 0.5 * y) + assert dr.allclose(result3, dr.square(2 * x) + dr.square(0.5 * y)) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test28_recorded_size(t, auto_opaque): + """ + Tests that a frozen function, producing a variable with a constant size, + can be replayed and produces an output of the same size. + """ + mod = sys.modules[t.__module__] + UInt32 = mod.UInt32 + Float = mod.Float + + @dr.freeze(auto_opaque=auto_opaque) + def fun(a): + x = t(dr.linspace(Float, -1, 1, 10)) + a + source = x + 2 * x + # source = get_single_entry(x + 2 * x) + index = dr.arange(UInt32, dr.width(source)) + active = index % UInt32(2) != 0 + + return dr.gather(Float, source, index, active) + + a = t(0.1) + res1 = fun(a) + res2 = fun(a) + res3 = fun(a) + + assert dr.allclose(res1, res2) + assert dr.allclose(res1, res3) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test29_with_gathers(t, auto_opaque): + """ + Test gathering from an array at every second index in a frozen function. + """ + import numpy as np + + n = 20 + mod = sys.modules[t.__module__] + UInt32 = mod.UInt32 + + rng = np.random.default_rng(seed=1234) + shape = tuple(reversed(dr.shape(dr.zeros(t, n)))) + + def fun(x, idx): + active = idx % 2 != 0 + source = get_single_entry(x) + return dr.gather(type(source), source, idx, active=active) + + fun_frozen = dr.freeze(fun, auto_opaque=auto_opaque) + + # 1. Recording call + x1 = t(rng.uniform(low=-1, high=1, size=shape)) + idx1 = dr.arange(UInt32, n) + result1 = fun_frozen(x1, idx1) + assert dr.allclose(result1, fun(x1, idx1)) + + # 2. Different source as during recording + x2 = t(rng.uniform(low=-2, high=-1, size=shape)) + idx2 = idx1 + + result2 = fun_frozen(x2, idx2) + assert dr.allclose(result2, fun(x2, idx2)) + + x3 = x2 + idx3 = UInt32([i for i in reversed(range(n))]) + result3 = fun_frozen(x3, idx3) + assert dr.allclose(result3, fun(x3, idx3)) + + # 3. Same source as during recording + result4 = fun_frozen(x1, idx1) + assert dr.allclose(result4, result1) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test30_scatter_with_op(t, auto_opaque): + """ + Tests scattering into the input of a frozen function. + """ + import numpy as np + + n = 16 + mod = sys.modules[t.__module__] + UInt32 = mod.UInt32 + + rng = np.random.default_rng(seed=1234) + + def func(x, idx): + active = idx % 2 != 0 + + result = x - 0.5 + dr.scatter(x, result, idx, active=active) + return result + + func_frozen = dr.freeze(func, auto_opaque=auto_opaque) + + # 1. Recording call + x1 = t(rng.uniform(low=-1, high=1, size=[n])) + x1_copy = t(x1) + x1_copy_copy = t(x1) + idx1 = dr.arange(UInt32, n) + + result1 = func_frozen(x1, idx1) + + # assert dr.allclose(x1, x1_copy) + assert dr.allclose(result1, func(x1_copy, idx1)) + + # 2. Different source as during recording + # TODO: problem: during trace, the actual x1 Python variable changes + # from index r2 to index r12 as a result of the `scatter`. + # But in subsequent launches, even if we successfully create a new + # output buffer equivalent to r12, it doesn't get assigned to `x2`. + x2 = t(rng.uniform(low=-2, high=-1, size=[n])) + x2_copy = t(x2) + idx2 = idx1 + + result2 = func_frozen(x2, idx2) + assert dr.allclose(result2, func(x2_copy, idx2)) + # assert dr.allclose(x2, x2_copy) + + x3 = x2 + x3_copy = t(x3) + idx3 = UInt32([i for i in reversed(range(n))]) + result3 = func_frozen(x3, idx3) + assert dr.allclose(result3, func(x3_copy, idx3)) + # assert dr.allclose(x3, x3_copy) + + # # 3. Same source as during recording + result4 = func_frozen(x1_copy_copy, idx1) + assert dr.allclose(result4, result1) + # # assert dr.allclose(x1_copy_copy, x1) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test31_with_gather_and_scatter(t, auto_opaque): + """ + Tests a combination of scatters and gathers in a frozen function. + """ + + import numpy as np + + n = 20 + mod = sys.modules[t.__module__] + UInt32 = mod.UInt32 + + rng = np.random.default_rng(seed=1234) + shape = tuple(reversed(dr.shape(dr.zeros(t, n)))) + + def fun(x, idx): + active = idx % 2 != 0 + dest = get_single_entry(x) + + values = dr.gather(UInt32, idx, idx, active=active) + values = type(dest)(values) + dr.scatter(dest, values, idx, active=active) + return dest, values + + fun_frozen = dr.freeze(fun, auto_opaque=auto_opaque) + + # 1. Recording call + x1 = t(rng.uniform(low=-1, high=1, size=shape)) + x1_copy = t(x1) + x1_copy_copy = t(x1) + idx1 = dr.arange(UInt32, n) + + result1 = fun_frozen(x1, idx1) + assert dr.allclose(result1, fun(x1_copy, idx1)) + assert dr.allclose(x1, x1_copy) + + # 2. Different source as during recording + x2 = t(rng.uniform(low=-2, high=-1, size=shape)) + x2_copy = t(x2) + idx2 = idx1 + + result2 = fun_frozen(x2, idx2) + assert dr.allclose(result2, fun(x2_copy, idx2)) + assert dr.allclose(x2, x2_copy) + + x3 = x2 + x3_copy = t(x3) + idx3 = UInt32([i for i in reversed(range(n))]) + result3 = fun_frozen(x3, idx3) + assert dr.allclose(result3, fun(x3_copy, idx3)) + assert dr.allclose(x3, x3_copy) + + # 3. Same source as during recording + result4 = fun_frozen(x1_copy_copy, idx1) + assert dr.allclose(result4, result1) + assert dr.allclose(x1_copy_copy, x1) + + +@pytest.mark.parametrize("relative_size", ["<", "=", ">"]) +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test32_gather_only_pointer_as_input(t, auto_opaque, relative_size): + """ + Tests that it is possible to infer the launch size of kernels, if the width + of the resulting variable is a multiple/fraction of the variables from which + the result was gathered. + """ + mod = sys.modules[t.__module__] + Array3f = mod.Array3f + Float = mod.Float32 + UInt32 = mod.UInt32 + + import numpy as np + + rng = np.random.default_rng(seed=1234) + + if relative_size == "<": + + def fun(v): + idx = dr.arange(UInt32, 0, dr.width(v), 3) + return Array3f( + dr.gather(Float, v, idx), + dr.gather(Float, v, idx + 1), + dr.gather(Float, v, idx + 2), + ) + + elif relative_size == "=": + + def fun(v): + idx = dr.arange(UInt32, 0, dr.width(v)) // 2 + return Array3f( + dr.gather(Float, v, idx), + dr.gather(Float, v, idx + 1), + dr.gather(Float, v, idx + 2), + ) + + elif relative_size == ">": + + def fun(v): + max_width = dr.width(v) + idx = dr.arange(UInt32, 0, 5 * max_width) + # TODO(!): what can we do against this literal being baked into the kernel? + active = (idx + 2) < max_width + return Array3f( + dr.gather(Float, v, idx, active=active), + dr.gather(Float, v, idx + 1, active=active), + dr.gather(Float, v, idx + 2, active=active), + ) + + fun_freeze = dr.freeze(fun, auto_opaque=auto_opaque) + + def check_results(v, result): + size = v.size + if relative_size == "<": + expected = v.T + if relative_size == "=": + idx = np.arange(0, size) // 2 + expected = v.ravel() + expected = np.stack( + [ + expected[idx], + expected[idx + 1], + expected[idx + 2], + ], + axis=0, + ) + elif relative_size == ">": + idx = np.arange(0, 5 * size) + mask = (idx + 2) < size + expected = v.ravel() + expected = np.stack( + [ + np.where(mask, expected[(idx) % size], 0), + np.where(mask, expected[(idx + 1) % size], 0), + np.where(mask, expected[(idx + 2) % size], 0), + ], + axis=0, + ) + + assert np.allclose(result.numpy(), expected) + + # Note: Does not fail for n=1 + n = 7 + + for i in range(3): + v = rng.uniform(size=[n, 3]) + result = fun(Float(v.ravel())) + check_results(v, result) + + for i in range(10): + if i <= 5: + n_lanes = n + else: + n_lanes = n + i + + v = rng.uniform(size=[n_lanes, 3]) + result = fun_freeze(Float(v.ravel())) + + expected_width = { + "<": n_lanes, + "=": n_lanes * 3, + ">": n_lanes * 3 * 5, + }[relative_size] + + # if i == 0: + # assert len(fun_freeze.frozen.kernels) + # for kernel in fun_freeze.frozen.kernels.values(): + # assert kernel.original_input_size == n * 3 + # if relative_size == "<": + # assert kernel.original_launch_size == expected_width + # assert kernel.original_launch_size_ratio == (False, 3, True) + # elif relative_size == "=": + # assert kernel.original_launch_size == expected_width + # assert kernel.original_launch_size_ratio == (False, 1, True) + # else: + # assert kernel.original_launch_size == expected_width + # assert kernel.original_launch_size_ratio == (True, 5, True) + + assert dr.width(result) == expected_width + if relative_size == ">" and n_lanes != n: + pytest.xfail( + reason="The width() of the original input is baked into the kernel to compute the `active` mask during the first launch, so results are incorrect once the width changes." + ) + + check_results(v, result) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test33_multiple_kernels(t, auto_opaque): + def fn(x: dr.ArrayBase, y: dr.ArrayBase, flag: bool): + + # First kernel uses only `x` + quantity = 0.5 if flag else -0.5 + intermediate1 = x + quantity + intermediate2 = x * quantity + dr.eval(intermediate1, intermediate2) + + # Second kernel uses `x`, `y` and one of the intermediate result + result = intermediate2 + y + + # The function returns some mix of outputs + return intermediate1, None, y, result + + n = 15 + x = dr.full(t, 1.5, n) + dr.opaque(t, 0.2) + y = dr.full(t, 0.5, n) + dr.opaque(t, 0.1) + dr.eval(x, y) + + ref_results = fn(x, y, flag=True) + dr.eval(ref_results) + + fn_frozen = dr.freeze(fn, auto_opaque=auto_opaque) + for _ in range(2): + results = fn_frozen(x, y, flag=True) + assert dr.allclose(results[0], ref_results[0]) + assert results[1] is None + assert dr.allclose(results[2], y) + assert dr.allclose(results[3], ref_results[3]) + + for i in range(4): + new_y = y + float(i) + new_flag = (i % 2) == 0 + results = fn_frozen(x, new_y, flag=new_flag) + ref_results = fn(x, new_y, flag = new_flag) + assert dr.allclose(results[0], ref_results[0]) + assert results[1] is None + assert dr.allclose(results[2], new_y) + assert dr.allclose(results[3], ref_results[3]) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test34_global_flag(t, auto_opaque): + Float = t + + @dr.freeze(auto_opaque=auto_opaque) + def my_fn(a, b, c=0.5): + return a + b + c + + # Recording + one = Float([1.0] * 9) + result1 = my_fn(one, one, c=0.1) + assert dr.allclose(result1, 2.1) + + # Can change the type of an input + result2 = my_fn(one, one, c=Float(0.6)) + assert dr.allclose(result2, 2.6) + + assert my_fn.n_recordings == 2 + + # Disable frozen kernels globally, now the freezing + # logic should be completely bypassed + with dr.scoped_set_flag(dr.JitFlag.KernelFreezing, False): + result3 = my_fn(one, one, c=0.9) + assert dr.allclose(result3, 2.9) + + +# @pytest.mark.parametrize("struct_style", ["drjit", "dataclass"]) +@pytest.mark.parametrize("struct_style", ["drjit", "dataclass"]) +# @pytest.test_arrays("float32, llvm, jit, -is_diff, shape=(*)") +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test35_return_types(t, auto_opaque, struct_style): + # WARN: only working on CUDA! + mod = sys.modules[t.__module__] + Float = t + Array3f = mod.Array3f + UInt32 = mod.UInt32 + + import numpy as np + + if struct_style == "drjit": + + class ToyDataclass: + DRJIT_STRUCT: dict = {"a": Float, "b": Float} + a: Float + b: Float + + def __init__(self, a=None, b=None): + self.a = a + self.b = b + + else: + assert struct_style == "dataclass" + + @dataclass(frozen=True) + class ToyDataclass: + a: Float + b: Float + + # 1. Many different types + @dr.freeze(auto_opaque=auto_opaque) + def toy1(x: Float) -> Float: + y = x**2 + dr.sin(x) + z = x**2 + dr.cos(x) + return (x, y, z, ToyDataclass(a=x, b=y), {"x": x, "yi": UInt32(y)}, [[[[x]]]]) + + for i in range(2): + input = Float(np.full(17, i)) + # input = dr.full(Float, i, 17) + result = toy1(input) + assert isinstance(result[0], Float) + assert isinstance(result[1], Float) + assert isinstance(result[2], Float) + assert isinstance(result[3], ToyDataclass) + assert isinstance(result[4], dict) + assert result[4].keys() == set(("x", "yi")) + assert isinstance(result[4]["x"], Float) + assert isinstance(result[4]["yi"], UInt32) + assert isinstance(result[5], list) + assert isinstance(result[5][0], list) + assert isinstance(result[5][0][0], list) + assert isinstance(result[5][0][0][0], list) + + # 2. Many different types + @dr.freeze(auto_opaque=auto_opaque) + def toy2(x: Float, target: Float) -> Float: + dr.scatter(target, 0.5 + x, dr.arange(UInt32, dr.width(x))) + return None + + for i in range(3): + input = Float([i] * 17) + target = dr.opaque(Float, 0, dr.width(input)) + # target = dr.full(Float, 0, dr.width(input)) + # target = dr.empty(Float, dr.width(input)) + + result = toy2(input, target) + assert dr.allclose(target, 0.5 + input) + assert result is None + + # 3. DRJIT_STRUCT as input and returning nested dictionaries + @dr.freeze(auto_opaque=auto_opaque) + def toy3(x: Float, y: ToyDataclass) -> Float: + x_d = dr.detach(x, preserve_type=False) + return { + "a": x, + "b": (x, UInt32(2 * y.a + y.b)), + "c": None, + "d": { + "d1": x + x, + "d2": Array3f(x_d, -x_d, 2 * x_d), + "d3": None, + "d4": {}, + "d5": tuple(), + "d6": list(), + "d7": ToyDataclass(a=x, b=2 * x), + }, + "e": [x, {"e1": None}], + } + + for i in range(3): + input = Float([i] * 5) + input2 = ToyDataclass(a=input, b=Float(4.0)) + result = toy3(input, input2) + assert isinstance(result, dict) + assert isinstance(result["a"], Float) + assert isinstance(result["b"], tuple) + assert isinstance(result["b"][0], Float) + assert isinstance(result["b"][1], UInt32) + assert result["c"] is None + assert isinstance(result["d"], dict) + assert isinstance(result["d"]["d1"], Float) + assert isinstance(result["d"]["d2"], Array3f) + assert result["d"]["d3"] is None + assert isinstance(result["d"]["d4"], dict) and len(result["d"]["d4"]) == 0 + assert isinstance(result["d"]["d5"], tuple) and len(result["d"]["d5"]) == 0 + assert isinstance(result["d"]["d6"], list) and len(result["d"]["d6"]) == 0 + assert isinstance(result["d"]["d7"], ToyDataclass) + assert dr.allclose(result["d"]["d7"].a, input) + assert dr.allclose(result["d"]["d7"].b, 2 * input) + assert isinstance(result["e"], list) + assert isinstance(result["e"][0], Float) + assert isinstance(result["e"][1], dict) + assert result["e"][1]["e1"] is None + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test36_drjit_struct_and_matrix(t, auto_opaque): + package = sys.modules[t.__module__] + Float = package.Float + Array4f = package.Array4f + Matrix4f = package.Matrix4f + + class MyTransform4f: + DRJIT_STRUCT = { + "matrix": Matrix4f, + "inverse": Matrix4f, + } + + def __init__(self, matrix: Matrix4f = None, inverse: Matrix4f = None): + self.matrix = matrix + self.inverse = inverse + + @dataclass(frozen=False) + class Camera: + to_world: MyTransform4f + + @dataclass(frozen=False) + class Batch: + camera: Camera + value: float = 0.5 + offset: float = 0.5 + + @dataclass(frozen=False) + class Result: + value: Float + constant: int = 5 + + def fun(batch: Batch, x: Array4f): + res1 = batch.camera.to_world.matrix @ x + res2 = batch.camera.to_world.matrix @ x + batch.offset + res3 = batch.value + x + res4 = Result(value=batch.value) + return res1, res2, res3, res4 + + fun_frozen = dr.freeze(fun, auto_opaque=auto_opaque) + + n = 7 + for i in range(4): + x = Array4f( + *(dr.linspace(Float, 0, 1, n) + dr.opaque(Float, i) + k for k in range(4)) + ) + mat = Matrix4f( + *( + dr.linspace(Float, 0, 1, n) + dr.opaque(Float, i) + ii + jj + for jj in range(4) + for ii in range(4) + ) + ) + trafo = MyTransform4f() + trafo.matrix = mat + trafo.inverse = dr.rcp(mat) + + batch = Batch( + camera=Camera(to_world=trafo), + value=dr.linspace(Float, -1, 0, n) - dr.opaque(Float, i), + ) + # dr.eval(x, trafo, batch.value) + + results = fun_frozen(batch, x) + expected = fun(batch, x) + + assert len(results) == len(expected) + for result_i, (value, expected) in enumerate(zip(results, expected)): + + assert type(value) == type(expected) + if isinstance(value, Result): + value = value.value + expected = expected.value + assert dr.allclose(value, expected), str(result_i) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test37_with_dataclass_in_out(t, auto_opaque): + mod = sys.modules[t.__module__] + Int32 = mod.Int32 + UInt32 = mod.UInt32 + Bool = mod.Bool + + @dataclass(frozen=False) + class MyRecord: + step_in_segment: Int32 + total_steps: UInt32 + short_segment: Bool + + def acc_fn(record: MyRecord): + record.step_in_segment += Int32(2) + return Int32(record.total_steps + record.step_in_segment) + + # Initialize MyRecord + n_rays = 100 + record = MyRecord( + step_in_segment=UInt32([1] * n_rays), + total_steps=UInt32([0] * n_rays), + short_segment=dr.zeros(Bool, n_rays), + ) + + # Create frozen kernel that contains another function + frozen_acc_fn = dr.freeze(acc_fn, auto_opaque=auto_opaque) + + accumulation = dr.zeros(UInt32, n_rays) + n_iter = 12 + for _ in range(n_iter): + accumulation += frozen_acc_fn(record) + + expected = 0 + for i in range(n_iter): + expected += 0 + 2 * (i + 1) + 1 + assert dr.all(accumulation == expected) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test38_allocated_scratch_buffer(t, auto_opaque): + """ + Frozen functions may want to allocate some scratch space, scatter to it + in a first kernel, and read / use the values later on. As long as the + size of the scratch space can be guessed (e.g. a multiple of the launch width, + or matching the width of an existing input), we should be able to support it. + + On the other hand, the "scattering to an unknown buffer" pattern may actually + be scattering to an actual pre-existing buffer, which the user simply forgot + to include in the `state` lambda. In order to catch that case, we at least + check that the "scratch buffer" was read from in one of the kernels. + Otherwise, we assume it was meant as a side-effect into a pre-existing buffer. + """ + mod = sys.modules[t.__module__] + # dr.set_flag(dr.JitFlag.KernelFreezing, False) + UInt32 = mod.UInt32 + + # Note: we are going through an object / method, otherwise the closure + # checker would already catch the `forgotten_target_buffer` usage. + class Model: + DRJIT_STRUCT = { + "some_state": UInt32, + # "forgotten_target_buffer": UInt32, + } + + def __init__(self): + self.some_state = UInt32([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + self.forgotten_target_buffer = self.some_state + 1 + dr.eval(self.some_state, self.forgotten_target_buffer) + + @dr.freeze(auto_opaque=auto_opaque) + def fn1(self, x): + # Note: assuming here that the width of `forgotten_target_buffer` doesn't change + index = dr.arange(UInt32, dr.width(x)) % dr.width( + self.forgotten_target_buffer + ) + dr.scatter(self.forgotten_target_buffer, x, index) + + return 2 * x + + @dr.freeze(auto_opaque=auto_opaque) + def fn2(self, x): + # Scratch buffer with width equal to a state variable + scratch = dr.zeros(UInt32, dr.width(self.some_state)) + # Kernel 1: write to `scratch` + index = dr.arange(UInt32, dr.width(x)) % dr.width(self.some_state) + dr.scatter(scratch, x, index) + # Kernel 2: use values from `scratch` directly + result = dr.square(scratch) + # We don't actually return `scratch`, its lifetime is limited to the frozen function. + return result + + @dr.freeze(auto_opaque=auto_opaque) + def fn3(self, x): + # Scratch buffer with width equal to a state variable + scratch = dr.zeros(UInt32, dr.width(self.some_state)) + # Kernel 1: write to `scratch` + index = dr.arange(UInt32, dr.width(x)) % dr.width(self.some_state) + dr.scatter(scratch, x, index) + # Kernel 2: use values from `scratch` via a gather + result = x + dr.gather(UInt32, scratch, index) + # We don't actually return `scratch`, its lifetime is limited to the frozen function. + return result + + model = Model() + + # Suspicious usage, should not allow it to avoid silent surprising behavior + for i in range(4): + x = UInt32(list(range(i + 2))) + assert dr.width(x) < dr.width(model.forgotten_target_buffer) + + if dr.flag(dr.JitFlag.KernelFreezing): + with pytest.raises(RuntimeError): + result = model.fn1(x) + break + + else: + result = model.fn1(x) + assert dr.allclose(result, 2 * x) + + expected = UInt32(model.some_state + 1) + dr.scatter(expected, x, dr.arange(UInt32, dr.width(x))) + assert dr.allclose(model.forgotten_target_buffer, expected) + + # Expected usage, we should allocate the buffer and allow the launch + for i in range(4): + x = UInt32(list(range(i + 2))) # i+1 + assert dr.width(x) < dr.width(model.some_state) + result = model.fn2(x) + expected = dr.zeros(UInt32, dr.width(model.some_state)) + dr.scatter(expected, x, dr.arange(UInt32, dr.width(x))) + assert dr.allclose(result, dr.square(expected)) + + # Expected usage, we should allocate the buffer and allow the launch + for i in range(4): + x = UInt32(list(range(i + 2))) # i+1 + assert dr.width(x) < dr.width(model.some_state) + result = model.fn3(x) + assert dr.allclose(result, 2 * x) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test39_simple_reductions(t, auto_opaque): + import numpy as np + + mod = sys.modules[t.__module__] + Float = mod.Float32 + n = 37 + + @dr.freeze(auto_opaque=auto_opaque) + def simple_sum(x): + return dr.sum(x) + + @dr.freeze(auto_opaque=auto_opaque) + def simple_product(x): + return dr.prod(x) + + @dr.freeze(auto_opaque=auto_opaque) + def simple_min(x): + return dr.min(x) + + @dr.freeze(auto_opaque=auto_opaque) + def simple_max(x): + return dr.max(x) + + @dr.freeze(auto_opaque=auto_opaque) + def sum_not_returned_wide(x): + return dr.sum(x) + x + + @dr.freeze(auto_opaque=auto_opaque) + def sum_not_returned_single(x): + return dr.sum(x) + 4 + + def check_expected(fn, expected): + result = fn(x) + + assert dr.width(result) == dr.width(expected) + assert isinstance(result, Float) + assert dr.allclose(result, expected) + + for i in range(3): + x = dr.linspace(Float, 0, 1, n) + dr.opaque(Float, i) + + x_np = x.numpy() + check_expected(simple_sum, np.sum(x_np).item()) + check_expected(simple_product, np.prod(x_np).item()) + check_expected(simple_min, np.min(x_np).item()) + check_expected(simple_max, np.max(x_np).item()) + + check_expected(sum_not_returned_wide, np.sum(x_np).item() + x) + check_expected(sum_not_returned_single, np.sum(x_np).item() + 4) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test40_prefix_reductions(t, auto_opaque): + import numpy as np + + mod = sys.modules[t.__module__] + Float = mod.Float32 + n = 37 + + @dr.freeze(auto_opaque=auto_opaque) + def prefix_sum(x): + return dr.prefix_reduce(dr.ReduceOp.Add, x, exclusive=False) + + def check_expected(fn, expected): + result = fn(x) + + assert dr.width(result) == dr.width(expected) + assert isinstance(result, Float) + assert dr.allclose(result, expected) + + for i in range(3): + x = dr.linspace(Float, 0, 1, n) + dr.opaque(Float, i) + + x_np = x.numpy() + check_expected(prefix_sum, Float(np.cumsum(x_np))) + + +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test41_reductions_with_ad(t, auto_opaque): + Float = t + n = 37 + + @dr.freeze(auto_opaque=auto_opaque) + def sum_with_ad(x, width_opaque): + intermediate = 2 * x + 1 + dr.enable_grad(intermediate) + + result = dr.square(intermediate) + + # Unfortunately, as long as we don't support creating opaque values + # within a frozen kernel, we can't use `dr.mean()` directly. + loss = dr.sum(result) / width_opaque + # loss = dr.mean(result) + dr.backward(loss) + + return result, intermediate + + @dr.freeze(auto_opaque=auto_opaque) + def product_with_ad(x): + dr.enable_grad(x) + loss = dr.prod(x) + dr.backward_from(loss) + + for i in range(3): + x = dr.linspace(Float, 0, 1, n + i) + dr.opaque(Float, i) + result, intermediate = sum_with_ad(x, dr.opaque(Float, dr.width(x))) + assert dr.width(result) == n + i + + assert dr.grad_enabled(result) + assert dr.grad_enabled(intermediate) + assert not dr.grad_enabled(x) + intermediate_expected = 2 * x + 1 + assert dr.allclose(intermediate, intermediate_expected) + assert dr.allclose(result, dr.square(intermediate_expected)) + assert sum_with_ad.n_recordings == 1 + assert dr.allclose(dr.grad(result), 0) + assert dr.allclose( + dr.grad(intermediate), 2 * intermediate_expected / dr.width(x) + ) + + for i in range(3): + x = dr.linspace(Float, 0.1, 1, n + i) + dr.opaque(Float, i) + result = product_with_ad(x) + + assert result is None + assert dr.grad_enabled(x) + with dr.suspend_grad(): + expected_grad = dr.prod(x) / x + assert dr.allclose(dr.grad(x), expected_grad) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test42_size_aliasing(t, auto_opaque): + def func(x, y): + return x + 1, y + 2 + + frozen_func = dr.freeze(func, auto_opaque=auto_opaque) + + n = 3 + + for i in range(3): + x = dr.linspace(t, 0, 1, n) + dr.opaque(t, i) + y = dr.linspace(t, 0, 1, n + i) + dr.opaque(t, i) + + result = frozen_func(x, y) + + assert dr.allclose(result, func(x, y)) + + """ + We should have two recordings, one for which the experssions x+1 and y+1 are compiled into the same kernel, + because x and y have the same size and one where they are compiled seperately because their sizes are different. + """ + assert frozen_func.n_recordings == 2 + + +@pytest.test_arrays("float32, jit, -is_diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test43_pointer_aliasing(t, auto_opaque): + """ + Dr.Jit employs a memory cache, which means that two variables + get allocated the same memory region, if one is destroyed + before the other is created. + Since we track variables using their pointers in the `RecordThreadState`, + we have to update the `ptr_to_slot` mapping for new variables. + """ + + n = 4 + + def func(x): + y = x + 1 + dr.make_opaque(y) + for i in range(3): + y = y + 1 + dr.make_opaque(y) + return y + + for i in range(10): + frozen_func = dr.freeze(func, auto_opaque=auto_opaque) + + x = dr.linspace(t, 0, 1, n) + dr.opaque(t, i) + assert dr.allclose(frozen_func(x), func(x)) + + x = dr.linspace(t, 0, 1, n) + dr.opaque(t, i) + assert dr.allclose(frozen_func(x), func(x)) + + +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test44_simple_ad_fully_inside(t, auto_opaque): + mod = sys.modules[t.__module__] + Float = mod.Float + + def my_kernel(x): + dr.enable_grad(x) + + result = x * x + dr.backward(result) + + return result + + for start_enabled in (True, False): + # Re-freeze + my_kernel_frozen = dr.freeze(my_kernel, auto_opaque=auto_opaque) + + for i in range(3): + x = Float([1.0, 2.0, 3.0]) + dr.opaque(Float, i) + if start_enabled: + dr.enable_grad(x) + + y = my_kernel_frozen(x) + grad_x = dr.grad(x) + grad_y = dr.grad(y) + dr.schedule(y, grad_x, grad_y) + assert dr.allclose(y, dr.square(x)) + assert dr.allclose(grad_y, 0) + assert dr.allclose(grad_x, 2 * x) + + # Status of grad_enabled should be restored (side-effect of the function), + # even if it wasn't enabled at first + assert dr.grad_enabled(x) + + +@pytest.mark.parametrize("set_some_literal_grad", (False,)) +@pytest.mark.parametrize("inputs_end_enabled", (True, False)) +@pytest.mark.parametrize("inputs_start_enabled", (True,)) +@pytest.mark.parametrize("params_end_enabled", (False, False)) +@pytest.mark.parametrize("params_start_enabled", (True,)) +@pytest.mark.parametrize("freeze", (True,)) +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test45_suspend_resume( + t, + auto_opaque, + params_start_enabled, + params_end_enabled, + inputs_start_enabled, + inputs_end_enabled, + set_some_literal_grad, + freeze, +): + + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + log_level = dr.log_level() + + # TODO: remove this + # dr.set_flag(dr.JitFlag.KernelFreezing, False) + + class MyModel: + DRJIT_STRUCT = {"params": Float} + + def __init__(self, params): + self.params = params + self.frozen_eval = ( + dr.freeze(type(self).eval, auto_opaque=auto_opaque) + if freeze + else type(self).eval + ) + + def eval( + self, + x: Float, + params_end_enabled: bool, + inputs_end_enabled: bool, + set_some_literal_grad: bool, + ): + idx = dr.arange(UInt32, dr.width(x)) % dr.width(self.params) + latents = dr.gather(Float, self.params, idx) + result = x * latents + + with dr.resume_grad(): + dr.set_grad_enabled(self.params, params_end_enabled) + dr.set_grad_enabled(x, inputs_end_enabled) + if set_some_literal_grad: + # If grads are not enabled, this will get ignored, which is fine + dr.set_grad(x, Float(6.66)) + + return result + + model = MyModel(params=Float([1, 2, 3, 4, 5])) + + for i in range(3): + # Inputs of different widths + x = Float([0.1, 0.2, 0.3, 0.4, 0.5, 0.6] * (i + 1)) + dr.opaque(Float, i) + + dr.set_grad_enabled(model.params, params_start_enabled) + dr.set_grad_enabled(x, inputs_start_enabled) + + dr.eval(x, dr.grad(x)) + + with dr.suspend_grad(): + result = model.frozen_eval( + model, x, params_end_enabled, inputs_end_enabled, set_some_literal_grad + ) + + # dr.eval(result, model.params, dr.grad(model.params)) + assert not dr.grad_enabled(result) + assert dr.grad_enabled(model.params) == params_end_enabled + assert dr.grad_enabled(x) == inputs_end_enabled + + # The frozen function should restore the right width, even for a zero-valued literal. + # The default gradients are a zero-valued literal array + # with a width equal to the array's width + grads = dr.grad(model.params) + assert dr.width(grads) == dr.width(model.params) + assert dr.all(grads == 0) + + grads = dr.grad(x) + assert dr.width(grads) == dr.width(x) + if inputs_end_enabled and set_some_literal_grad: + assert dr.all(grads == 6.66) + else: + assert dr.all(grads == 0) + + if not auto_opaque: + assert model.frozen_eval.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +@pytest.mark.parametrize("freeze", (True,)) +@pytest.mark.parametrize("change_params_width", (False,)) +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test46_with_grad_scatter(t, auto_opaque, freeze: bool, change_params_width): + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + log_level = dr.log_level() + + class Model: + DRJIT_STRUCT = {"params": Float} + + def __init__(self, n): + self.params = Float(list(range(1, n + 1))) + assert dr.width(self.params) == n + dr.enable_grad(self.params) + + def __call__(self): + # Cheeky workaround for the frozen kernel signature checking + pass + + def my_kernel(model, x, opaque_params_width): + idx = dr.arange(UInt32, dr.width(x)) % opaque_params_width + + with dr.resume_grad(): + latents = dr.gather(Float, model.params, idx) + contrib = x * latents + dr.backward_from(contrib) + + return dr.detach(contrib) + + model = Model(5) + my_kernel_frozen = ( + dr.freeze(my_kernel, auto_opaque=auto_opaque) if freeze else my_kernel + ) + + for i in range(6): + # Different width at each iteration + x = Float([1.0, 2.0, 3.0] * (i + 1)) + dr.opaque(Float, i) + + # The frozen kernel should also support the params (and therefore its gradient buffer) + # changing width without issues. + if change_params_width and (i == 3): + model = Model(10) + # Reset gradients + dr.set_grad(model.params, 0) + assert dr.grad_enabled(model.params) + + with dr.suspend_grad(): + y = my_kernel_frozen(model, x, dr.opaque(UInt32, dr.width(model.params))) + assert not dr.grad_enabled(x) + assert not dr.grad_enabled(y) + assert dr.grad_enabled(model.params) + + grad_x = dr.grad(x) + grad_y = dr.grad(y) + grad_p = dr.grad(model.params) + # assert dr.allclose(y, dr.sqr(x)) + + # Expected grads + assert dr.allclose(grad_y, 0) + assert dr.allclose(grad_x, 0) + grad_p_expected = dr.zeros(Float, dr.width(model.params)) + idx = dr.arange(UInt32, dr.width(x)) % dr.width(model.params) + dr.scatter_reduce(dr.ReduceOp.Add, grad_p_expected, x, idx) + assert dr.allclose(grad_p, grad_p_expected) + + +@pytest.test_arrays("float32, jit, is_diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test47_tutorial_example(t, auto_opaque): + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + + @dr.freeze(auto_opaque=auto_opaque) + def frozen_eval(inputs, idx, params, target_value, grad_factor): + intermediate = dr.gather(Float, params, idx) + result = 0.5 * dr.square(intermediate) * inputs + + # Since reductions are not supported yet, we cannot compute a single + # loss value here. It's not really a problem though, since DrJit can + # backpropagate starting from arrays of any widths. + loss_per_entry = dr.square(result - target_value) * grad_factor + + # The gradients resulting from backpropagation will be directly accumulated + # (via dr.scatter_add()) into the gradient buffer of `params` (= `dr.grad(params)`). + dr.backward_from(loss_per_entry) + + # It's fine to return the primal values of `result`, but keep in mind that they will + # not be differentiable w.r.t. `params`. + return dr.detach(result) + + params = Float([1, 2, 3, 4, 5]) + + for _ in range(3): + dr.disable_grad(params) + dr.enable_grad(params) + assert dr.all(dr.grad(params) == 0) + + inputs = Float([0.1, 0.2, 0.3]) + idx = UInt32([1, 2, 3]) + # Represents the optimizer's loss scale + grad_factor = 4096 / dr.opaque(Float, dr.width(inputs)) + + result = frozen_eval( + inputs, idx, params, target_value=0.5, grad_factor=grad_factor + ) + assert not dr.grad_enabled(result) + # Gradients were correctly accumulated to `params`'s gradients. + assert not dr.all(dr.grad(params) == 0) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test48_compress(t, auto_opaque): + + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + + pkg = get_pkg(t) + Sampler = pkg.Sampler + + def func(sampler: Sampler) -> UInt32: + indices = dr.compress(sampler.next() < 0.5) + return indices + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + sampler_func = Sampler(10) + sampler_frozen = Sampler(10) + for i in range(3): + dr.all(frozen(sampler_frozen) == func(sampler_func)) + + sampler_func = Sampler(11) + sampler_frozen = Sampler(11) + for i in range(3): + dr.all(frozen(sampler_frozen) == func(sampler_func)) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("uint32, llvm, -is_diff, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test49_scatter_reduce_expanded(t, auto_opaque): + + def func(target: t, src: t): + dr.scatter_reduce(dr.ReduceOp.Add, target, src, dr.arange(t, dr.width(src)) % 2) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(4): + src = dr.full(t, 1, 10 + i) + dr.make_opaque(src) + + result = t([0] * (i + 2)) + dr.make_opaque(result) + frozen(result, src) + + reference = t([0] * (i + 2)) + dr.make_opaque(reference) + func(reference, src) + + assert dr.all(result == reference) + + assert frozen.n_cached_recordings == 1 + assert frozen.n_recordings == 4 + + +@pytest.test_arrays("uint32, llvm, -is_diff, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test50_scatter_reduce_expanded_identity(t, auto_opaque): + + def func(src: t): + target = dr.zeros(t, 5) + dr.scatter_reduce(dr.ReduceOp.Add, target, src, dr.arange(t, dr.width(src)) % 2) + return target + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(4): + src = dr.full(t, 1, 10 + i) + dr.make_opaque(src) + + result = frozen(src) + + reference = func(src) + + assert dr.all(result == reference) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("uint32, llvm, -is_diff, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test51_scatter_reduce_expanded_no_memset(t, auto_opaque): + + def func(src: t): + target = dr.full(t, 5) + dr.scatter_reduce(dr.ReduceOp.Add, target, src, dr.arange(t, dr.width(src)) % 2) + return target + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(4): + src = dr.full(t, 1, 10 + i) + dr.make_opaque(src) + + result = frozen(src) + + reference = func(src) + + assert dr.all(result == reference) + + assert frozen.n_recordings == 1 + assert frozen.n_cached_recordings == 1 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test52_python_inputs(t, auto_opaque): + + def func(x: t, neg: bool): + if neg: + return -x + 1 + else: + return x + 1 + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + for neg in [False, True]: + x = t(1, 2, 3) + dr.opaque(t, i) + + res = frozen(x, neg) + ref = func(x, neg) + assert dr.all(res == ref) + + assert frozen.n_recordings == 2 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test53_scatter_inc(t, auto_opaque): + + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + + n = 37 + + def acc_with_scatter_inc(x, counter, out_buffer, max_points): + active = x > 0.5 + out_idx = dr.scatter_inc(counter, UInt32(0), active=active) + active &= out_idx < max_points + + # TODO: also test within a loop + dr.scatter(out_buffer, x, out_idx, active=active) + + def test(i, func): + x = dr.linspace(Float, 0.1, 1, n + i) + dr.opaque(Float, i) / 100 + counter = UInt32(dr.opaque(UInt32, 0)) + out_buffer = dr.zeros(Float, 10) + max_points = UInt32(dr.opaque(UInt32, dr.width(out_buffer))) + + dr.set_label(x, "x") + dr.set_label(counter, "counter") + dr.set_label(out_buffer, "out_buffer") + dr.set_label(max_points, "max_points") + dr.eval(x, counter, out_buffer, max_points) + + func(x, counter, out_buffer, max_points) + + return out_buffer, counter + + def func(i): + return test(i, acc_with_scatter_inc) + + acc_with_scatter_inc = dr.freeze(acc_with_scatter_inc, auto_opaque=auto_opaque) + + def frozen(i): + return test(i, acc_with_scatter_inc) + + for i in range(3): + + res, _ = frozen(i) + ref, _ = func(i) + + # assert dr.all(res == ref) + + # Should have filled all of the entries + assert dr.all(res > 0.5) + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test54_read_while_frozen(t, auto_opaque): + # dr.set_flag(dr.JitFlag.KernelFreezing, True) + assert dr.flag(dr.JitFlag.KernelFreezing) + + def func(x): + return x[1] + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + x = t(1, 2, 3) + with pytest.raises(RuntimeError): + frozen(x) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test55_var_upload(t, auto_opaque): + def func(x): + + arrays = [] + + for i in range(3): + y = dr.arange(t, 3) + dr.make_opaque(y) + arrays.append(y) + + del arrays + del y + + return x / t(10, 10, 10) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + x = dr.arange(t, 3) + dr.make_opaque(x) + + # with pytest.raises(RuntimeError, match = "created while recording"): + with pytest.raises(RuntimeError): + z = frozen(x) + + # assert dr.allclose(z, func(x)) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test56_grad_isolate(t, auto_opaque): + dr.set_flag(dr.JitFlag.ReuseIndices, False) + + def f(x): + return x * 2 + + def g(y): + return y * 3 + + def func(y): + z = g(y) + dr.backward(z) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + + x = dr.arange(t, 3) + dr.make_opaque(x) + dr.enable_grad(x) + + y = f(x) + with dr.isolate_grad(): + func(y) + + ref = dr.grad(x) + + x = dr.arange(t, 3) + dr.make_opaque(x) + dr.enable_grad(x) + + y = f(x) + dr.make_opaque(y) + frozen(y) + + res = dr.grad(x) + + assert dr.allclose(ref, res) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test57_isolate_grad_fwd(t, auto_opaque): + + def f(x): + return x * x + + def g(y): + return y * 2 + + def func(x): + with dr.isolate_grad(): + y = f(x) + dr.forward(x) + return y + + def frozen(x): + y = f(x) + dr.forward(x) + return y + + frozen = dr.freeze(frozen, auto_opaque=auto_opaque) + + for i in range(3): + x = t(i) + dr.make_opaque(x) + dr.enable_grad(x) + + y = func(x) + # z = g(y) + + ref = dr.grad(y) + + x = t(i) + dr.make_opaque(x) + dr.enable_grad(x) + + y = frozen(x) + # z = g(y) + + res = dr.grad(y) + + assert dr.allclose(ref, res) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test58_grad_postponed_part(t, auto_opaque): + dr.set_flag(dr.JitFlag.ReuseIndices, False) + + def f(x): + return x * x * 2 + + def g(y): + return y * y * 3 + + def func(y1, y2): + z1 = g(y1) + z2 = g(y2) + dr.backward(z1) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + def run(i, name, func): + x1 = dr.arange(t, 3) + i + dr.make_opaque(x1) + dr.enable_grad(x1) + y1 = f(x1) + + x2 = dr.arange(t, 3) + i + dr.make_opaque(x2) + dr.enable_grad(x2) + dr.set_grad(x2, 2) + y2 = f(x2) + + func(y1, y2) + + dx1 = dr.grad(x1) + dx2_1 = dr.grad(x2) + + dr.set_grad(x2, 1) + dr.backward(x2) + dx1_2 = dr.grad(x2) + + return [dx1, dx1_2, dx2_1] + + for i in range(3): + + def isolated(y1, y2): + with dr.isolate_grad(): + func(y1, y2) + + ref = run(i, "reference", isolated) + res = run(i, "frozen", frozen) + + for ref, res in zip(ref, res): + assert dr.allclose(ref, res) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", (False, True)) +def test59_nested(t, auto_opaque): + + pkg = get_pkg(t) + mod = sys.modules[t.__module__] + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + a, b = A(), B() + a.value = dr.ones(t, 16) + dr.enable_grad(a.value) + + U = mod.UInt32 + xi = t(1, 2, 8, 3, 4) + yi = dr.reinterpret_array(U, BasePtr(a, a, a, a, a)) + + def nested(self, xi, yi): + return self.nested(xi, yi) + + def func(c, xi, yi): + return dr.dispatch(c, nested, xi, yi) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + c = BasePtr(a, a, a, b, b) + xi = t(1, 2, 8, 3, 4) + yi = dr.reinterpret_array(U, BasePtr(a, a, a, a, a)) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, True): + xref = func(c, xi, yi) + + assert dr.all(xref == xi + 1) + + c = BasePtr(a, a, a, b, b) + xi = t(1, 2, 8, 3, 4) + yi = dr.reinterpret_array(U, BasePtr(a, a, a, a, a)) + + with dr.scoped_set_flag(dr.JitFlag.SymbolicCalls, True): + xfrozen = frozen(c, xi, yi) + + assert dr.all(xfrozen == xref) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test60_call_raise(t, auto_opaque): + + mod = sys.modules[t.__module__] + pkg = get_pkg(t) + + UInt = mod.UInt + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + a, b = A(), B() + + def f(x: t): + raise RuntimeError("test") + + def g(self, x: t): + if isinstance(self, B): + raise RuntimeError + return x + 1 + + c = BasePtr(a, a, a, b, b) + + with pytest.raises(RuntimeError): + dr.dispatch(c, g, t(1, 1, 2, 2, 2)) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test61_reduce_dot(t, auto_opaque): + def func(x, y): + return dr.dot(x, y) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + x = dr.arange(t, 10 + i) + y = dr.arange(t, 10 + i) + + result = frozen(x, y) + reference = func(x, y) + + assert dr.allclose(result, reference) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test62_clear(t, auto_opaque): + @dr.freeze(auto_opaque=auto_opaque) + def func(x): + return x + 1 + + x = dr.arange(t, 10) + y = func(x) + assert func.n_recordings == 1 + + func.clear() + assert func.n_recordings == 0 + + x = dr.arange(t, 10) + y = func(x) + assert func.n_recordings == 1 + + +@pytest.test_arrays("uint32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test63_method_decorator(t, auto_opaque): + mod = sys.modules[t.__module__] + + class Custom: + DRJIT_STRUCT = {"state": t} + + def __init__(self) -> None: + self.state = t([1, 2, 3]) + + @dr.freeze(auto_opaque=auto_opaque) + def frozen(self, x): + return x + self.state + + def func(self, x): + return x + self.state + + c = Custom() + for i in range(3): + x = dr.arange(t, 3) + i + dr.make_opaque(x) + res = c.frozen(x) + ref = c.func(x) + + assert dr.all(res == ref) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test64_tensor(t, auto_opaque): + """ + Tests that constructing tensors in frozen functions is possible, and does + not cause leaks. + """ + mod = sys.modules[t.__module__] + Float32 = mod.Float32 + TensorXf = mod.TensorXf + + def func(x): + return TensorXf(x + 1) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + x = dr.arange(Float32, 100) + ref = func(x) + res = frozen(x) + assert dr.all(res == ref) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test65_assign_tensor(t, auto_opaque): + """ + Tests that assigning tensors to the input of frozen functions is possible, + and does not cause leaks. + """ + mod = sys.modules[t.__module__] + Float32 = mod.Float32 + TensorXf = mod.TensorXf + + def func(x): + x += 1 + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + x = TensorXf(dr.arange(Float32, 100)) + func(x) + ref = x + + x = TensorXf(dr.arange(Float32, 100)) + frozen(x) + res = x + assert dr.all(res == ref) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test66_closure(t, auto_opaque): + + c1 = 1 + c2 = t(2) + + def func(x): + return x + c1 + c2 + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + x = dr.arange(t, i + 2) + ref = func(x) + + x = dr.arange(t, i + 2) + res = frozen(x) + + assert dr.allclose(ref, res) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test67_mutable_closure(t, auto_opaque): + """ + Test that it is possible to use and modify closures in frozen functions. + """ + y1 = t(1, 2, 3) + y2 = t(1, 2, 3) + + def func(x): + nonlocal y1 + y1 += x + + @dr.freeze(auto_opaque=auto_opaque) + def frozen(x): + nonlocal y2 + y2 += x + + for i in range(3): + x = t(i) + dr.make_opaque(x) + + func(x) + frozen(x) + + assert dr.allclose(y1, y2) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test68_state_decorator(t, auto_opaque): + mod = sys.modules[t.__module__] + Float = mod.Float32 + UInt32 = mod.UInt32 + + # Note: not a dataclass or DRJIT_STRUCT + class MyClass: + def __init__(self): + self.something1 = 4.5 + self.something2 = Float([1, 2, 3, 4, 5]) + + @dr.freeze( + auto_opaque=auto_opaque, state_fn=lambda self, *_, **__: (self.something2) + ) + def frozen(self, x: Float, idx: UInt32) -> Float: + return x * self.something1 * dr.gather(Float, self.something2, idx) + + def func(self, x: Float, idx: UInt32) -> Float: + return x * self.something1 * dr.gather(Float, self.something2, idx) + + c = MyClass() + + for i in range(3): + x = dr.arange(Float, i + 2) + idx = dr.arange(UInt32, i + 2) + + res = c.frozen(x, idx) + ref = c.func(x, idx) + + assert dr.allclose(res, ref) + assert c.frozen.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("limit", (-1, None, 0, 1, 2)) +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test69_max_cache_size(t, auto_opaque, limit): + """ + Tests different cache size limitations for the frozen function. + """ + + def func(x, p): + return x + p + + frozen = dr.freeze(func, auto_opaque=auto_opaque, limit=limit) + + n = 3 + for i in range(n): + + x = t(i, i + 1, i + 2) + + res = frozen(x, i) + ref = func(x, i) + + assert dr.allclose(res, ref) + + if limit == -1 or limit is None: + assert frozen.n_recordings == n + assert frozen.n_cached_recordings == n + elif limit == 0: + assert frozen.n_recordings == 0 + assert frozen.n_cached_recordings == 0 + else: + assert frozen.n_recordings == n + assert frozen.n_cached_recordings == limit + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test69_lru_eviction(t, auto_opaque): + """ + Tests that the least recently used cache entry is evicted from the frozen + function if the cache size is limited. + """ + + def func(x, p): + return x + p + + frozen = dr.freeze(func, auto_opaque=auto_opaque, limit=2) + + x = t(0, 1, 2) + + # Create two entries in the cache + frozen(x, 0) + frozen(x, 1) + + # This should evict the first one + frozen(x, 2) + + assert frozen.n_recordings == 3 + + # p = 1 should still be in the cache, and calling it should not increment + # the recording counter. + frozen(x, 1) + + assert frozen.n_recordings == 3 + + # p = 0 should be evicted, and calling it will increment the counter + frozen(x, 0) + + assert frozen.n_recordings == 4 + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test70_warn_recordings(t, auto_opaque): + """ + This test simply calls the frozen function with incompattible inputs, and + should print two warnings. + """ + + def func(x, i): + return x + i + + frozen = dr.freeze(func, auto_opaque=auto_opaque, warn_after=2) + + for i in range(4): + x = t(1, 2, 3) + frozen(x, i) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("force_optix", [True, False]) +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test71_texture(t, auto_opaque, force_optix): + mod = sys.modules[t.__module__] + Texture1f = mod.Texture1f + Float = mod.Float32 + + def func(tex: Texture1f, pos: Float): + return tex.eval(pos) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + with dr.scoped_set_flag(dr.JitFlag.ForceOptiX, force_optix): + n = 4 + for i in range(3): + tex = Texture1f([2], 1, True, dr.FilterMode.Linear, dr.WrapMode.Repeat) + tex.set_value(t(0, 1)) + + pos = dr.arange(Float, i + 2) / n + + res = frozen(tex, pos) + ref = func(tex, pos) + + assert dr.allclose(res, ref) + + assert frozen.n_recordings < n + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test72_no_input(t): + mod = sys.modules[t.__module__] + + backend = dr.backend_v(t) + + def func(): + return dr.arange(t, 10) + + frozen = dr.freeze(func, backend=backend) + + for i in range(3): + res = frozen() + ref = func() + + assert dr.allclose(res, ref) + + wrong_backend = ( + dr.JitBackend.CUDA if backend == dr.JitBackend.LLVM else dr.JitBackend.LLVM + ) + + frozen = dr.freeze(func, backend=wrong_backend) + + with pytest.raises(RuntimeError): + for i in range(3): + res = frozen() + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test73_opaque_width(t, auto_opaque): + def func(x: t): + return dr.mean(x) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + n = 3 + for i in range(n): + x = dr.arange(t, 3 + i) + + res = frozen(x) + ref = func(x) + + assert dr.allclose(ref, res) + + assert frozen.n_recordings < n + + +@pytest.test_arrays("float32, jit, shape=(*)") +def test74_auto_opaque_retraverse(t): + """ + Tests that the auto_oaque feature correctly evaluates the changing literal. + """ + + def func(x: t): + return x + 1 + + frozen = dr.freeze(func, auto_opaque=True) + + n = 3 + for i in range(n): + x = t(i) # Create as literal + + res = frozen(x) + ref = func(x) + + assert dr.allclose(ref, res) + + # The literal is made opaque by the auto_opaque feature before re-tracing + # the function in the second iteration (i = 1). + if i >= 1: + assert x.state == dr.VarState.Evaluated + else: + assert x.state == dr.VarState.Literal + + assert frozen.n_recordings == 2 + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test75_changing_literal_width(t, auto_opaque): + """ + Tests that the auot opaque feature correctly forces evaluation of literals, + if the literal size changes between calls to the frozen function. + """ + + def func(x: t, lit: t): + return x + 1 + + frozen = dr.freeze(func, warn_after=3, auto_opaque=auto_opaque) + + n = 10 + for i in range(n): + lit = dr.zeros(t, (i + 1) * 10) + x = lit + 0.5 + dr.make_opaque(x) + + res = frozen(x, lit) + ref = func(x, lit) + + assert dr.allclose(ref, res) + + if auto_opaque: + # The literal is made opaque by the auto_opaque feature before re-tracing + # the function in the second iteration (i = 1). + if i >= 1: + assert lit.state == dr.VarState.Evaluated + else: + assert lit.state == dr.VarState.Literal + else: + # Otherwise, all literals are made opaque. + assert lit.state == dr.VarState.Evaluated + + if auto_opaque: + assert frozen.n_recordings == 2 + else: + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test76_changing_literal_width_holder(t, auto_opaque): + class MyHolder: + DRJIT_STRUCT = {"lit": t} + + def __init__(self, lit): + self.lit = lit + + def func(x: t, lit: MyHolder): + return x + 1 + + frozen = dr.freeze(func, warn_after=3, auto_opaque=auto_opaque) + + n = 10 + for i in range(n): + holder = MyHolder(dr.zeros(dr.tensor_t(t), (i + 1) * 10)) + x = holder.lit + 0.5 + dr.make_opaque(x) + + res = frozen(x, holder) + ref = func(x, holder) + + assert dr.allclose(ref, res) + + if auto_opaque: + assert frozen.n_recordings == 2 + else: + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("optimizer", ["sdg", "rmsprop", "adam"]) +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test77_optimizers(t, optimizer, auto_opaque): + n = 10 + + def func(target, opt): + loss = dr.mean(dr.square(opt["x"] - target)) + + dr.backward(loss) + + opt.step() + + return [opt["x"], opt["y"]], loss + + def init_optimizer(): + if optimizer == "sdg": + opt = dr.opt.SGD(lr=0.001, momentum=0.9) + elif optimizer == "rmsprop": + opt = dr.opt.RMSProp(lr=0.001) + elif optimizer == "adam": + opt = dr.opt.Adam(lr=0.001) + return opt + + frozen = dr.freeze(func) + + opt_func = init_optimizer() + opt_frozen = init_optimizer() + + for i in range(n): + x = dr.full(t, 1, 10) + y = dr.full(t, -1, 10) + target = dr.full(t, 0, 10) + + opt_func["x"] = x + opt_frozen["x"] = x + + opt_func["y"] = y + opt_frozen["y"] = y + + opt_func.set_learning_rate({"x": 1e-4, "y": 1e-3}) + opt_frozen.set_learning_rate({"x": 1e-4, "y": 1e-3}) + + res_params, res_loss = frozen(target, opt_frozen) + ref_params, ref_loss = func(target, opt_func) + + assert dr.allclose(res_params, ref_params) + assert dr.allclose(res_loss, ref_loss) + + if optimizer == "adam": + assert frozen.n_recordings == 2 + else: + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test78_hash_id_fallback(t, auto_opaque): + """ + Test the hash to id fallback for object hashing if the object is not + traversible nor hashable. + """ + + n = 3 + + class Test: + x = 0 + + def __init__(self, x) -> None: + self.x = x + + __hash__ = None + + def func(x, test): + return x + test.x + + frozen = dr.freeze(func) + + for i in range(n): + x = dr.arange(t, 3) + y = Test(i) + + res = frozen(x, y) + ref = func(x, y) + + assert dr.allclose(res, ref) + + assert frozen.n_recordings == 3 + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test79_empty(t, auto_opaque): + + n = 5 + + mod = sys.modules[t.__module__] + + def func(x, i, v): + dr.scatter(x, v, i) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(n): + i = mod.UInt32(i) + + res = dr.empty(t, n) + frozen(res, i, 1) + + ref = dr.empty(t, n) + func(ref, i, 1) + + assert res[i] == ref[i] + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test80_tensor_mean(t, auto_opaque): + """ + Tests that the mean of a tensor inside a frozen function is computed correctly + when changing the last tensor dimension. + """ + mod = sys.modules[t.__module__] + Float32 = mod.Float32 + TensorXf = mod.TensorXf + + def func(x): + return dr.mean(x) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + shape = ((i + 3), 10, 5) + x = TensorXf(dr.arange(Float32, dr.prod(shape)), shape=shape) + + res = frozen(x) + ref = func(x) + + assert dr.allclose(res, ref) + + assert frozen.n_recordings == 1 + + # Changing the trailing dimensions should cause the function to be re-traced + for i in range(3): + shape = (10, (i + 3), 5) + x = TensorXf(dr.arange(Float32, dr.prod(shape)), shape=shape) + + res = frozen(x) + ref = func(x) + + assert dr.allclose(res, ref) + + assert frozen.n_recordings == 4 + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test81_changing_closures(t, auto_opaque): + + y = 1 + + def func(x): + return x + y + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + x = dr.arange(t, i + 3) + + res = frozen(x) + ref = func(x) + + assert dr.allclose(res, ref) + + assert frozen.n_recordings == 1 + + for i in range(3): + x = dr.arange(t, i + 3) + + y += 1 + + res = frozen(x) + ref = func(x) + + assert dr.allclose(res, ref) + + assert frozen.n_recordings == 4 + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test82_changing_closures_methods(t, auto_opaque): + + y = 1 + + class Test: + def func(self, x): + return x + y + + @dr.freeze(auto_opaque=auto_opaque) + def frozen(self, x): + return x + y + + test = Test() + + for i in range(3): + x = dr.arange(t, i + 3) + + res = test.frozen(x) + ref = test.func(x) + + assert dr.allclose(res, ref) + + assert test.frozen.n_recordings == 1 + + for i in range(3): + x = dr.arange(t, i + 3) + + y += 1 + + res = test.frozen(x) + ref = test.func(x) + + assert dr.allclose(res, ref) + + assert test.frozen.n_recordings == 4 + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test83_any(t, auto_opaque): + mod = sys.modules[t.__module__] + + def func(x): + return dr.any(x) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(4): + x = dr.zeros(mod.Bool, i + 3) + if i % 2: + x[2] = True + + res = frozen(x) + ref = func(x) + + dr.all(res == ref) + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test84_block_sum(t, auto_opaque): + """ + Tests a dry-run, resulting from a block sum. + """ + mod = sys.modules[t.__module__] + + def func(x): + return dr.block_sum(x, 2) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(7): + x = dr.arange(t, i + 4) + + res = frozen(x) + ref = func(x) + + assert dr.allclose(res, ref) + + assert frozen.n_recordings == 4 + assert frozen.n_cached_recordings == 1 + + +@pytest.test_arrays("float32, jit, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test85_dry_run_failure(t, auto_opaque): + """ + Test the dry-run block_sum + compress failure case + """ + n = 4 + + mod = sys.modules[t.__module__] + + def func(x): + y = dr.block_sum(x, 2) + return dr.compress(y > 3) + + frozen = dr.freeze(func, auto_opaque=auto_opaque, warn_after=0) + + for i in range(n): + x = dr.arange(t, i + 4) + + res = frozen(x) + ref = func(x) + + assert dr.allclose(res, ref) + + assert frozen.n_recordings == n + assert frozen.n_cached_recordings == 1 + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test86_nested_vcalls(t, auto_opaque): + """ + Test that nested vcalls are in principle possible. + """ + mod = sys.modules[t.__module__] + + pkg = get_pkg(t) + mod = sys.modules[t.__module__] + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + a, b = A(), B() + a.value = dr.ones(t, 16) + + c = BasePtr(a, a, None, None, b, b) + s = BasePtr(a, b, None, b, b, a) + s = dr.reinterpret_array(mod.UInt32, s) + + def func(c, x, s): + return c.nested(x, s) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + + x = dr.arange(t, 6) + + res = frozen(c, x, s) + ref = func(c, x, s) + + assert dr.allclose(res, ref) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test87_nested_vcalls_member(t, auto_opaque): + """ + Tests that a vcall, using an opaque member as the pointer fails correctly. + """ + pkg = get_pkg(t) + mod = sys.modules[t.__module__] + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + a, b = A(), B() + a.value = dr.ones(t, 16) + + c = BasePtr(a, a, None, None, b, b) + + b.s = dr.reinterpret_array(mod.UInt32, BasePtr(a)) + dr.make_opaque(b.s) + + def func(c, x): + return c.nested_self(x) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + x = dr.arange(t, 6) + + with pytest.raises(RuntimeError): + frozen(c, x) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test88_tensor_indexing(t, auto_opaque): + """ + Tests that changes in the first dimension of a tensor do not cause re-tracing. + """ + mod = sys.modules[t.__module__] + + def func(x: mod.TensorXf, row: mod.UInt32, col: mod.UInt32): + return dr.gather(mod.Float, x.array, row * dr.shape(x)[1] + col) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + shape = ((i + 4), 10) + x = mod.TensorXf(dr.arange(mod.Float, dr.prod(shape)), shape=shape) + row = dr.arange(mod.UInt32, i + 3) + col = dr.arange(mod.UInt32, i + 3) + 1 + + res = frozen(x, row, col) + ref = func(x, row, col) + + assert dr.allclose(res, ref) + + assert frozen.n_recordings == 1 + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test89_tensor_slicing(t, auto_opaque): + """ + Tests that changes in the first dimension of a tensor do not cause re-tracing, + and slicing works inside of frozen functions. + """ + mod = sys.modules[t.__module__] + + def func(x: mod.TensorXf): + return x[:, 3, 1] + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + shape = ((i + 4), 10, 3) + x = mod.TensorXf(dr.arange(mod.Float, dr.prod(shape)), shape=shape) + + res = frozen(x) + ref = func(x) + + assert dr.allclose(res, ref) + + assert frozen.n_recordings == 1 + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test90_tensor_slicing(t, auto_opaque): + """ + Tests dynamic indexing of tensors using the Dr.Jit's slicing implementation + inside of frozen functions. + """ + mod = sys.modules[t.__module__] + + def func(x: mod.TensorXf, row: mod.UInt32, col: mod.UInt32): + res = x[row, col] + return res + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + shape = ((i + 5), 10) + x = mod.TensorXf(dr.arange(mod.Float, dr.prod(shape)), shape=shape) + row = dr.arange(mod.UInt32, i + 4) + col = dr.arange(mod.UInt32, 3) + 1 + + res = frozen(x, row, col) + ref = func(x, row, col) + + assert dr.allclose(res, ref) + + assert frozen.n_recordings == 1 + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test91_grad_doc(t, auto_opaque): + """ + Tests the code snippet from the docs section on gradients. + """ + + @dr.freeze + def func(y): + # Some differentiable operation... + z = dr.mean(y) + # Propagate the gradients to the input of the function... + dr.backward(z) + + x = dr.arange(t, 3) + dr.enable_grad(x) + + y = dr.square(x) + + # The first time the function is called, it will be recorded and the correct + # gradients will be accumulated into x. + func(y) + + # Compare against manually calculated gradient + assert dr.allclose(dr.grad(x), 2 * 1 / dr.width(x) * x) + + dr.clear_grad(x) + + y = x * 2 + + # On subsequent calls the the function will be replayed, and gradients will + # be accumulated in x. + func(y) + + # Compare against manually calculated gradient + assert dr.allclose(dr.grad(x), [2 * 1 / dr.width(x)] * dr.width(x)) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test92_refcycle(t, auto_opaque): + """ + Tests that traversing PyTrees with direct ref cycles is possible. + """ + + def func(l: list): + l[0] += 1 + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + for i in range(3): + res = [dr.arange(t, i + 3)] + res.append(res) + frozen(res) + + ref = [dr.arange(t, i + 3)] + ref.append(ref) + func(ref) + + assert dr.allclose(ref[0], res[0]) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +def test93_auto_opaque_list(t): + """ + Tests that traversing PyTrees with direct ref cycles is possible. + """ + + @dr.freeze + def frozen(x, y, l, c): + return x + 1 + ... + + @dataclass + class MyClass: + z: t + + for i in range(3): + x = dr.arange(t, i + 2) + y = t(i) + l = [t(1), t(i)] + c = MyClass(t(i)) + + frozen(x, y, l, c=c) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test94_exception(t, auto_opaque): + """ + Tests that raising an exception inside a frozen function does not cause memory + leaks. + """ + + @dr.freeze + def frozen(x): + raise RuntimeError("test") + return x + 1 + + x = dr.arange(t, 3) + + with pytest.raises(RuntimeError): + frozen(x) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test95_vcall_exception(t, auto_opaque): + """ + Tests that raising an exception inside a frozen function does not cause memory + leaks, even if the function uses a vcall. + """ + pkg = get_pkg(t) + mod = sys.modules[t.__module__] + + A, B, Base, BasePtr = pkg.A, pkg.B, pkg.Base, pkg.BasePtr + + a, b = A(), B() + a.value = dr.ones(t, 16) + + b.s = dr.reinterpret_array(mod.UInt32, BasePtr(a)) + dr.make_opaque(b.s) + + c = BasePtr(a, a, None, None, b, b) + + @dr.freeze + def frozen(c, x): + c.nested_self(x) + raise RuntimeError("test") + + x = dr.arange(t, 3) + + with pytest.raises(RuntimeError): + frozen(x) + + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +@pytest.mark.parametrize("layout", ["training", "inference"]) +def test96_coop_vec(t, auto_opaque, layout): + """ + Tests that it is possible to evaluate a neural network inside a frozen function. + """ + skip_if_coopvec_not_supported(t) + mod = sys.modules[t.__module__] + Float16 = mod.Float16 + ArrayXf = mod.ArrayXf + TensorXf16 = mod.TensorXf16 + + import drjit.nn as nn + from drjit.opt import Adam + + def func(net, x: ArrayXf): + return ArrayXf(net(nn.CoopVec(x))) + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + net = nn.Sequential( + nn.Cast(Float16), + nn.Linear(2, 32, bias=False), + nn.LeakyReLU(), + nn.Linear(-1, 3, bias=False), + nn.Exp(), + ) + + net = net.alloc(TensorXf16, 2) + + weights, net = nn.pack(net, layout=layout) + + for i in range(3): + x = dr.rand(ArrayXf, (2, 2 * i + 4)) + + res = frozen(net=net, x=x) + ref = func(net = net, x = x) + + assert dr.allclose(res, ref) + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test97_coop_vec_bwd(t, auto_opaque): + """ + Tests that it is possible to evaluate a neural network inside a frozen function, + and calculate gradients w.r.t. some loss. + """ + skip_if_coopvec_not_supported(t) + mod = sys.modules[t.__module__] + Float16 = mod.Float16 + ArrayXf = mod.ArrayXf + TensorXf16 = mod.TensorXf16 + + import drjit.nn as nn + from drjit.opt import Adam + + def func(net, x: ArrayXf): + y = ArrayXf(net(nn.CoopVec(x))) + + loss = dr.squared_norm(y - 1) + + dr.backward(loss) + + return loss + + frozen = dr.freeze(func, auto_opaque=auto_opaque) + + net = nn.Sequential( + nn.Cast(Float16), + nn.Linear(2, 32, bias=False), + nn.LeakyReLU(), + nn.Linear(-1, 3, bias=False), + nn.Exp(), + ) + + net = net.alloc(TensorXf16, 2) + + weights, net = nn.pack(net, layout="training") + + x = dr.rand(ArrayXf, (2, 2)) + + for i in range(3): + dr.enable_grad(weights) + dr.clear_grad(weights) + + x = dr.rand(ArrayXf, (2, i * 2 + 4)) + + loss_res = frozen(net, x) + grad_res = dr.grad(weights) + + dr.clear_grad(weights) + + loss_ref = frozen(net, x) + grad_ref = dr.grad(weights) + + assert dr.allclose(loss_res, loss_ref) + assert dr.allclose(grad_res, grad_ref) + + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test98_changing_list(t, auto_opaque): + """ + Tests that changing the number of elements in a list fails gracefully. + """ + mod = sys.modules[t.__module__] + + @dr.freeze + def frozen(x: list): + x.append(x[0] + 1) + + x = [t(1, 2, 3)] + frozen(x) + + x = [t(1, 2, 3, 4)] + with pytest.raises(RuntimeError): + frozen(x) + +@pytest.test_arrays("float32, jit, diff, shape=(*)") +@pytest.mark.parametrize("auto_opaque", [False, True]) +def test99_construction_failure(t, auto_opaque): + """ + Tests that trying to return a variable that cannot be constructed from a frozen + function fails gracefully. + """ + mod = sys.modules[t.__module__] + + class NonConstructable: + DRJIT_STRUCT = { + "x": t, + } + + x: t + + def __init__(self, x: t): + self.x = x + + @dr.freeze + def frozen(x: t): + return NonConstructable(x) + + x = t(1, 2, 3) + with pytest.raises(RuntimeError): + frozen(x) + + x = t(1, 2, 3, 4) + with pytest.raises(RuntimeError): + frozen(x) + diff --git a/tests/while_loop_ext.cpp b/tests/while_loop_ext.cpp index 64613e7f3..e207a38a7 100644 --- a/tests/while_loop_ext.cpp +++ b/tests/while_loop_ext.cpp @@ -3,6 +3,7 @@ #include #include #include +#include namespace nb = nanobind; namespace dr = drjit; @@ -37,11 +38,13 @@ struct Sampler { T next() { return rng.next_float32(); } - void traverse_1_cb_ro(void *payload, void (*fn)(void *, uint64_t)) const { + void traverse_1_cb_ro(void *payload, + dr::detail::traverse_callback_ro fn) const { traverse_1_fn_ro(rng, payload, fn); } - void traverse_1_cb_rw(void *payload, uint64_t (*fn)(void *, uint64_t)) { + void traverse_1_cb_rw(void *payload, + dr::detail::traverse_callback_rw fn) { traverse_1_fn_rw(rng, payload, fn); }